Source code for ax.core.trial
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import TYPE_CHECKING, Dict, List, Optional
from ax.core.arm import Arm
from ax.core.base_trial import BaseTrial, immutable_once_run
from ax.core.generator_run import GeneratorRun, GeneratorRunType
from ax.utils.common.typeutils import not_none
if TYPE_CHECKING:
# import as module to make sphinx-autodoc-typehints happy
from ax import core # noqa F401 # pragma: no cover
[docs]class Trial(BaseTrial):
"""Trial that only has one attached arm and no arm weights.
Args:
experiment: experiment, to which this trial is attached
generator_run: generator_run associated with
this trial. Trial has only one generator run (and thus arm)
attached to it. This can also be set later through `add_arm`
or `add_generator_run`, but a trial's associated genetor run is
immutable once set.
"""
def __init__(
self,
experiment: "core.experiment.Experiment",
generator_run: Optional[GeneratorRun] = None,
trial_type: Optional[str] = None,
) -> None:
super().__init__(experiment=experiment, trial_type=trial_type)
self._generator_run = None
if generator_run is not None:
self.add_generator_run(generator_run=generator_run)
@property
def generator_run(self) -> Optional[GeneratorRun]:
"""Generator run attached to this trial."""
return self._generator_run
@property
def arm(self) -> Optional[Arm]:
"""The arm associated with this batch."""
# pyre-fixme[16]: `Optional` has no attribute `arms`.
if self.generator_run is not None and len(self.generator_run.arms) > 1:
raise ValueError( # pragma: no cover
"Generator run associated with this trial included multiple "
"arms, but trial expects only one."
)
return self.generator_run.arms[0] if self.generator_run is not None else None
@immutable_once_run
def add_arm(self, arm: Arm) -> "Trial":
"""Add arm to the trial.
Returns:
The trial instance.
"""
return self.add_generator_run(
generator_run=GeneratorRun(arms=[arm], type=GeneratorRunType.MANUAL.name)
)
@immutable_once_run
def add_generator_run(
self, generator_run: GeneratorRun, multiplier: float = 1.0
) -> "Trial":
"""Add a generator run to the trial.
Note: since trial includes only one arm, this will raise a ValueError if
the generator run includes multiple arms.
Returns:
The trial instance.
"""
if len(generator_run.arms) > 1:
raise ValueError(
"Trial includes only one arm, but this generator run "
"included multiple."
)
self.experiment.search_space.check_types(
generator_run.arms[0].parameters, raise_error=True
)
self._check_existing_and_name_arm(generator_run.arms[0])
self._generator_run = generator_run
generator_run.index = 0
return self
@property
def arms(self) -> List[Arm]:
"""All arms attached to this trial.
Returns:
arms: list of a single arm
attached to this trial if there is one, else None.
"""
# pyre-fixme[7]: Expected `List[Arm]` but got `Union[List[Optional[Arm]],
# List[_T]]`.
return [self.arm] if self.arm is not None else []
@property
def arms_by_name(self) -> Dict[str, Arm]:
"""Dictionary of all arms attached to this trial with their names
as keys.
Returns:
arms: dictionary of a single
arm name to arm if one is attached to this trial,
else None.
"""
# pyre-fixme[16]: `Optional` has no attribute `name`.
return {self.arm.name: self.arm} if self.arm is not None else {}
@property
def abandoned_arms(self) -> List[Arm]:
"""Abandoned arms attached to this trial."""
return (
[not_none(self.arm)]
if self.generator_run is not None
and self.arm is not None
and self.is_abandoned
else []
)
@property
def objective_mean(self) -> float:
"""Objective mean for the arm attached to this trial, retrieved from the
latest data available for the objective for the trial.
Note: the retrieved objective is the experiment-level objective at the
time of the call to `objective_mean`, which is not necessarily the
objective that was set at the time the trial was created or ran.
"""
# For SimpleExperiment, fetch_data just executes eval_trial.
df = self.fetch_data().df
if df.empty:
raise ValueError(f"No data was retrieved for trial {self.index}.")
opt_config = self.experiment.optimization_config
if opt_config is None:
raise ValueError( # pragma: no cover
"Experiment optimization config (and thus the objective) is not set."
)
return self.get_metric_mean(metric_name=opt_config.objective.metric.name)
[docs] def get_metric_mean(self, metric_name: str) -> float:
"""Metric mean for the arm attached to this trial, retrieved from the
latest data available for the metric for the trial.
"""
# For SimpleExperiment, fetch_data just executes eval_trial.
df = self.fetch_data().df
try:
return df.loc[df["metric_name"] == metric_name].iloc[0]["mean"]
except IndexError: # pragma: no cover
raise ValueError(f"Metric {metric_name} not yet in data for trial.")
def __repr__(self) -> str:
return (
"Trial("
f"experiment_name='{self._experiment._name}', "
f"index={self._index}, "
f"status={self._status}, "
f"arm={self.arm})"
)