ax.runners

BoTorch Test Problem

class ax.runners.botorch_test_problem.BotorchTestProblemRunner(test_problem: botorch.test_functions.base.BaseTestProblem)[source]

Bases: ax.core.runner.Runner

A Runner for evaluation Botorch BaseTestProblems. Given a trial the Runner will evaluate the BaseTestProblem.forward method for each arm in the trial, as well as return some metadata about the underlying Botorch problem such as the noise_std. We compute the full result on the Runner (as opposed to the Metric as is typical in synthetic test problems) because the BoTorch problem computes all metrics in one stacked tensor in the MOO case, and we wish to avoid recomputation per metric.

classmethod deserialize_init_args(args: Dict[str, Any]) Dict[str, Any][source]

Given a dictionary, deserialize the properties needed to initialize the runner. Used for storage.

poll_trial_status(trials: Iterable[ax.core.base_trial.BaseTrial]) Dict[ax.core.base_trial.TrialStatus, Set[int]][source]

Checks the status of any non-terminal trials and returns their indices as a mapping from TrialStatus to a list of indices. Required for runners used with Ax Scheduler.

NOTE: Does not need to handle waiting between polling calls while trials are running; this function should just perform a single poll.

Parameters

trials – Trials to poll.

Returns

A dictionary mapping TrialStatus to a list of trial indices that have the respective status at the time of the polling. This does not need to include trials that at the time of polling already have a terminal (ABANDONED, FAILED, COMPLETED) status (but it may).

run(trial: ax.core.base_trial.BaseTrial) Dict[str, Any][source]

Deploys a trial based on custom runner subclass implementation.

Parameters

trial – The trial to deploy.

Returns

Dict of run metadata from the deployment process.

classmethod serialize_init_args(obj: Any) Dict[str, Any][source]

Serialize the properties needed to initialize the runner. Used for storage.

Synthetic Runner

class ax.runners.synthetic.SyntheticRunner(dummy_metadata: Optional[str] = None)[source]

Bases: ax.core.runner.Runner

Class for synthetic or dummy runner.

Currently acts as a shell runner, only creating a name.

poll_trial_status(trials: Iterable[ax.core.base_trial.BaseTrial]) Dict[ax.core.base_trial.TrialStatus, Set[int]][source]

Checks the status of any non-terminal trials and returns their indices as a mapping from TrialStatus to a list of indices. Required for runners used with Ax Scheduler.

NOTE: Does not need to handle waiting between polling calls while trials are running; this function should just perform a single poll.

Parameters

trials – Trials to poll.

Returns

A dictionary mapping TrialStatus to a list of trial indices that have the respective status at the time of the polling. This does not need to include trials that at the time of polling already have a terminal (ABANDONED, FAILED, COMPLETED) status (but it may).

run(trial: ax.core.base_trial.BaseTrial) Dict[str, Any][source]

Deploys a trial based on custom runner subclass implementation.

Parameters

trial – The trial to deploy.

Returns

Dict of run metadata from the deployment process.

Simulated Backend Runner

class ax.runners.simulated_backend.SimulatedBackendRunner(simulator: ax.utils.testing.backend_simulator.BackendSimulator, sample_runtime_func: Optional[Callable[[ax.core.base_trial.BaseTrial], float]] = None)[source]

Bases: ax.core.runner.Runner

Class for a runner that works with the BackendSimulator.

poll_trial_status(trials: Iterable[ax.core.base_trial.BaseTrial]) Dict[ax.core.base_trial.TrialStatus, Set[int]][source]

Poll trial status from the BackendSimulator. NOTE: The Scheduler currently marks trials as running when they are created, but some of these trials may actually be in queued on the BackendSimulator.

Returns

A Dict mapping statuses to sets of trial indices.

run(trial: ax.core.base_trial.BaseTrial) Dict[str, Any][source]

Start a trial on the BackendSimulator.

Parameters

trial – Trial to deploy via the runner.

Returns

Dict containing the sampled runtime of the trial.

stop(trial: ax.core.base_trial.BaseTrial, reason: Optional[str] = None) Dict[str, Any][source]

Stop a trial on the BackendSimulator.

Parameters
  • trial – Trial to stop on the simulator.

  • reason – A message containing information why the trial is to be stopped.

Returns

A dictionary containing a single key “reason” that maps to the reason passed to the function. If no reason was given, returns an empty dictionary.

ax.runners.simulated_backend.sample_runtime_unif(trial: ax.core.base_trial.BaseTrial, low: float = 1.0, high: float = 5.0) float[source]

Return a uniform runtime in [low, high]

Parameters
  • trial – Trial for which to sample runtime.

  • low – Lower bound of uniform runtime distribution.

  • high – Upper bound of uniform runtime distribution.

Returns

A float representing the simulated trial runtime.

TorchX Runner

class ax.runners.torchx.TorchXRunner(tracker_base: str, component: Callable[[...], torchx.specs.api.AppDef], component_const_params: Optional[Dict[str, Any]] = None, scheduler: str = 'local', cfg: Optional[Mapping[str, Optional[Union[str, int, float, bool, List[str]]]]] = None)[source]

Bases: ax.core.runner.Runner

An implementation of ax.core.runner.Runner that delegates job submission to the TorchX Runner. This runner is coupled with the TorchX component since Ax runners run trials of a single component with different parameters.

It is expected that the experiment parameter names and types match EXACTLY with component’s function args. Component function args that are NOT part of the search space can be passed as component_const_params. The following args are passed automatically if declared in the component function’s signature:

  • trial_idx (int): current trial’s index

  • tracker_base (str): torchx tracker’s base (typically a URL indicating the base dir of the tracker)

Example:


def trainer_component(

x1: int, x2: float, trial_idx: int, tracker_base: str, x3: float, x4: str) -> spec.AppDef: # … implementation omitted for brevity … pass

The experiment should be set up as:


parameters=[ {

“name”: “x1”, “value_type”: “int”, # … other options…

}, {

“name”: “x2”, “value_type”: “float”, # … other options…

And the rest of the arguments can be set as:


TorchXRunner(

tracker_base=”s3://foo/bar”, component=trainer_component, # trial_idx and tracker_base args passed automatically # if the function signature declares those args component_const_params={“x3”: 1.2, “x4”: “barbaz”})

Running the experiment as set up above results in each trial running:


appdef = trainer_component(

x1=trial.params[“x1”], x2=trial.params[“x2”], trial_idx=trial.index, tracker_base=”s3://foo/bar”, x3=1.2, x4=”barbaz”)

torchx.runner.get_runner().run(appdef, …)

poll_trial_status(trials: Iterable[ax.core.base_trial.BaseTrial]) Dict[ax.core.base_trial.TrialStatus, Set[int]][source]

Checks the status of any non-terminal trials and returns their indices as a mapping from TrialStatus to a list of indices. Required for runners used with Ax Scheduler.

NOTE: Does not need to handle waiting between polling calls while trials are running; this function should just perform a single poll.

Parameters

trials – Trials to poll.

Returns

A dictionary mapping TrialStatus to a list of trial indices that have the respective status at the time of the polling. This does not need to include trials that at the time of polling already have a terminal (ABANDONED, FAILED, COMPLETED) status (but it may).

run(trial: ax.core.base_trial.BaseTrial) Dict[str, Any][source]

Submits the trial (which maps to an AppDef) as a job onto the scheduler using torchx.runner.

Note

only supports Trial (not BatchTrial).

stop(trial: ax.core.base_trial.BaseTrial, reason: Optional[str] = None) Dict[str, Any][source]

Kill the given trial.