Source code for ax.core.objective

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from typing import Iterable, List, Optional, Tuple

from ax.core.base import Base
from ax.core.metric import Metric


[docs]class Objective(Base): """Base class for representing an objective. Attributes: minimize: If True, minimize metric. """ def __init__(self, metric: Metric, minimize: bool = False) -> None: """Create a new objective. Args: metric: The metric to be optimized. minimize: If True, minimize metric. """ self._metric = metric self.minimize = minimize @property def metric(self) -> Metric: """Get the objective metric.""" return self._metric @property def metrics(self) -> List[Metric]: """Get a list of objective metrics.""" return [self._metric]
[docs] def clone(self) -> "Objective": """Create a copy of the objective.""" return Objective(self.metric.clone(), self.minimize)
def __repr__(self) -> str: return 'Objective(metric_name="{}", minimize={})'.format( self.metric.name, self.minimize )
[docs]class ScalarizedObjective(Objective): """Class for an objective composed of a linear scalarization of metrics. Attributes: metrics: List of metrics. weights: Weights for scalarization; default to 1. """ weights: List[float] def __init__( self, metrics: List[Metric], weights: Optional[List[float]] = None, minimize: bool = False, ) -> None: """Create a new objective. Args: metric: The metric to be optimized. weights: The weights for the linear combination of metrics. minimize: If true, minimize the linear combination. """ if weights is None: weights = [1.0 for i in range(len(metrics))] else: if len(weights) != len(metrics): raise ValueError("Length of weights must equal length of metrics") self._metrics = metrics self.minimize = minimize self.weights = weights @property def metric(self) -> Metric: """Override base method to error.""" raise NotImplementedError("ScalarizedObjective is composed of multiple metrics") @property def metrics(self) -> List[Metric]: """Get the objective metrics.""" return self._metrics @property def metric_weights(self) -> Iterable[Tuple[Metric, float]]: """Get the objective metrics and weights.""" return zip(self.metrics, self.weights)
[docs] def clone(self) -> "Objective": """Create a copy of the objective.""" return ScalarizedObjective( metrics=[m.clone() for m in self.metrics], weights=self.weights.copy(), minimize=self.minimize, )
def __repr__(self) -> str: return "ScalarizedObjective(metric_names={}, weights={}, minimize={})".format( [metric.name for metric in self.metrics], self.weights, self.minimize )