# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# pyre-strict
import time
from collections.abc import Callable
from logging import Logger
from queue import Queue
from threading import Event, Lock, Thread
from typing import Any
from ax.core.types import TEvaluationOutcome, TParameterization
from ax.exceptions.core import DataRequiredError
from ax.exceptions.generation_strategy import MaxParallelismReachedException
from ax.service.ax_client import AxClient
from ax.utils.common.logger import get_logger
logger: Logger = get_logger(__name__)
IDLE_SLEEP_SEC = 0.1
[docs]
def interactive_optimize(
num_trials: int,
candidate_queue_maxsize: int,
candidate_generator_function: Callable[..., None],
data_attacher_function: Callable[..., None],
# pyre-ignore[2]: Missing parameter annotation
elicitation_function: Callable[..., Any],
candidate_generator_kwargs: dict[str, Any] | None = None,
data_attacher_kwargs: dict[str, Any] | None = None,
elicitation_function_kwargs: dict[str, Any] | None = None,
) -> bool:
"""
Function to facilitate running Ax experiments with candidate pregeneration (the
generation of candidate points while waiting for trial evaluation). This can be
useful in many contexts, especially in interactive experiments in which trial
evaluation entails eliciting a response from a human. Candidate pregeneration
uses the time waiting for trail evaluation to generate new candidates from the
data available. Note that this is a tradeoff -- a larger pregeneration queue
will result in more trials being generated with less data compared to a smaller
pregeneration queue (or no pregeneration as in conventional Ax usage) and should
only be used in contexts where it is necessary for the user to not experience any
"lag" while candidates are being generated.
Args:
num_trials: The total number of trials to be run.
candidate_queue_maxsize: The maximum number of candidates to pregenerate.
candidate_generator_function: A function taking in a queue and event that
enqueues candidates (generated by any means). See
`ax_client_candidate_generator` for an example.
data_attacher_function: A function taking in a queue and event that attaches
observations to Ax. See `ax_client_data_attacher` for an example.
elicitation_function: Function from parameterization (as returned by
`AxClient.get_next_trial`) to outcome (as expected by
`AxClient.complete_trial`). If None, elicitation is aborted by the user.
candidate_generator_kwargs: kwargs to be passed into
`candidate_generator_function` when it is spawned as a thread.
data_attacher_kwargs: kwargs to be passed into `data_attacher_function` when
it is spawned as a thread.
elicitation_function_kwargs: kwargs to be passed into `elicitation_function`
Returns:
True if optimization was completed and False if aborted.
"""
optimization_completed = True
# Construct queues to buffer arbitrary inputs and outputs
candidate_queue = Queue(maxsize=candidate_queue_maxsize)
data_queue = Queue()
# Construct events to allow us to stop the generator and attacher threads
candidate_generator_stop_event = Event()
data_attacher_stop_event = Event()
# Construct threads to run candidate thread-safe pregeneration and thread-safe
# data attaching respectively
candidate_generator_thread = Thread(
target=candidate_generator_function,
args=(
candidate_queue,
candidate_generator_stop_event,
num_trials,
),
kwargs=(candidate_generator_kwargs or {}),
)
data_attacher_thread = Thread(
target=data_attacher_function,
args=(
data_queue,
data_attacher_stop_event,
),
kwargs=(data_attacher_kwargs or {}),
)
candidate_generator_thread.start()
data_attacher_thread.start()
for _i in range(num_trials):
candidate_item = candidate_queue.get()
if candidate_item is None:
# if candidate_item is None,
# it means the candidate generator has failed and stopped
optimization_completed = False
break
response = elicitation_function(
candidate_item, **(elicitation_function_kwargs or {})
)
candidate_queue.task_done()
if response is not None:
data_queue.put(response)
else:
# if resopnse is None, it means the user has stopped
# abort the optimization
optimization_completed = False
break
# Clean up threads (if they have not been stopped already)
candidate_generator_stop_event.set()
data_queue.join()
data_attacher_stop_event.set()
return optimization_completed
[docs]
def interactive_optimize_with_client(
ax_client: AxClient,
num_trials: int,
candidate_queue_maxsize: int,
elicitation_function: Callable[[tuple[TParameterization, int]], TEvaluationOutcome],
) -> bool:
"""
Implementation of `interactive_loop` using the AxClient. Extract results of the
experiment from the AxClient passed in.
The basic structure is as follows: One thread tries for a lock on the AxClient,
generates a candidate, and enqueues it to a candidate queue. Another thread tries
for the same lock, takes all the trial outcomes in the outcome queue, and attaches
them to the AxClient. The main thread pops a candidate off the candidate queue,
elicits response from the user, and puts the response onto the outcome queue.
"""
# Construct a lock to ensure only one thread my access the AxClient at any moment
ax_client_lock = Lock()
return interactive_optimize(
num_trials=num_trials,
candidate_queue_maxsize=candidate_queue_maxsize,
candidate_generator_function=ax_client_candidate_generator,
candidate_generator_kwargs={"ax_client": ax_client, "lock": ax_client_lock},
data_attacher_function=ax_client_data_attacher,
data_attacher_kwargs={"ax_client": ax_client, "lock": ax_client_lock},
elicitation_function=elicitation_function,
)
[docs]
def ax_client_candidate_generator(
queue: Queue[tuple[TParameterization, int]],
stop_event: Event,
num_trials: int,
ax_client: AxClient,
lock: Lock,
) -> None:
"""Thread-safe method for generating the next trial from the AxClient and
enqueueing it to the candidate queue. The number of candidates pre-generated is
controlled by the maximum size of the queue. Generation stops when num_trials
trials are attached to the AxClient's experiment.
"""
while not stop_event.is_set():
if not queue.full():
with lock:
try:
parameterization_with_trial_index = ax_client.get_next_trial()
queue.put(parameterization_with_trial_index)
# Check if candidate generation can end
if len(ax_client.experiment.arms_by_name) >= num_trials:
stop_event.set()
except (MaxParallelismReachedException, DataRequiredError) as e:
logger.warning(
f"Encountered error {e}, sleeping for {IDLE_SLEEP_SEC} "
"seconds and trying again."
)
pass # Try again later
time.sleep(IDLE_SLEEP_SEC)
[docs]
def ax_client_data_attacher(
queue: Queue[tuple[int, TEvaluationOutcome]],
stop_event: Event,
ax_client: AxClient,
lock: Lock,
) -> None:
"""Thread-safe method for attaching evaluation outcomes to the AxClient from the
outcome queue. If the AxClient's lock is acquired all data in the outcome queue
is attached at once, then the lock released. Stops when the event is set.
"""
while not stop_event.is_set():
if not queue.empty():
with lock:
while not queue.empty():
trial_index, raw_data = queue.get()
ax_client.complete_trial(
trial_index=trial_index,
raw_data=raw_data,
)
queue.task_done()
time.sleep(IDLE_SLEEP_SEC)