Source code for ax.core.objective

#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

import warnings
from typing import Any, Iterable, List, Optional, Tuple

from ax.core.metric import Metric
from ax.utils.common.base import SortableBase
from ax.utils.common.logger import get_logger
from ax.utils.common.typeutils import not_none


# pyre-fixme[5]: Global expression must be annotated.
logger = get_logger(__name__)


[docs]class Objective(SortableBase): """Base class for representing an objective. Attributes: minimize: If True, minimize metric. """ def __init__(self, metric: Metric, minimize: Optional[bool] = None) -> None: """Create a new objective. Args: metric: The metric to be optimized. minimize: If True, minimize metric. If None, will be set based on the `lower_is_better` property of the metric (if that is not specified, will raise a DeprecationWarning). """ lower_is_better = metric.lower_is_better if minimize is None: if lower_is_better is None: warnings.warn( f"Defaulting to `minimize=False` for metric {metric.name} not " + "specifying `lower_is_better` property. This is a wild guess. " + "Specify either `lower_is_better` on the metric, or specify " + "`minimize` explicitly. This will become an error in the future.", DeprecationWarning, ) minimize = False else: minimize = lower_is_better if lower_is_better is not None: if lower_is_better and not minimize: warnings.warn( f"Attempting to maximize metric {metric.name} with property " "`lower_is_better=True`." ) elif not lower_is_better and minimize: warnings.warn( f"Attempting to minimize metric {metric.name} with property " "`lower_is_better=False`." ) self._metric = metric # pyre-fixme[4]: Attribute must be annotated. self.minimize = not_none(minimize) @property def metric(self) -> Metric: """Get the objective metric.""" return self._metric @property def metrics(self) -> List[Metric]: """Get a list of objective metrics.""" return [self._metric] @property def metric_names(self) -> List[str]: """Get a list of objective metric names.""" return [m.name for m in self.metrics]
[docs] def clone(self) -> Objective: """Create a copy of the objective.""" return Objective(self.metric.clone(), self.minimize)
def __repr__(self) -> str: return 'Objective(metric_name="{}", minimize={})'.format( self.metric.name, self.minimize )
[docs] def get_unconstrainable_metrics(self) -> List[Metric]: """Return a list of metrics that are incompatible with OutcomeConstraints.""" return self.metrics
@property def _unique_id(self) -> str: return str(self)
[docs]class MultiObjective(Objective): """Class for an objective composed of a multiple component objectives. The Acquisition function determines how the objectives are weighted. Attributes: objectives: List of objectives. """ weights: List[float] def __init__( self, objectives: Optional[List[Objective]] = None, **extra_kwargs: Any, # Here to satisfy serialization. ) -> None: """Create a new objective. Args: objectives: The list of objectives to be jointly optimized. """ # Support backwards compatibility for old API in which # MultiObjective constructor accepted `metrics` and `minimize` # rather than `objectives` if objectives is None: if "metrics" not in extra_kwargs: raise ValueError( "Must either specify `objectives` or `metrics` " "as input to `MultiObjective` constructor." ) metrics = extra_kwargs["metrics"] minimize = extra_kwargs.get("minimize", False) warnings.warn( "Passing `metrics` and `minimize` as input to the `MultiObjective` " "constructor will soon be deprecated. Instead, pass a list of " "`objectives`. This will become an error in the future.", DeprecationWarning, ) objectives = [] for metric in metrics: lower_is_better = metric.lower_is_better or False _minimize = not lower_is_better if minimize else lower_is_better objectives.append(Objective(metric=metric, minimize=_minimize)) # pyre-fixme[4]: Attribute must be annotated. self._objectives = not_none(objectives) # For now, assume all objectives are weighted equally. # This might be used in the future to change emphasis on the # relative focus of the exploration during the optimization. self.weights = [1.0 for _ in range(len(objectives))] @property def metric(self) -> Metric: """Override base method to error.""" raise NotImplementedError( f"{type(self).__name__} is composed of multiple metrics" ) @property def metrics(self) -> List[Metric]: """Get the objective metrics.""" return [o.metric for o in self._objectives] @property def objectives(self) -> List[Objective]: """Get the objectives.""" return self._objectives @property def objective_weights(self) -> Iterable[Tuple[Objective, float]]: """Get the objectives and weights.""" return zip(self.objectives, self.weights)
[docs] def clone(self) -> Objective: """Create a copy of the objective.""" return MultiObjective(objectives=[o.clone() for o in self.objectives])
def __repr__(self) -> str: return f"MultiObjective(objectives={self.objectives})"
[docs]class ScalarizedObjective(Objective): """Class for an objective composed of a linear scalarization of metrics. Attributes: metrics: List of metrics. weights: Weights for scalarization; default to 1. """ weights: List[float] def __init__( self, metrics: List[Metric], weights: Optional[List[float]] = None, minimize: bool = False, ) -> None: """Create a new objective. Args: metric: The metric to be optimized. weights: The weights for the linear combination of metrics. minimize: If true, minimize the linear combination. """ if weights is None: weights = [1.0 for i in range(len(metrics))] else: if len(weights) != len(metrics): raise ValueError("Length of weights must equal length of metrics") self._metrics = metrics self.weights = weights self.minimize = minimize @property def metric(self) -> Metric: """Override base method to error.""" raise NotImplementedError( f"{type(self).__name__} is composed of multiple metrics" ) @property def metrics(self) -> List[Metric]: """Get the metrics.""" return self._metrics @property def metric_weights(self) -> Iterable[Tuple[Metric, float]]: """Get the metrics and weights.""" return zip(self.metrics, self.weights)
[docs] def clone(self) -> Objective: """Create a copy of the objective.""" return ScalarizedObjective( metrics=[m.clone() for m in self.metrics], weights=self.weights.copy(), minimize=self.minimize, )
def __repr__(self) -> str: return "ScalarizedObjective(metric_names={}, weights={}, minimize={})".format( [metric.name for metric in self.metrics], self.weights, self.minimize )