Source code for ax.core.arm
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# pyre-strict
import hashlib
import json
from collections.abc import Mapping
from ax.core.types import TParameterization, TParamValue
from ax.utils.common.base import SortableBase
from ax.utils.common.equality import equality_typechecker
from ax.utils.common.typeutils_nonnative import numpy_type_to_python_type
[docs]
class Arm(SortableBase):
"""Base class for defining arms.
Randomization in experiments assigns units to a given arm. Thus, the arm
encapsulates the parametrization needed by the unit.
"""
def __init__(self, parameters: TParameterization, name: str | None = None) -> None:
"""Inits Arm.
Args:
parameters: Mapping from parameter names to values.
name: Defaults to None; will be set when arm is attached to a trial
"""
self._parameters: TParameterization = _numpy_types_to_python_types(parameters)
self._name = name
@property
def parameters(self) -> TParameterization:
"""Get mapping from parameter names to values."""
# Make a copy before returning so it cannot be accidentally mutated
return self._parameters.copy()
@property
def has_name(self) -> bool:
"""Return true if arm's name is not None."""
return self._name is not None
@property
def name(self) -> str:
"""Get arm name. Throws if name is None."""
if self._name is None:
raise ValueError("Arm's name is None.")
return self._name
@property
def name_or_short_signature(self) -> str:
"""Returns arm name if exists; else last 8 characters of the hash.
Used for presentation of candidates (e.g. plotting and tables),
where the candidates do not yet have names (since names are
automatically set upon addition to a trial).
"""
return self._name or self.signature[-8:]
@name.setter
def name(self, name: str) -> None:
if self._name is not None:
raise ValueError("Arm name is not mutable once set.")
self._name = name
@property
def signature(self) -> str:
"""Get unique representation of a arm."""
return self.md5hash(self.parameters)
[docs]
@staticmethod
def md5hash(parameters: Mapping[str, TParamValue]) -> str:
"""Return unique identifier for arm's parameters.
Args:
parameters: Parameterization; mapping of param name
to value.
Returns:
Hash of arm's parameters.
"""
new_parameters = {}
for k, v in parameters.items():
new_parameters[k] = numpy_type_to_python_type(v)
parameters_str = json.dumps(parameters, sort_keys=True)
return hashlib.md5(parameters_str.encode("utf-8")).hexdigest()
[docs]
def clone(self, clear_name: bool = False) -> "Arm":
"""Create a copy of this arm.
Args:
clear_name: whether this cloned copy should set its
name to None instead of the name of the arm being cloned.
Defaults to False.
"""
clear_name = clear_name or not self.has_name
return Arm(
parameters=self.parameters.copy(), name=None if clear_name else self.name
)
def __repr__(self) -> str:
parameters_str = f"parameters={self._parameters}"
if self.has_name:
name_str = f"name='{self.name}'"
return f"Arm({name_str}, {parameters_str})"
return f"Arm({parameters_str})"
@equality_typechecker
def __eq__(self, other: "Arm") -> bool:
"""Need to overwrite the default __eq__ method of Base,
because accessing the "name" attribute of Arm
can result in an error.
"""
parameters_equal = self.parameters == other.parameters
names_equal = self.has_name == other.has_name
if names_equal and self.has_name:
names_equal = self.name == other.name
return parameters_equal and names_equal
def __hash__(self) -> int:
return int(self.signature, 16)
@property
def _unique_id(self) -> str:
return self.signature
def _numpy_types_to_python_types(
parameterization: TParameterization,
) -> TParameterization:
"""If applicable, coerce values of the parameterization from Numpy int/float to
Python int/float.
"""
return {
name: numpy_type_to_python_type(value)
for name, value in parameterization.items()
}