Source code for ax.core.arm

#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

import hashlib
import json

from ax.core.types import TParameterization
from ax.utils.common.base import SortableBase
from ax.utils.common.equality import equality_typechecker
from ax.utils.common.typeutils_nonnative import numpy_type_to_python_type


[docs] class Arm(SortableBase): """Base class for defining arms. Randomization in experiments assigns units to a given arm. Thus, the arm encapsulates the parametrization needed by the unit. """ def __init__(self, parameters: TParameterization, name: str | None = None) -> None: """Inits Arm. Args: parameters: Mapping from parameter names to values. name: Defaults to None; will be set when arm is attached to a trial """ self._parameters: TParameterization = _numpy_types_to_python_types(parameters) self._name = name @property def parameters(self) -> TParameterization: """Get mapping from parameter names to values.""" # Make a copy before returning so it cannot be accidentally mutated return self._parameters.copy() @property def has_name(self) -> bool: """Return true if arm's name is not None.""" return self._name is not None @property def name(self) -> str: """Get arm name. Throws if name is None.""" if self._name is None: raise ValueError("Arm's name is None.") return self._name @property def name_or_short_signature(self) -> str: """Returns arm name if exists; else last 8 characters of the hash. Used for presentation of candidates (e.g. plotting and tables), where the candidates do not yet have names (since names are automatically set upon addition to a trial). """ return self._name or self.signature[-8:] @name.setter def name(self, name: str) -> None: if self._name is not None: raise ValueError("Arm name is not mutable once set.") self._name = name @property def signature(self) -> str: """Get unique representation of a arm.""" return self.md5hash(self.parameters)
[docs] @staticmethod def md5hash(parameters: TParameterization) -> str: """Return unique identifier for arm's parameters. Args: parameters: Parameterization; mapping of param name to value. Returns: Hash of arm's parameters. """ for k, v in parameters.items(): parameters[k] = numpy_type_to_python_type(v) parameters_str = json.dumps(parameters, sort_keys=True) return hashlib.md5(parameters_str.encode("utf-8")).hexdigest()
[docs] def clone(self, clear_name: bool = False) -> "Arm": """Create a copy of this arm. Args: clear_name: whether this cloned copy should set its name to None instead of the name of the arm being cloned. Defaults to False. """ clear_name = clear_name or not self.has_name return Arm( parameters=self.parameters.copy(), name=None if clear_name else self.name )
def __repr__(self) -> str: parameters_str = f"parameters={self._parameters}" if self.has_name: name_str = f"name='{self.name}'" return f"Arm({name_str}, {parameters_str})" return f"Arm({parameters_str})" @equality_typechecker def __eq__(self, other: "Arm") -> bool: """Need to overwrite the default __eq__ method of Base, because accessing the "name" attribute of Arm can result in an error. """ parameters_equal = self.parameters == other.parameters names_equal = self.has_name == other.has_name if names_equal and self.has_name: names_equal = self.name == other.name return parameters_equal and names_equal def __hash__(self) -> int: return int(self.signature, 16) @property def _unique_id(self) -> str: return self.signature
def _numpy_types_to_python_types( parameterization: TParameterization, ) -> TParameterization: """If applicable, coerce values of the parameterization from Numpy int/float to Python int/float. """ return { name: numpy_type_to_python_type(value) for name, value in parameterization.items() }