ax.runners

BoTorch Test Problem

SingleRunningTrialMixin

class ax.runners.single_running_trial_mixin.SingleRunningTrialMixin[source]

Bases: object

Mixin for Runners with a single running trial.

This mixin implements a simple poll_trial_status method that allows for a single running trial (the latest running trial). The returned status of trials that currently are marked as running is completed.

poll_trial_status(trials: Iterable[BaseTrial]) dict[TrialStatus, set[int]][source]

Checks the status of any non-terminal trials and returns their indices as a mapping from TrialStatus to a list of indices. Required for runners used with Ax Scheduler.

NOTE: Does not need to handle waiting between polling calls while trials are running; this function should just perform a single poll.

Parameters:

trials – Trials to poll.

Returns:

A dictionary mapping TrialStatus to a list of trial indices that have the respective status at the time of the polling. This does not need to include trials that at the time of polling already have a terminal (ABANDONED, FAILED, COMPLETED) status (but it may).

Synthetic Runner

class ax.runners.synthetic.SyntheticRunner(dummy_metadata: str | None = None)[source]

Bases: Runner

Class for synthetic or dummy runner.

Currently acts as a shell runner, only creating a name.

poll_trial_status(trials: Iterable[BaseTrial]) dict[TrialStatus, set[int]][source]

Checks the status of any non-terminal trials and returns their indices as a mapping from TrialStatus to a list of indices. Required for runners used with Ax Scheduler.

NOTE: Does not need to handle waiting between polling calls while trials are running; this function should just perform a single poll.

Parameters:

trials – Trials to poll.

Returns:

A dictionary mapping TrialStatus to a list of trial indices that have the respective status at the time of the polling. This does not need to include trials that at the time of polling already have a terminal (ABANDONED, FAILED, COMPLETED) status (but it may).

run(trial: BaseTrial) dict[str, Any][source]

Deploys a trial based on custom runner subclass implementation.

Parameters:

trial – The trial to deploy.

Returns:

Dict of run metadata from the deployment process.

property run_metadata_report_keys: list[str]

A list of keys of the metadata dict returned by run() that are relevant outside the runner-internal impolementation. These can e.g. be reported in Scheduler.report_results().

Simulated Backend Runner

class ax.runners.simulated_backend.SimulatedBackendRunner(simulator: BackendSimulator, sample_runtime_func: Callable[[BaseTrial], float] | None = None)[source]

Bases: Runner

Class for a runner that works with the BackendSimulator.

poll_trial_status(trials: Iterable[BaseTrial]) dict[TrialStatus, set[int]][source]

Poll trial status from the BackendSimulator. NOTE: The Scheduler currently marks trials as running when they are created, but some of these trials may actually be in queued on the BackendSimulator.

Returns:

A Dict mapping statuses to sets of trial indices.

run(trial: BaseTrial) dict[str, Any][source]

Start a trial on the BackendSimulator.

Parameters:

trial – Trial to deploy via the runner.

Returns:

Dict containing the sampled runtime of the trial.

stop(trial: BaseTrial, reason: str | None = None) dict[str, Any][source]

Stop a trial on the BackendSimulator.

Parameters:
  • trial – Trial to stop on the simulator.

  • reason – A message containing information why the trial is to be stopped.

Returns:

A dictionary containing a single key “reason” that maps to the reason passed to the function. If no reason was given, returns an empty dictionary.

ax.runners.simulated_backend.sample_runtime_unif(trial: BaseTrial, low: float = 1.0, high: float = 5.0) float[source]

Return a uniform runtime in [low, high]

Parameters:
  • trial – Trial for which to sample runtime.

  • low – Lower bound of uniform runtime distribution.

  • high – Upper bound of uniform runtime distribution.

Returns:

A float representing the simulated trial runtime.

TorchX Runner

class ax.runners.torchx.TorchXRunner(tracker_base: str, component: Callable[[...], AppDef], component_const_params: dict[str, Any] | None = None, scheduler: str = 'local', cfg: Mapping[str, str | int | float | bool | List[str] | Dict[str, str] | None] | None = None)[source]

Bases: Runner

An implementation of ax.core.runner.Runner that delegates job submission to the TorchX Runner. This runner is coupled with the TorchX component since Ax runners run trials of a single component with different parameters.

It is expected that the experiment parameter names and types match EXACTLY with component’s function args. Component function args that are NOT part of the search space can be passed as component_const_params. The following args are passed automatically if declared in the component function’s signature:

  • trial_idx (int): current trial’s index

  • tracker_base (str): torchx tracker’s base (typically a URL indicating the base dir of the tracker)

Example:


def trainer_component(

x1: int, x2: float, trial_idx: int, tracker_base: str, x3: float, x4: str) -> spec.AppDef: # … implementation omitted for brevity … pass

The experiment should be set up as:


parameters=[ {

“name”: “x1”, “value_type”: “int”, # … other options…

}, {

“name”: “x2”, “value_type”: “float”, # … other options…

}

And the rest of the arguments can be set as:


TorchXRunner(

tracker_base=”s3://foo/bar”, component=trainer_component, # trial_idx and tracker_base args passed automatically # if the function signature declares those args component_const_params={“x3”: 1.2, “x4”: “barbaz”})

Running the experiment as set up above results in each trial running:


appdef = trainer_component(

x1=trial.params[“x1”], x2=trial.params[“x2”], trial_idx=trial.index, tracker_base=”s3://foo/bar”, x3=1.2, x4=”barbaz”)

torchx.runner.get_runner().run(appdef, …)

poll_trial_status(trials: Iterable[BaseTrial]) dict[TrialStatus, set[int]][source]

Checks the status of any non-terminal trials and returns their indices as a mapping from TrialStatus to a list of indices. Required for runners used with Ax Scheduler.

NOTE: Does not need to handle waiting between polling calls while trials are running; this function should just perform a single poll.

Parameters:

trials – Trials to poll.

Returns:

A dictionary mapping TrialStatus to a list of trial indices that have the respective status at the time of the polling. This does not need to include trials that at the time of polling already have a terminal (ABANDONED, FAILED, COMPLETED) status (but it may).

run(trial: BaseTrial) dict[str, Any][source]

Submits the trial (which maps to an AppDef) as a job onto the scheduler using torchx.runner.

Note

only supports Trial (not BatchTrial).

stop(trial: BaseTrial, reason: str | None = None) dict[str, Any][source]

Kill the given trial.