Source code for ax.service.managed_loop

#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

import inspect
import logging
from typing import Iterable, List, Optional, Tuple

from ax.core.arm import Arm
from ax.core.base_trial import BaseTrial
from ax.core.batch_trial import BatchTrial
from ax.core.experiment import Experiment
from ax.core.formatting_utils import data_and_evaluations_from_raw_data
from ax.core.trial import Trial
from ax.core.types import (
    TEvaluationFunction,
    TEvaluationOutcome,
    TModelPredictArm,
    TParameterization,
)
from ax.core.utils import get_pending_observation_features
from ax.exceptions.constants import CHOLESKY_ERROR_ANNOTATION
from ax.exceptions.core import SearchSpaceExhausted, UserInputError
from ax.modelbridge.base import ModelBridge
from ax.modelbridge.dispatch_utils import choose_generation_strategy
from ax.modelbridge.generation_strategy import GenerationStrategy
from ax.modelbridge.registry import Models
from ax.service.utils.best_point import (
    get_best_parameters_from_model_predictions,
    get_best_raw_objective_point,
)
from ax.service.utils.instantiation import InstantiationBase, TParameterRepresentation
from ax.utils.common.executils import retry_on_exception
from ax.utils.common.logger import get_logger
from ax.utils.common.typeutils import not_none


logger: logging.Logger = get_logger(__name__)


