Source code for ax.core.objective

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, Iterable, List, Optional, Tuple

from ax.core.base import Base
from ax.core.metric import Metric
from ax.utils.common.logger import get_logger


logger = get_logger(__name__)


[docs]class Objective(Base): """Base class for representing an objective. Attributes: minimize: If True, minimize metric. """ def __init__(self, metric: Metric, minimize: bool = False) -> None: """Create a new objective. Args: metric: The metric to be optimized. minimize: If True, minimize metric. """ self._metric = metric self.minimize = minimize @property def metric(self) -> Metric: """Get the objective metric.""" return self._metric @property def metrics(self) -> List[Metric]: """Get a list of objective metrics.""" return [self._metric]
[docs] def clone(self) -> "Objective": """Create a copy of the objective.""" return Objective(self.metric.clone(), self.minimize)
def __repr__(self) -> str: return 'Objective(metric_name="{}", minimize={})'.format( self.metric.name, self.minimize )
# TODO (jej): Support sqa_store encoding. Currenlty only single metric obj supported.
[docs]class MultiObjective(Objective): """Class for an objective composed of a multiple component objectives. The Acquisition function determines how the objectives are weighted. Attributes: metrics: List of metrics. """ weights: List[float] def __init__( self, metrics: List[Metric], minimize: bool = False, **extra_kwargs: Any, # Here to satisfy serialization. ) -> None: """Create a new objective. Args: metrics: The list of metrics to be jointly optimized. minimize: If true, minimize the aggregate of these metrics. """ self._metrics = metrics self.weights = [] for metric in metrics: # Set weights from "lower_is_better" if metric.lower_is_better is None: logger.warning( f"metric {metric.name} has not set `lower_is_better`. " "Treating as `False` (Metric should be maximized)." ) self.weights.append( -1.0 if metric.lower_is_better is True else 1.0 if metric.lower_is_better is False else 0.0 ) self.minimize = minimize @property def metric_weights(self) -> Iterable[Tuple[Metric, float]]: """Get the objective metrics and weights.""" return zip(self.metrics, self.weights) @property def metric(self) -> Metric: """Override base method to error.""" raise NotImplementedError( f"{type(self).__name__} is composed of multiple metrics" ) @property def metrics(self) -> List[Metric]: """Get the objective metrics.""" return self._metrics
[docs] def clone(self) -> "Objective": """Create a copy of the objective.""" return MultiObjective( metrics=[m.clone() for m in self.metrics], minimize=self.minimize )
def __repr__(self) -> str: return "MultiObjective(metric_names={}, minimize={})".format( [metric.name for metric in self.metrics], self.minimize )
[docs]class ScalarizedObjective(MultiObjective): """Class for an objective composed of a linear scalarization of metrics. Attributes: metrics: List of metrics. weights: Weights for scalarization; default to 1. """ weights: List[float] def __init__( self, metrics: List[Metric], weights: Optional[List[float]] = None, minimize: bool = False, ) -> None: """Create a new objective. Args: metric: The metric to be optimized. weights: The weights for the linear combination of metrics. minimize: If true, minimize the linear combination. """ if weights is None: weights = [1.0 for i in range(len(metrics))] else: if len(weights) != len(metrics): raise ValueError("Length of weights must equal length of metrics") super().__init__(metrics, minimize) self.weights = weights
[docs] def clone(self) -> "Objective": """Create a copy of the objective.""" return ScalarizedObjective( metrics=[m.clone() for m in self.metrics], weights=self.weights.copy(), minimize=self.minimize, )
def __repr__(self) -> str: return "ScalarizedObjective(metric_names={}, weights={}, minimize={})".format( [metric.name for metric in self.metrics], self.weights, self.minimize )