Source code for ax.runners.torchx

#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

import inspect
from collections.abc import Callable, Iterable, Mapping

from logging import Logger
from typing import Any

from ax.core import Trial
from ax.core.base_trial import BaseTrial, TrialStatus
from ax.core.runner import Runner
from ax.utils.common.logger import get_logger
from pyre_extensions import none_throws

logger: Logger = get_logger(__name__)


try:
    from torchx.runner import get_runner, Runner as torchx_Runner
    from torchx.specs import AppDef, AppState, AppStatus, CfgVal

    TORCHX_APP_HANDLE: str = "torchx_app_handle"
    TORCHX_RUNNER: str = "torchx_runner"
    TORCHX_TRACKER_BASE: str = "torchx_tracker_base"

    # Maps TorchX AppState to Ax's TrialStatus.
    APP_STATE_TO_TRIAL_STATUS: dict[AppState, TrialStatus] = {
        AppState.UNSUBMITTED: TrialStatus.CANDIDATE,
        AppState.SUBMITTED: TrialStatus.STAGED,
        AppState.PENDING: TrialStatus.STAGED,
        AppState.RUNNING: TrialStatus.RUNNING,
        AppState.SUCCEEDED: TrialStatus.COMPLETED,
        AppState.CANCELLED: TrialStatus.ABANDONED,
        AppState.FAILED: TrialStatus.FAILED,
        AppState.UNKNOWN: TrialStatus.FAILED,
    }

[docs] class TorchXRunner(Runner): """ An implementation of ``ax.core.runner.Runner`` that delegates job submission to the TorchX Runner. This runner is coupled with the TorchX component since Ax runners run trials of a single component with different parameters. It is expected that the experiment parameter names and types match EXACTLY with component's function args. Component function args that are NOT part of the search space can be passed as ``component_const_params``. The following args are passed automatically if declared in the component function's signature: * ``trial_idx (int)``: current trial's index * ``tracker_base (str)``: torchx tracker's base (typically a URL indicating the base dir of the tracker) Example: .. code-block:: python def trainer_component( x1: int, x2: float, trial_idx: int, tracker_base: str, x3: float, x4: str) -> spec.AppDef: # ... implementation omitted for brevity ... pass The experiment should be set up as: .. code-block:: python parameters=[ { "name": "x1", "value_type": "int", # ... other options... }, { "name": "x2", "value_type": "float", # ... other options... } ] And the rest of the arguments can be set as: .. code-block:: python TorchXRunner( tracker_base="s3://foo/bar", component=trainer_component, # trial_idx and tracker_base args passed automatically # if the function signature declares those args component_const_params={"x3": 1.2, "x4": "barbaz"}) Running the experiment as set up above results in each trial running: .. code-block:: python appdef = trainer_component( x1=trial.params["x1"], x2=trial.params["x2"], trial_idx=trial.index, tracker_base="s3://foo/bar", x3=1.2, x4="barbaz") torchx.runner.get_runner().run(appdef, ...) """ def __init__( self, tracker_base: str, component: Callable[..., AppDef], component_const_params: dict[str, Any] | None = None, scheduler: str = "local", cfg: Mapping[str, CfgVal] | None = None, ) -> None: self._component: Callable[..., AppDef] = component self._scheduler: str = scheduler self._cfg: Mapping[str, CfgVal] | None = cfg # need to use the same runner in case it has state # e.g. torchx's local_scheduler has state hence need to poll status # on the same scheduler instance self._torchx_runner: torchx_Runner = get_runner() self._tracker_base = tracker_base self._component_const_params: dict[str, Any] = component_const_params or {}
[docs] def run(self, trial: BaseTrial) -> dict[str, Any]: """ Submits the trial (which maps to an AppDef) as a job onto the scheduler using ``torchx.runner``. .. note:: only supports `Trial` (not `BatchTrial`). """ if not isinstance(trial, Trial): raise ValueError( f"{type(trial)} is not supported. Check your experiment setup" ) parameters = dict(self._component_const_params) parameters.update(none_throws(trial.arm).parameters) component_args = inspect.getfullargspec(self._component).args if "trial_idx" in component_args: parameters["trial_idx"] = trial.index if "experiment_name" in component_args: parameters["experiment_name"] = trial.experiment.name if "tracker_base" in component_args: parameters["tracker_base"] = self._tracker_base appdef = self._component(**parameters) app_handle = self._torchx_runner.run(appdef, self._scheduler, self._cfg) return { TORCHX_APP_HANDLE: app_handle, TORCHX_RUNNER: self._torchx_runner, TORCHX_TRACKER_BASE: self._tracker_base, }
[docs] def poll_trial_status( self, trials: Iterable[BaseTrial] ) -> dict[TrialStatus, set[int]]: trial_statuses: dict[TrialStatus, set[int]] = {} for trial in trials: app_handle: str = trial.run_metadata[TORCHX_APP_HANDLE] torchx_runner = trial.run_metadata[TORCHX_RUNNER] app_status: AppStatus = torchx_runner.status(app_handle) trial_status = APP_STATE_TO_TRIAL_STATUS[app_status.state] indices = trial_statuses.setdefault(trial_status, set()) indices.add(trial.index) return trial_statuses
[docs] def stop(self, trial: BaseTrial, reason: str | None = None) -> dict[str, Any]: """Kill the given trial.""" app_handle: str = trial.run_metadata[TORCHX_APP_HANDLE] self._torchx_runner.stop(app_handle) return {"reason": reason} if reason else {}
except ImportError: logger.warning( "torchx package not found. If you would like to use TorchXRunner, please " "install torchx." ) pass