Source code for ax.core.risk_measures

#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

from __future__ import annotations

from copy import deepcopy
from typing import Union

from ax.utils.common.base import SortableBase
from ax.utils.common.equality import equality_typechecker


[docs]class RiskMeasure(SortableBase): """A class for defining risk measures. This can be used with a `RobustSearchSpace`, to convert the predictions over `ParameterDistribution`s to robust metrics, which then get used in candidate generation to recommend robust candidates. See `ax/modelbridge/modelbridge_utils.py` for `RISK_MEASURE_NAME_TO_CLASS`, which lists the supported risk measures, and for `extract_risk_measure` helper, which extracts the BoTorch risk measure. """ def __init__( self, risk_measure: str, options: dict[str, Union[int, float, bool, list[float]]], ) -> None: """Initialize a risk measure. Args: risk_measure: The name of the risk measure to use. This should have a corresponding entry in `RISK_MEASURE_NAME_TO_CLASS`. options: A dictionary of keyword arguments for initializing the risk measure. Except for MARS, the risk measure will be initialized as `RISK_MEASURE_NAME_TO_CLASS[risk_measure](**options)`. For MARS, additional attributes are needed to inform the scalarization. """ super().__init__() self.risk_measure = risk_measure self.options = options
[docs] def clone(self) -> RiskMeasure: """Clone.""" return RiskMeasure( risk_measure=self.risk_measure, options=deepcopy(self.options), )
def __repr__(self) -> str: return ( f"{self.__class__.__name__}(" "risk_measure=" + self.risk_measure + ", " "options=" + repr(self.options) + ")" ) @property def _unique_id(self) -> str: return str(self) @equality_typechecker def __eq__(self, other: RiskMeasure) -> bool: return str(self) == str(other)