Source code for ax.core.trial
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import TYPE_CHECKING, Dict, List, Optional
from ax.core.arm import Arm
from ax.core.base_trial import BaseTrial, immutable_once_run
from ax.core.generator_run import GeneratorRun, GeneratorRunType
from ax.utils.common.typeutils import not_none
if TYPE_CHECKING:
# import as module to make sphinx-autodoc-typehints happy
from ax import core # noqa F401 # pragma: no cover
[docs]class Trial(BaseTrial):
"""Trial that only has one attached arm and no arm weights.
Args:
experiment: Experiment, to which this trial is attached.
generator_run: GeneratorRun, associated with this trial.
Trial has only one generator run (of just one arm)
attached to it. This can also be set later through `add_arm`
or `add_generator_run`, but a trial's associated genetor run is
immutable once set.
trial_type: Type of this trial, if used in MultiTypeExperiment.
ttl_seconds: If specified, trials will be considered failed after
this many seconds since the time the trial was ran, unless the
trial is completed before then. Meant to be used to detect
'dead' trials, for which the evaluation process might have
crashed etc., and which should be considered failed after
their 'time to live' has passed.
"""
def __init__(
self,
experiment: "core.experiment.Experiment",
generator_run: Optional[GeneratorRun] = None,
trial_type: Optional[str] = None,
ttl_seconds: Optional[int] = None,
) -> None:
super().__init__(
experiment=experiment, trial_type=trial_type, ttl_seconds=ttl_seconds
)
self._generator_run = None
if generator_run is not None:
self.add_generator_run(generator_run=generator_run)
@property
def generator_run(self) -> Optional[GeneratorRun]:
"""Generator run attached to this trial."""
return self._generator_run
@property
def arm(self) -> Optional[Arm]:
"""The arm associated with this batch."""
# pyre-fixme[16]: `Optional` has no attribute `arms`.
if self.generator_run is not None and len(self.generator_run.arms) > 1:
raise ValueError( # pragma: no cover
"Generator run associated with this trial included multiple "
"arms, but trial expects only one."
)
return self.generator_run.arms[0] if self.generator_run is not None else None
@immutable_once_run
def add_arm(self, arm: Arm) -> "Trial":
"""Add arm to the trial.
Returns:
The trial instance.
"""
return self.add_generator_run(
generator_run=GeneratorRun(arms=[arm], type=GeneratorRunType.MANUAL.name)
)
@immutable_once_run
def add_generator_run(
self, generator_run: GeneratorRun, multiplier: float = 1.0
) -> "Trial":
"""Add a generator run to the trial.
Note: since trial includes only one arm, this will raise a ValueError if
the generator run includes multiple arms.
Returns:
The trial instance.
"""
if len(generator_run.arms) > 1:
raise ValueError(
"Trial includes only one arm, but this generator run "
"included multiple."
)
self.experiment.search_space.check_types(
generator_run.arms[0].parameters, raise_error=True
)
self._check_existing_and_name_arm(generator_run.arms[0])
self._generator_run = generator_run
generator_run.index = 0
self._set_generation_step_index(
generation_step_index=generator_run._generation_step_index
)
return self
@property
def arms(self) -> List[Arm]:
"""All arms attached to this trial.
Returns:
arms: list of a single arm
attached to this trial if there is one, else None.
"""
# pyre-fixme[7]: Expected `List[Arm]` but got `Union[List[Optional[Arm]],
# List[_T]]`.
return [self.arm] if self.arm is not None else []
@property
def arms_by_name(self) -> Dict[str, Arm]:
"""Dictionary of all arms attached to this trial with their names
as keys.
Returns:
arms: dictionary of a single
arm name to arm if one is attached to this trial,
else None.
"""
# pyre-fixme[16]: `Optional` has no attribute `name`.
return {self.arm.name: self.arm} if self.arm is not None else {}
@property
def abandoned_arms(self) -> List[Arm]:
"""Abandoned arms attached to this trial."""
return (
[not_none(self.arm)]
if self.generator_run is not None
and self.arm is not None
and self.is_abandoned
else []
)
@property
def objective_mean(self) -> float:
"""Objective mean for the arm attached to this trial, retrieved from the
latest data available for the objective for the trial.
Note: the retrieved objective is the experiment-level objective at the
time of the call to `objective_mean`, which is not necessarily the
objective that was set at the time the trial was created or ran.
"""
# For SimpleExperiment, fetch_data just executes eval_trial.
df = self.fetch_data().df
if df.empty:
raise ValueError(f"No data was retrieved for trial {self.index}.")
opt_config = self.experiment.optimization_config
if opt_config is None:
raise ValueError( # pragma: no cover
"Experiment optimization config (and thus the objective) is not set."
)
return self.get_metric_mean(metric_name=opt_config.objective.metric.name)
[docs] def get_metric_mean(self, metric_name: str) -> float:
"""Metric mean for the arm attached to this trial, retrieved from the
latest data available for the metric for the trial.
"""
# For SimpleExperiment, fetch_data just executes eval_trial.
df = self.fetch_data().df
try:
return df.loc[df["metric_name"] == metric_name].iloc[0]["mean"]
except IndexError: # pragma: no cover
raise ValueError(f"Metric {metric_name} not yet in data for trial.")
def __repr__(self) -> str:
return (
"Trial("
f"experiment_name='{self._experiment._name}', "
f"index={self._index}, "
f"status={self._status}, "
f"arm={self.arm})"
)