Source code for ax.core.trial

#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

from __future__ import annotations

from functools import partial

from logging import Logger

from typing import Any, Dict, List, Optional, TYPE_CHECKING, Union

from ax.core.arm import Arm
from ax.core.base_trial import BaseTrial, immutable_once_run
from import Data
from ax.core.generator_run import GeneratorRun, GeneratorRunType
from ax.core.types import TCandidateMetadata, TEvaluationOutcome
from ax.exceptions.core import UnsupportedError
from ax.utils.common.docutils import copy_doc
from ax.utils.common.logger import _round_floats_for_logging, get_logger
from ax.utils.common.typeutils import not_none
from pyre_extensions import override

logger: Logger = get_logger(__name__)


# pyre-fixme[5]: Global expression must be annotated.
round_floats_for_logging = partial(

    # import as module to make sphinx-autodoc-typehints happy
    from ax import core  # noqa F401

[docs]class Trial(BaseTrial): """Trial that only has one attached arm and no arm weights. Args: experiment: Experiment, to which this trial is attached. generator_run: GeneratorRun, associated with this trial. Trial has only one generator run (of just one arm) attached to it. This can also be set later through `add_arm` or `add_generator_run`, but a trial's associated genetor run is immutable once set. trial_type: Type of this trial, if used in MultiTypeExperiment. ttl_seconds: If specified, trials will be considered failed after this many seconds since the time the trial was ran, unless the trial is completed before then. Meant to be used to detect 'dead' trials, for which the evaluation process might have crashed etc., and which should be considered failed after their 'time to live' has passed. index: If specified, the trial's index will be set accordingly. This should generally not be specified, as in the index will be automatically determined based on the number of existing trials. This is only used for the purpose of loading from storage. """ def __init__( self, experiment: core.experiment.Experiment, generator_run: Optional[GeneratorRun] = None, trial_type: Optional[str] = None, ttl_seconds: Optional[int] = None, index: Optional[int] = None, ) -> None: super().__init__( experiment=experiment, trial_type=trial_type, ttl_seconds=ttl_seconds, index=index, ) # pyre-fixme[4]: Attribute must be annotated. self._generator_run = None if generator_run is not None: self.add_generator_run(generator_run=generator_run) @property def generator_run(self) -> Optional[GeneratorRun]: """Generator run attached to this trial.""" return self._generator_run # pyre-ignore[6]: T77111662. @copy_doc(BaseTrial.generator_runs) @property def generator_runs(self) -> List[GeneratorRun]: gr = self._generator_run return [gr] if gr is not None else [] @property def arm(self) -> Optional[Arm]: """The arm associated with this batch.""" if self.generator_run is None: return None generator_run = not_none(self.generator_run) if len(generator_run.arms) == 0: return None elif len(generator_run.arms) > 1: raise ValueError( "Generator run associated with this trial included multiple " "arms, but trial expects only one." ) return generator_run.arms[0] @immutable_once_run def add_arm( self, arm: Arm, candidate_metadata: Optional[Dict[str, Any]] = None ) -> Trial: """Add arm to the trial. Returns: The trial instance. """ return self.add_generator_run( generator_run=GeneratorRun( arms=[arm],, # pyre-ignore[6]: In call `GeneratorRun.__init__`, for 3rd parameter # `candidate_metadata_by_arm_signature` # expected `Optional[Dict[str, Optional[Dict[str, typing.Any]]]]` # but got `Optional[Dict[str, Dict[str, typing.Any]]]` candidate_metadata_by_arm_signature=( None if candidate_metadata is None else {arm.signature: candidate_metadata.copy()} ), ) ) @immutable_once_run def add_generator_run( self, generator_run: GeneratorRun, multiplier: float = 1.0 ) -> Trial: """Add a generator run to the trial. Note: since trial includes only one arm, this will raise a ValueError if the generator run includes multiple arms. Returns: The trial instance. """ if len(generator_run.arms) > 1: raise ValueError( "Trial includes only one arm, but this generator run " "included multiple." ) self.experiment.search_space.check_types( generator_run.arms[0].parameters, raise_error=True ) self._check_existing_and_name_arm(generator_run.arms[0]) self._generator_run = generator_run generator_run.index = 0 self._set_generation_step_index( generation_step_index=generator_run._generation_step_index ) return self @property def arms(self) -> List[Arm]: """All arms attached to this trial. Returns: arms: list of a single arm attached to this trial if there is one, else None. """ return [self.arm] if self.arm is not None else [] @property def arms_by_name(self) -> Dict[str, Arm]: """Dictionary of all arms attached to this trial with their names as keys. Returns: arms: dictionary of a single arm name to arm if one is attached to this trial, else None. """ return { self.arm} if self.arm is not None else {} @property def abandoned_arms(self) -> List[Arm]: """Abandoned arms attached to this trial.""" return ( [not_none(self.arm)] if self.generator_run is not None and self.arm is not None and self.is_abandoned else [] ) @property def objective_mean(self) -> float: """Objective mean for the arm attached to this trial, retrieved from the latest data available for the objective for the trial. Note: the retrieved objective is the experiment-level objective at the time of the call to `objective_mean`, which is not necessarily the objective that was set at the time the trial was created or ran. """ # For SimpleExperiment, fetch_data just executes eval_trial. opt_config = self.experiment.optimization_config if opt_config is None: raise ValueError( "Experiment optimization config (and thus the objective) is not set." ) return self.get_metric_mean(
[docs] def get_metric_mean(self, metric_name: str) -> float: """Metric mean for the arm attached to this trial, retrieved from the latest data available for the metric for the trial. """ fetch_result = self.lookup_data() try: df = fetch_result.df return df.loc[df["metric_name"] == metric_name].iloc[0]["mean"] except IndexError: raise ValueError(f"Metric {metric_name} not yet in data for trial.")
def __repr__(self) -> str: return ( "Trial(" f"experiment_name='{self._experiment._name}', " f"index={self._index}, " f"status={self._status}, " f"arm={self.arm})" ) def _get_candidate_metadata_from_all_generator_runs( self, ) -> Dict[str, TCandidateMetadata]: """Retrieves candidate metadata from the generator run on this batch trial in the form of { arm name -> candidate metadata} mapping. """ gr = self.generator_run if gr is None or gr.candidate_metadata_by_arm_signature is None: return {} cand_metadata = not_none(gr.candidate_metadata_by_arm_signature) return { cand_metadata.get(a.signature) for a in gr.arms} def _get_candidate_metadata(self, arm_name: str) -> TCandidateMetadata: """Retrieves candidate metadata for a specific arm.""" gr = self.generator_run if gr is None or gr.arms[0].name != arm_name: raise ValueError( f"Arm by name {arm_name} is not part of trial #{self.index}." ) if gr.candidate_metadata_by_arm_signature is None: return None arm = gr.arms[0] return not_none(gr.candidate_metadata_by_arm_signature).get(arm.signature)
[docs] def validate_data_for_trial(self, data: Data) -> None: """Utility method to validate data before further processing.""" for metric_name in data.df["metric_name"].values: if metric_name not in self.experiment.metrics: f"Data was logged for metric {metric_name} that was not yet " "tracked on the experiment. Please specify `tracking_metric_" "names` argument in AxClient.create_experiment to add tracking " "metrics to the experiment. Without those, all data users " "specify is still attached to the experiment, but will not be " "fetched in `experiment.fetch_data()`, but you can still use " "`experiment.lookup_data_for_trial` to get all attached data." )
[docs] def update_trial_data( self, raw_data: TEvaluationOutcome, metadata: Optional[Dict[str, Union[str, int]]] = None, sample_size: Optional[int] = None, combine_with_last_data: bool = False, ) -> str: """Utility method that attaches data to a trial and returns an update message. Args: raw_data: Evaluation data for the trial. Can be a mapping from metric name to a tuple of mean and SEM, just a tuple of mean and SEM if only one metric in optimization, or just the mean if SEM is unknown (then Ax will infer observation noise level). Can also be a list of (fidelities, mapping from metric name to a tuple of mean and SEM). metadata: Additional metadata to track about this run, optional. sample_size: Number of samples collected for the underlying arm, optional. combine_with_last_data: Whether to combine the given data with the data that was previously attached to the trial. See `Experiment.attach_data` for a detailed explanation. Returns: A string message summarizing the update. """ arm_name = not_none(self.arm).name sample_sizes = {arm_name: sample_size} if sample_size else {} raw_data_by_arm = {arm_name: raw_data} evaluations, data = self._make_evaluations_and_data( raw_data=raw_data_by_arm, metadata=metadata, sample_sizes=sample_sizes, ) self.validate_data_for_trial(data=data) self.update_run_metadata(metadata=metadata or {}) self.experiment.attach_data( data=data, combine_with_last_data=combine_with_last_data ) return str( round_floats_for_logging(item=evaluations[next(iter(evaluations.keys()))]) )
[docs] def clone_to( self, experiment: Optional[core.experiment.Experiment] = None, ) -> Trial: """Clone the trial and attach it to the specified experiment. If no experiment is provided, the original experiment will be used. Args: experiment: The experiment to which the cloned trial will belong. If unspecified, uses the current experiment. Returns: A new instance of the trial. """ experiment = self._experiment if experiment is None else experiment new_trial = experiment.new_trial( ttl_seconds=self.ttl_seconds, trial_type=self.trial_type ) if self.generator_run is not None: new_trial.add_generator_run(self.generator_run.clone()) self._update_trial_attrs_on_clone(new_trial=new_trial) return new_trial
@override def _raise_cant_attach_if_completed(self) -> None: """ Helper method used by `validate_can_attach_data` to raise an error if the user tries to attach data to a completed trial. Subclasses such as `Trial` override this by suggesting a remediation. """ raise UnsupportedError( f"Trial {self.index} has already been completed with data. " "To add more data to it (for example, for a different metric), " f"use `{self.__class__.__name__}.update_trial_data()`." )