#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
from collections import OrderedDict, defaultdict
from datetime import datetime
from typing import (
TYPE_CHECKING,
DefaultDict,
Dict,
List,
MutableMapping,
NamedTuple,
Optional,
Union,
)
import numpy as np
from ax.core.arm import Arm
from ax.core.base_trial import BaseTrial
from ax.core.generator_run import GeneratorRun, GeneratorRunType
from ax.core.trial import immutable_once_run
from ax.utils.common.equality import datetime_equals, equality_typechecker
from ax.utils.common.typeutils import checked_cast, not_none
if TYPE_CHECKING:
# import as module to make sphinx-autodoc-typehints happy
from ax import core # noqa F401 # pragma: no cover
[docs]class AbandonedArm(NamedTuple):
"""Tuple storing metadata of arm that has been abandoned within
a BatchTrial.
"""
name: str
time: datetime
reason: Optional[str] = None
@equality_typechecker
def __eq__(self, other: "AbandonedArm") -> bool:
return (
self.name == other.name
and self.reason == other.reason
and datetime_equals(self.time, other.time)
)
[docs]class GeneratorRunStruct(NamedTuple):
"""Stores GeneratorRun object as well as the weight with which it was added."""
generator_run: GeneratorRun
weight: float
[docs]class BatchTrial(BaseTrial):
def __init__(
self,
experiment: "core.experiment.Experiment",
generator_run: Optional[GeneratorRun] = None,
trial_type: Optional[str] = None,
) -> None:
super().__init__(experiment=experiment, trial_type=trial_type)
self._generator_run_structs: List[GeneratorRunStruct] = []
self._abandoned_arms_metadata: Dict[str, AbandonedArm] = {}
self._status_quo: Optional[Arm] = None
self._status_quo_weight: float = 0.0
if generator_run is not None:
self.add_generator_run(generator_run=generator_run)
self.status_quo = experiment.status_quo
@property
def experiment(self) -> "core.experiment.Experiment":
"""The experiment this batch belongs to."""
return self._experiment
@property
def index(self) -> int:
"""The index of this batch within the experiment's batch list."""
return self._index
@property
def generator_run_structs(self) -> List[GeneratorRunStruct]:
"""List of generator run structs attached to this trial.
Struct holds generator_run object and the weight with which it was added.
"""
return self._generator_run_structs
@property
def arm_weights(self) -> Optional[MutableMapping[Arm, float]]:
"""The set of arms and associated weights for the trial.
These are constructed by merging the arms and weights from
each generator run that is attached to the trial.
"""
if len(self._generator_run_structs) == 0 and self.status_quo is None:
return None
arm_weights = OrderedDict()
for struct in self._generator_run_structs:
multiplier = struct.weight
for arm, weight in struct.generator_run.arm_weights.items():
scaled_weight = weight * multiplier
if arm in arm_weights:
arm_weights[arm] += scaled_weight
else:
arm_weights[arm] = scaled_weight
if self.status_quo is not None:
arm_weights[self.status_quo] = self._status_quo_weight + arm_weights.get(
self.status_quo, 0.0
)
return arm_weights
@arm_weights.setter
def arm_weights(self, arm_weights: MutableMapping[Arm, float]) -> None:
raise NotImplementedError("Use `trial.add_arms_and_weights`")
@immutable_once_run
def add_arm(self, arm: Arm, weight: float = 1.0) -> "BatchTrial":
"""Add a arm to the trial.
Args:
arm: The arm to be added.
weight: The weight with which this arm should be added.
Returns:
The trial instance.
"""
return self.add_arms_and_weights(arms=[arm], weights=[weight])
@immutable_once_run
def add_arms_and_weights(
self,
arms: List[Arm],
weights: Optional[List[float]] = None,
multiplier: float = 1.0,
) -> "BatchTrial":
"""Add arms and weights to the trial.
Args:
arms: The arms to be added.
weights: The weights associated with the arms.
multiplier: The multiplier applied to input weights before merging with
the current set of arms and weights.
Returns:
The trial instance.
"""
return self.add_generator_run(
generator_run=GeneratorRun(
arms=arms, weights=weights, type=GeneratorRunType.MANUAL.name
),
multiplier=multiplier,
)
@immutable_once_run
def add_generator_run(
self, generator_run: GeneratorRun, multiplier: float = 1.0
) -> "BatchTrial":
"""Add a generator run to the trial.
The arms and weights from the generator run will be merged with
the existing arms and weights on the trial, and the generator run
object will be linked to the trial for tracking.
Args:
generator_run: The generator run to be added.
multiplier: The multiplier applied to input weights before merging with
the current set of arms and weights.
Returns:
The trial instance.
"""
# Copy the generator run, to preserve initial and skip mutations to arms.
generator_run = generator_run.clone()
# First validate generator run arms
for arm in generator_run.arms:
self.experiment.search_space.check_types(arm.parameters, raise_error=True)
# Add names to arms
# For those not yet added to this experiment, create a new name
# Else, use the name of the existing arm
for arm in generator_run.arms:
self._check_existing_and_name_arm(arm)
self._generator_run_structs.append(
GeneratorRunStruct(generator_run=generator_run, weight=multiplier)
)
generator_run.index = len(self._generator_run_structs) - 1
# Resetting status quo reweights the status_quo, based on new arms
self.reweight_status_quo()
return self
@property
def status_quo(self) -> Optional[Arm]:
"""The control arm for this batch."""
return self._status_quo
@status_quo.setter
def status_quo(self, status_quo: Optional[Arm]) -> None:
"""Sets status quo arm."""
self.set_status_quo_with_weight(status_quo)
@immutable_once_run
def set_status_quo_with_weight(
self, status_quo: Optional[Arm], weight: Optional[float] = None
) -> "BatchTrial":
"""Sets status quo arm.
Defaults weight to average of existing weights or 1.0 if no weights exist.
"""
# Assign a name to this arm if none exists
if weight is not None and weight <= 0.0:
raise ValueError("Status quo weight must be positive.")
if status_quo is not None:
self.experiment.search_space.check_types(
status_quo.parameters, raise_error=True
)
self.experiment._name_and_store_arm_if_not_exists(
arm=status_quo, proposed_name="status_quo_" + str(self.index)
)
self._status_quo = status_quo
self.reweight_status_quo(weight)
return self
@immutable_once_run
def reweight_status_quo(self, weight: Optional[float] = None) -> "BatchTrial":
"""Update status quo weight.
If arms have been added since the status quo was initially added,
the optimal weight of the status quo may change.
"""
status_quo = self._status_quo
# Unset status_quo so avg weight computation works as intended
self._status_quo = None
if weight is None:
weight = (
1.0
if len(self.weights) == 0
else float(sum(self.weights)) / len(self.weights)
)
self._status_quo = status_quo
self._status_quo_weight = weight
return self
@immutable_once_run
def set_status_quo_and_optimize_power(self, status_quo: Arm) -> "BatchTrial":
"""Adds a status quo arm to the batch and optimizes for power.
]
This function will maximize power across the multiple pair-wise
comparisons of existing arms against the status_quo.
Specifically, this function assigns sqrt(sum_weights) weight to the
status quo, where sum_weights is the sum of the weights of the existing
arms, excluding the status quo. This will be optimal in terms of
statistical power in the case where:
1) status quo is the only arm to compare against
2) all other arms are of equal interest
"""
self.status_quo = status_quo
if len(self.arms) == 1:
# If status quo is the only arm, don't adjust weights
# (will end up setting status quo weight to 0.0)
return self
# arm_weights should always have at least one arm now, the status quo
arm_weights = not_none(self.arm_weights)
sum_weights = sum(w for arm, w in arm_weights.items() if arm != status_quo)
optimal_status_quo_weight = np.sqrt(sum_weights)
# arm_weights[status_quo] will be equal to the weight that the status quo
# has from the generator runs, plus _status_quo_weight.
# Thus, to make sure that arm_weights[status_quo] = optimal_status_quo_weight,
# _status_quo_weight should be equal to optimal_status_quo_weight - the amount
# of weight that the status quo has from the generator runs.
status_quo_weight_from_generator_runs = (
arm_weights[status_quo] - self._status_quo_weight
)
# Technically it's possible for _status_quo_weight to be negative here,
# if the weight of the status quo in the generator runs is larger than
# the optimal weight. But that's okay, because that's the only way
# to get the status quo to the correct weight.
self._status_quo_weight = (
optimal_status_quo_weight - status_quo_weight_from_generator_runs
)
return self
@property
def arms(self) -> List[Arm]:
"""All arms contained in the trial."""
arm_weights = self.arm_weights
return [] if arm_weights is None else list(arm_weights.keys())
@property
def weights(self) -> List[float]:
"""Weights corresponding to arms contained in the trial."""
arm_weights = self.arm_weights
return [] if arm_weights is None else list(arm_weights.values())
@property
def arms_by_name(self) -> Dict[str, Arm]:
"""Map from arm name to object for all arms in trial."""
arms_by_name = {}
for arm in self.arms:
if not arm.has_name:
raise ValueError( # pragma: no cover
"Arms attached to a trial must have a name."
)
arms_by_name[arm.name] = arm
return arms_by_name
@property
def abandoned_arms(self) -> List[Arm]:
"""List of arms that have been abandoned within this trial"""
return [
self.arms_by_name[arm.name]
for arm in self._abandoned_arms_metadata.values()
]
@property
def abandoned_arms_metadata(self) -> List[AbandonedArm]:
return list(self._abandoned_arms_metadata.values())
@property
def is_factorial(self) -> bool:
"""Return true if the trial's arms are a factorial design with
no linked factors.
"""
# To match the model behavior, this should probably actually be pulled
# from exp.parameters. However, that seems rather ugly when this function
# intuitively should just depend on the arms.
sufficient_factors = all(len(arm.parameters or []) >= 2 for arm in self.arms)
if not sufficient_factors:
return False
param_levels: DefaultDict[str, Dict[Union[str, float], int]] = (
defaultdict(dict)
)
for arm in self.arms:
for param_name, param_value in arm.parameters.items():
# Expected `Union[float, str]` for 2nd anonymous parameter to call
# `dict.__setitem__` but got `Optional[Union[bool, float, str]]`.
# pyre-fixme[6]: Expected `Union[float, str]` for 1st param but got `...
param_levels[param_name][param_value] = 1
param_cardinality = 1
for param_values in param_levels.values():
param_cardinality *= len(param_values)
return len(self.arms) == param_cardinality
[docs] def run(self) -> "BatchTrial":
return checked_cast(BatchTrial, super().run())
[docs] def normalized_arm_weights(
self, total: float = 1, trunc_digits: Optional[int] = None
) -> MutableMapping[Arm, float]:
"""Returns arms with a new set of weights normalized
to the given total.
This method is useful for many runners where we need to normalize weights
to a certain total without mutating the weights attached to a trial.
Args:
total: The total weight to which to normalize.
Default is 1, in which case arm weights
can be interpreted as probabilities.
trunc_digits: The number of digits to keep. If the
resulting total weight is not equal to `total`, re-allocate
weight in such a way to maintain relative weights as best as
possible.
Returns:
Mapping from arms to the new set of weights.
"""
weights = np.array(self.weights)
if trunc_digits is not None:
atomic_weight = 10 ** -trunc_digits
# pyre-fixme[16]: `float` has no attribute `astype`.
int_weights = (
(total / atomic_weight) * (weights / np.sum(weights))
).astype(int)
n_leftover = int(total / atomic_weight) - np.sum(int_weights)
int_weights[:n_leftover] += 1
weights = int_weights * atomic_weight
else:
weights = weights * (total / np.sum(weights))
return OrderedDict(zip(self.arms, weights))
[docs] def mark_arm_abandoned(
self, arm_name: str, reason: Optional[str] = None
) -> "BatchTrial":
"""Mark a arm abandoned.
Usually done after deployment when one arm causes issues but
user wants to continue running other arms in the batch.
Args:
arm_name: The name of the arm to abandon.
reason: The reason for abandoning the arm.
Returns:
The batch instance.
"""
if arm_name not in self.arms_by_name:
raise ValueError("Arm must be contained in batch.")
abandoned_arm = AbandonedArm(name=arm_name, time=datetime.now(), reason=reason)
self._abandoned_arms_metadata[arm_name] = abandoned_arm
return self
[docs] def clone(self) -> "BatchTrial":
"""Clone the trial.
Returns:
A new instance of the trial.
"""
new_trial = self._experiment.new_batch_trial()
for struct in self._generator_run_structs:
new_trial.add_generator_run(struct.generator_run, struct.weight)
new_trial.trial_type = self._trial_type
new_trial.runner = self._runner
return new_trial
def __repr__(self) -> str:
return (
"BatchTrial("
f"experiment_name='{self._experiment._name}', "
f"index={self._index}, "
f"status={self._status})"
)