[docs]class OptimizationLoop: """Managed optimization loop, in which Ax oversees deployment of trials and gathering data.""" def __init__( self, experiment: Experiment, evaluation_function: TEvaluationFunction, total_trials: int = 20, arms_per_trial: int = 1, random_seed: Optional[int] = None, wait_time: int = 0, run_async: bool = False, # TODO[Lena], generation_strategy: Optional[GenerationStrategy] = None, ) -> None: assert not run_async, "OptimizationLoop does not yet support async." self.wait_time = wait_time self.total_trials = total_trials self.arms_per_trial = arms_per_trial self.random_seed = random_seed self.evaluation_function: TEvaluationFunction = evaluation_function assert len(experiment.trials) == 0, ( "Optimization Loop should not be initialized with an experiment " "that has trials already." ) self.experiment = experiment if generation_strategy is None: # pyre-fixme[4]: Attribute must be annotated. self.generation_strategy = choose_generation_strategy( search_space=experiment.search_space, use_batch_trials=self.arms_per_trial > 1, random_seed=self.random_seed, experiment=experiment, ) else: self.generation_strategy = generation_strategy self.current_trial = 0
[docs] @staticmethod def with_evaluation_function( parameters: List[TParameterRepresentation], evaluation_function: TEvaluationFunction, experiment_name: Optional[str] = None, objective_name: Optional[str] = None, minimize: bool = False, parameter_constraints: Optional[List[str]] = None, outcome_constraints: Optional[List[str]] = None, total_trials: int = 20, arms_per_trial: int = 1, wait_time: int = 0, random_seed: Optional[int] = None, generation_strategy: Optional[GenerationStrategy] = None, ) -> "OptimizationLoop": """Constructs a synchronous `OptimizationLoop` using an evaluation function.""" experiment = InstantiationBase.make_experiment( name=experiment_name, parameters=parameters, objectives={objective_name: "minimize" if minimize else "maximize"} if objective_name else None, parameter_constraints=parameter_constraints, outcome_constraints=outcome_constraints, ) return OptimizationLoop( experiment=experiment, total_trials=total_trials, arms_per_trial=arms_per_trial, random_seed=random_seed, wait_time=wait_time, generation_strategy=generation_strategy, evaluation_function=evaluation_function, )
[docs] @classmethod def with_runners_and_metrics( cls, parameters: List[TParameterRepresentation], path_to_runner: str, paths_to_metrics: List[str], experiment_name: Optional[str] = None, objective_name: Optional[str] = None, minimize: bool = False, parameter_constraints: Optional[List[str]] = None, outcome_constraints: Optional[List[str]] = None, total_trials: int = 20, arms_per_trial: int = 1, wait_time: int = 0, random_seed: Optional[int] = None, ) -> "OptimizationLoop": """Constructs an asynchronous `OptimizationLoop` using Ax runners and metrics.""" # NOTE: Could use `Scheduler` to implement this if needed. raise NotImplementedError
def _call_evaluation_function( self, parameterization: TParameterization, weight: Optional[float] = None ) -> TEvaluationOutcome: signature = inspect.signature(self.evaluation_function) num_evaluation_function_params = len(signature.parameters.items()) if num_evaluation_function_params == 1: # pyre-ignore [20]: Can't run instance checks on subscripted generics. evaluation = self.evaluation_function(parameterization) elif num_evaluation_function_params == 2: # pyre-ignore [19]: Can't run instance checks on subscripted generics. evaluation = self.evaluation_function(parameterization, weight) else: raise UserInputError( "Evaluation function must take either one parameter " "(parameterization) or two parameters (parameterization and weight)." ) return evaluation def _get_new_trial(self) -> BaseTrial: if self.arms_per_trial == 1: return self.experiment.new_trial( generator_run=self.generation_strategy.gen( experiment=self.experiment, pending_observations=get_pending_observation_features( experiment=self.experiment ), ) ) elif self.arms_per_trial > 1: return self.experiment.new_batch_trial( generator_run=self.generation_strategy.gen( experiment=self.experiment, n=self.arms_per_trial ) ) else: raise UserInputError( f"Invalid number of arms per trial: {self.arms_per_trial}" ) def _get_weights_by_arm( self, trial: BaseTrial ) -> Iterable[Tuple[Arm, Optional[float]]]: if isinstance(trial, Trial): if trial.arm is not None: return [(not_none(trial.arm), None)] return [] elif isinstance(trial, BatchTrial): return trial.normalized_arm_weights().items() else: raise UserInputError(f"Invalid trial type: {type(trial)}")
[docs] @retry_on_exception( logger=logger, exception_types=(RuntimeError,), suppress_all_errors=False, wrap_error_message_in=CHOLESKY_ERROR_ANNOTATION, ) def run_trial(self) -> None: """Run a single step of the optimization plan.""" if self.current_trial >= self.total_trials: raise ValueError("Optimization is complete, cannot run another trial.") logger.info(f"Running optimization trial {self.current_trial + 1}...") trial = self._get_new_trial() trial.mark_running(no_runner_required=True) _, data = data_and_evaluations_from_raw_data( raw_data={ arm.name: self._call_evaluation_function(arm.parameters, weight) for arm, weight in self._get_weights_by_arm(trial) }, trial_index=self.current_trial, sample_sizes={}, data_type=self.experiment.default_data_type, metric_names=not_none( self.experiment.optimization_config ).objective.metric_names, ) self.experiment.attach_data(data=data) trial.mark_completed() self.current_trial += 1
[docs] def full_run(self) -> OptimizationLoop: """Runs full optimization loop as defined in the provided optimization plan.""" num_steps = self.total_trials logger.info(f"Started full optimization with {num_steps} steps.") for _ in range(num_steps): try: self.run_trial() except SearchSpaceExhausted as err: logger.info( f"Stopped optimization as the search space is exhaused. Message " f"from generation strategy: {err}." ) return self except Exception: logger.exception("Encountered exception during optimization: ") return self return self
[docs] def get_best_point(self) -> Tuple[TParameterization, Optional[TModelPredictArm]]: """Obtains the best point encountered in the course of this optimization.""" # Find latest trial which has a generator_run attached and get its predictions model_predictions = get_best_parameters_from_model_predictions( experiment=self.experiment, models_enum=Models ) if model_predictions is not None: return model_predictions # Could not find through model, default to using raw objective. parameterization, values = get_best_raw_objective_point( experiment=self.experiment ) # For values, grab just the means to conform to TModelPredictArm format. return ( parameterization, ( {k: v[0] for k, v in values.items()}, # v[0] is mean {k: {k: v[1] * v[1]} for k, v in values.items()}, # v[1] is sem ), )
[docs] def get_current_model(self) -> Optional[ModelBridge]: """Obtain the most recently used model in optimization.""" return self.generation_strategy.model
[docs]def optimize( parameters: List[TParameterRepresentation], evaluation_function: TEvaluationFunction, experiment_name: Optional[str] = None, objective_name: Optional[str] = None, minimize: bool = False, parameter_constraints: Optional[List[str]] = None, outcome_constraints: Optional[List[str]] = None, total_trials: int = 20, arms_per_trial: int = 1, random_seed: Optional[int] = None, generation_strategy: Optional[GenerationStrategy] = None, ) -> Tuple[ TParameterization, Optional[TModelPredictArm], Experiment, Optional[ModelBridge] ]: """Construct and run a full optimization loop.""" loop = OptimizationLoop.with_evaluation_function( parameters=parameters, objective_name=objective_name, evaluation_function=evaluation_function, experiment_name=experiment_name, minimize=minimize, parameter_constraints=parameter_constraints, outcome_constraints=outcome_constraints, total_trials=total_trials, arms_per_trial=arms_per_trial, random_seed=random_seed, generation_strategy=generation_strategy, ) loop.full_run() parameterization, values = loop.get_best_point() return parameterization, values, loop.experiment, loop.get_current_model()