# Copyright (c) Meta Platforms, Inc. and affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import time
from logging import Logger
from queue import Queue
from threading import Event, Lock, Thread
from typing import Any, Callable, Dict, Tuple

from ax.core.types import TEvaluationOutcome, TParameterization

from ax.exceptions.core import DataRequiredError
from ax.exceptions.generation_strategy import MaxParallelismReachedException
from ax.service.ax_client import AxClient
from ax.utils.common.logger import get_logger

logger: Logger = get_logger(__name__)


[docs]def interactive_optimize( num_trials: int, candidate_queue_maxsize: int, # Callable[[Queue[Tuple[TParameterization, int]], Event, ...], None] candidate_generator_function: Callable[..., None], candidate_generator_kwargs: Dict[str, Any], # Callable[[Queue[Tuple[int, TEvaluationOutcome]], Event, ...], None] data_attacher_function: Callable[..., None], data_attacher_kwargs: Dict[str, Any], elicitation_function: Callable[[TParameterization], TEvaluationOutcome], ) -> None: """ Function to facilitate running Ax experiments with candidate pregeneration (the generation of candidate points while waiting for trial evaluation). This can be useful in many contexts, especially in interactive experiments in which trial evaluation entails eliciting a response from a human. Candidate pregeneration uses the time waiting for trail evaluation to generate new candidates from the data available. Note that this is a tradeoff -- a larger pregeneration queue will result in more trials being generated with less data compared to a smaller pregeneration queue (or no pregeneration as in conventional Ax usage) and should only be used in contexts where it is necessary for the user to not experience any "lag" while candidates are being generated. Args: num_trials: The total number of trials to be run. candidate_queue_maxsize: The maximum number of candidates to pregenerate. candidate_generator_function: A function taking in a queue and event that enqueues candidates (generated by any means). See `ax_client_candidate_generator` for an example. candidate_generator_kwargs: kwargs to be passed into `candidate_generator_function` when it is spawned as a thread. data_attacher_function: A function taking in a queue and event that attaches observations to Ax. See `ax_client_data_attacher` for an example. data_attacher_kwargs: kwargs to be passed into `data_attacher_function` when it is spawned as a thread. elicitation_function: Function from parameterization (as returned by `AxClient.get_next_trial`) to outcome (as expected by `AxClient.complete_trial`). """ # Construct queues to buffer inputs and outputs of the AxClient candidate_queue: "Queue[Tuple[TParameterization, int]]" = Queue( maxsize=candidate_queue_maxsize ) data_queue: "Queue[Tuple[int, TEvaluationOutcome]]" = Queue() # Construct events to allow us to stop the generator and attacher threads candidate_generator_stop_event = Event() data_attacher_stop_event = Event() # Construct threads to run candidate thread-safe pregeneration and thread-safe # data attaching respectively candidate_generator_thread = Thread( target=candidate_generator_function, args=( candidate_queue, candidate_generator_stop_event, num_trials, ), kwargs=candidate_generator_kwargs, ) data_attacher_thread = Thread( target=data_attacher_function, args=( data_queue, data_attacher_stop_event, ), kwargs=data_attacher_kwargs, ) candidate_generator_thread.start() data_attacher_thread.start() for _i in range(num_trials): parametrization, trial_index = candidate_queue.get() raw_data = elicitation_function(parametrization) data_queue.put((trial_index, raw_data)) candidate_queue.task_done() # Clean up threads (if they have not been stopped already) candidate_generator_stop_event.set() data_queue.join() data_attacher_stop_event.set()
[docs]def interactive_optimize_with_client( ax_client: AxClient, num_trials: int, candidate_queue_maxsize: int, elicitation_function: Callable[[TParameterization], TEvaluationOutcome], ) -> None: """ Implementation of `interactive_loop` using the AxClient. Extract results of the experiment from the AxClient passed in. The basic structure is as follows: One thread tries for a lock on the AxClient, generates a candidate, and enqueues it to a candidate queue. Another thread tries for the same lock, takes all the trial outcomes in the outcome queue, and attaches them to the AxClient. The main thread pops a candidate off the candidate queue, elicits response from the user, and puts the response onto the outcome queue. """ # Construct a lock to ensure only one thread my access the AxClient at any moment ax_client_lock = Lock() interactive_optimize( num_trials=num_trials, candidate_queue_maxsize=candidate_queue_maxsize, candidate_generator_function=ax_client_candidate_generator, candidate_generator_kwargs={"ax_client": ax_client, "lock": ax_client_lock}, data_attacher_function=ax_client_data_attacher, data_attacher_kwargs={"ax_client": ax_client, "lock": ax_client_lock}, elicitation_function=elicitation_function, )
[docs]def ax_client_candidate_generator( queue: "Queue[Tuple[TParameterization, int]]", stop_event: Event, num_trials: int, ax_client: AxClient, lock: Lock, ) -> None: """Thread-safe method for generating the next trial from the AxClient and enqueueing it to the candidate queue. The number of candidates pre-generated is controlled by the maximum size of the queue. Generation stops when num_trials trials are attached to the AxClient's experiment. """ while not stop_event.is_set(): if not queue.full(): with lock: try: parameterization_with_trial_index = ax_client.get_next_trial() queue.put(parameterization_with_trial_index) # Check if candidate generation can end if len(ax_client.experiment.arms_by_name) >= num_trials: stop_event.set() except (MaxParallelismReachedException, DataRequiredError) as e: logger.warning( f"Encountered error {e}, sleeping for {IDLE_SLEEP_SEC} " "seconds and trying again." ) pass # Try again later time.sleep(IDLE_SLEEP_SEC)
[docs]def ax_client_data_attacher( queue: "Queue[Tuple[int, TEvaluationOutcome]]", stop_event: Event, ax_client: AxClient, lock: Lock, ) -> None: """Thread-safe method for attaching evaluation outcomes to the AxClient from the outcome queue. If the AxClient's lock is acquired all data in the outcome queue is attached at once, then the lock released. Stops when the event is set. """ while not stop_event.is_set(): if not queue.empty(): with lock: while not queue.empty(): trial_index, raw_data = queue.get() ax_client.complete_trial( trial_index=trial_index, raw_data=raw_data, ) queue.task_done() time.sleep(IDLE_SLEEP_SEC)