Source code for ax.core.optimization_config
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# pyre-strict
from itertools import groupby
from typing import Dict, List, Optional
from ax.core.base import Base
from ax.core.metric import Metric
from ax.core.objective import Objective
from ax.core.outcome_constraint import OutcomeConstraint
from ax.core.types import ComparisonOp
MAX_OBJECTIVES: int = 4
OC_TEMPLATE: str = (
"OptimizationConfig(objective={objective}, outcome_constraints=[{constraints}])"
)
[docs]class OptimizationConfig(Base):
"""An optimization configuration, which comprises an objective
and outcome constraints.
There is no minimum or maximum number of outcome constraints, but an
individual metric can have at most two constraints--which is how we
represent metrics with both upper and lower bounds.
"""
def __init__(
self,
objective: Objective,
outcome_constraints: Optional[List[OutcomeConstraint]] = None,
) -> None:
"""Inits OptimizationConfig.
Args:
objective: Metric+direction to use for the optimization.
outcome_constraints: Constraints on metrics.
"""
constraints: List[
OutcomeConstraint
] = [] if outcome_constraints is None else outcome_constraints
self._validate_optimization_config(
objective=objective, outcome_constraints=constraints
)
self._objective: Objective = objective
self._outcome_constraints: List[OutcomeConstraint] = constraints
[docs] def clone(self) -> "OptimizationConfig":
"""Make a copy of this optimization config."""
return OptimizationConfig(
self.objective.clone(),
[constraint.clone() for constraint in self.outcome_constraints],
)
@property
def objective(self) -> Objective:
"""Get objective."""
return self._objective
@objective.setter
def objective(self, objective: Objective) -> None:
"""Set objective if not present in outcome constraints."""
self._validate_optimization_config(objective, self.outcome_constraints)
self._objective = objective
@property
def outcome_constraints(self) -> List[OutcomeConstraint]:
"""Get outcome constraints."""
return self._outcome_constraints
@property
def metrics(self) -> Dict[str, Metric]:
constraint_metrics = {
oc.metric.name: oc.metric for oc in self._outcome_constraints
}
objective_metrics = {metric.name: metric for metric in self.objective.metrics}
return {**constraint_metrics, **objective_metrics}
@outcome_constraints.setter
def outcome_constraints(self, outcome_constraints: List[OutcomeConstraint]) -> None:
"""Set outcome constraints if valid, else raise."""
self._validate_optimization_config(
objective=self.objective, outcome_constraints=outcome_constraints
)
self._outcome_constraints = outcome_constraints
@staticmethod
def _validate_optimization_config(
objective: Objective, outcome_constraints: List[OutcomeConstraint]
) -> None:
"""Ensure outcome constraints are valid.
Either one or two outcome constraints can reference one metric.
If there are two constraints, they must have different 'ops': one
LEQ and one GEQ.
If there are two constraints, the bound of the GEQ op must be less
than the bound of the LEQ op.
Args:
outcome_constraints: Constraints to validate.
"""
constraint_metrics = [
constraint.metric.name for constraint in outcome_constraints
]
unconstrainable_metrics = objective.get_unconstrainable_metrics()
for metric in unconstrainable_metrics:
if metric.name in constraint_metrics:
raise ValueError("Cannot constrain on objective metric.")
def get_metric_name(oc: OutcomeConstraint) -> str:
return oc.metric.name
sorted_constraints = sorted(outcome_constraints, key=get_metric_name)
for metric_name, constraints_itr in groupby(
sorted_constraints, get_metric_name
):
constraints: List[OutcomeConstraint] = list(constraints_itr)
constraints_len = len(constraints)
if constraints_len == 2:
if constraints[0].op == constraints[1].op:
raise ValueError(f"Duplicate outcome constraints {metric_name}")
lower_bound_idx = 0 if constraints[0].op == ComparisonOp.GEQ else 1
upper_bound_idx = 1 - lower_bound_idx
lower_bound = constraints[lower_bound_idx].bound
upper_bound = constraints[upper_bound_idx].bound
if lower_bound >= upper_bound:
raise ValueError(
f"Lower bound {lower_bound} is >= upper bound "
+ f"{upper_bound} for {metric_name}"
)
elif constraints_len > 2:
raise ValueError(f"Duplicate outcome constraints {metric_name}")
def __repr__(self) -> str:
return OC_TEMPLATE.format(
objective=repr(self.objective),
constraints=", ".join(
constraint.__repr__() for constraint in self.outcome_constraints
),
)