ax.runners¶
BoTorch Test Problem¶
- class ax.runners.botorch_test_problem.BotorchTestProblemRunner(test_problem_class: Type[BaseTestProblem], test_problem_kwargs: Dict[str, Any])[source]¶
Bases:
Runner
A Runner for evaluation Botorch BaseTestProblems. Given a trial the Runner will evaluate the BaseTestProblem.forward method for each arm in the trial, as well as return some metadata about the underlying Botorch problem such as the noise_std. We compute the full result on the Runner (as opposed to the Metric as is typical in synthetic test problems) because the BoTorch problem computes all metrics in one stacked tensor in the MOO case, and we wish to avoid recomputation per metric.
- classmethod deserialize_init_args(args: Dict[str, Any]) Dict[str, Any] [source]¶
Given a dictionary, deserialize the properties needed to initialize the runner. Used for storage.
- poll_trial_status(trials: Iterable[BaseTrial]) Dict[TrialStatus, Set[int]] [source]¶
Checks the status of any non-terminal trials and returns their indices as a mapping from TrialStatus to a list of indices. Required for runners used with Ax
Scheduler
.NOTE: Does not need to handle waiting between polling calls while trials are running; this function should just perform a single poll.
- Parameters:
trials – Trials to poll.
- Returns:
A dictionary mapping TrialStatus to a list of trial indices that have the respective status at the time of the polling. This does not need to include trials that at the time of polling already have a terminal (ABANDONED, FAILED, COMPLETED) status (but it may).
- run(trial: BaseTrial) Dict[str, Any] [source]¶
Deploys a trial based on custom runner subclass implementation.
- Parameters:
trial – The trial to deploy.
- Returns:
Dict of run metadata from the deployment process.
- classmethod serialize_init_args(obj: Any) Dict[str, Any] [source]¶
Serialize the properties needed to initialize the runner. Used for storage.
- test_problem: BaseTestProblem¶
Synthetic Runner¶
- class ax.runners.synthetic.SyntheticRunner(dummy_metadata: Optional[str] = None)[source]¶
Bases:
Runner
Class for synthetic or dummy runner.
Currently acts as a shell runner, only creating a name.
- poll_trial_status(trials: Iterable[BaseTrial]) Dict[TrialStatus, Set[int]] [source]¶
Checks the status of any non-terminal trials and returns their indices as a mapping from TrialStatus to a list of indices. Required for runners used with Ax
Scheduler
.NOTE: Does not need to handle waiting between polling calls while trials are running; this function should just perform a single poll.
- Parameters:
trials – Trials to poll.
- Returns:
A dictionary mapping TrialStatus to a list of trial indices that have the respective status at the time of the polling. This does not need to include trials that at the time of polling already have a terminal (ABANDONED, FAILED, COMPLETED) status (but it may).
Simulated Backend Runner¶
- class ax.runners.simulated_backend.SimulatedBackendRunner(simulator: BackendSimulator, sample_runtime_func: Optional[Callable[[BaseTrial], float]] = None)[source]¶
Bases:
Runner
Class for a runner that works with the BackendSimulator.
- poll_trial_status(trials: Iterable[BaseTrial]) Dict[TrialStatus, Set[int]] [source]¶
Poll trial status from the
BackendSimulator
. NOTE: TheScheduler
currently marks trials as running when they are created, but some of these trials may actually be in queued on theBackendSimulator
.- Returns:
A Dict mapping statuses to sets of trial indices.
- run(trial: BaseTrial) Dict[str, Any] [source]¶
Start a trial on the BackendSimulator.
- Parameters:
trial – Trial to deploy via the runner.
- Returns:
Dict containing the sampled runtime of the trial.
- stop(trial: BaseTrial, reason: Optional[str] = None) Dict[str, Any] [source]¶
Stop a trial on the BackendSimulator.
- Parameters:
trial – Trial to stop on the simulator.
reason – A message containing information why the trial is to be stopped.
- Returns:
A dictionary containing a single key “reason” that maps to the reason passed to the function. If no reason was given, returns an empty dictionary.
- ax.runners.simulated_backend.sample_runtime_unif(trial: BaseTrial, low: float = 1.0, high: float = 5.0) float [source]¶
Return a uniform runtime in [low, high]
- Parameters:
trial – Trial for which to sample runtime.
low – Lower bound of uniform runtime distribution.
high – Upper bound of uniform runtime distribution.
- Returns:
A float representing the simulated trial runtime.
TorchX Runner¶
- class ax.runners.torchx.TorchXRunner(tracker_base: str, component: Callable[[...], AppDef], component_const_params: Optional[Dict[str, Any]] = None, scheduler: str = 'local', cfg: Optional[Mapping[str, Optional[Union[str, int, float, bool, List[str]]]]] = None)[source]¶
Bases:
Runner
An implementation of
ax.core.runner.Runner
that delegates job submission to the TorchX Runner. This runner is coupled with the TorchX component since Ax runners run trials of a single component with different parameters.It is expected that the experiment parameter names and types match EXACTLY with component’s function args. Component function args that are NOT part of the search space can be passed as
component_const_params
. The following args are passed automatically if declared in the component function’s signature:trial_idx (int)
: current trial’s indextracker_base (str)
: torchx tracker’s base (typically a URL indicating the base dir of the tracker)
Example:
- def trainer_component(
x1: int, x2: float, trial_idx: int, tracker_base: str, x3: float, x4: str) -> spec.AppDef: # … implementation omitted for brevity … pass
The experiment should be set up as:
parameters=[ {
“name”: “x1”, “value_type”: “int”, # … other options…
}, {
“name”: “x2”, “value_type”: “float”, # … other options…
And the rest of the arguments can be set as:
- TorchXRunner(
tracker_base=”s3://foo/bar”, component=trainer_component, # trial_idx and tracker_base args passed automatically # if the function signature declares those args component_const_params={“x3”: 1.2, “x4”: “barbaz”})
Running the experiment as set up above results in each trial running:
- appdef = trainer_component(
x1=trial.params[“x1”], x2=trial.params[“x2”], trial_idx=trial.index, tracker_base=”s3://foo/bar”, x3=1.2, x4=”barbaz”)
torchx.runner.get_runner().run(appdef, …)
- poll_trial_status(trials: Iterable[BaseTrial]) Dict[TrialStatus, Set[int]] [source]¶
Checks the status of any non-terminal trials and returns their indices as a mapping from TrialStatus to a list of indices. Required for runners used with Ax
Scheduler
.NOTE: Does not need to handle waiting between polling calls while trials are running; this function should just perform a single poll.
- Parameters:
trials – Trials to poll.
- Returns:
A dictionary mapping TrialStatus to a list of trial indices that have the respective status at the time of the polling. This does not need to include trials that at the time of polling already have a terminal (ABANDONED, FAILED, COMPLETED) status (but it may).