# Source code for ax.core.parameter_constraint

#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# LICENSE file in the root directory of this source tree.

# pyre-strict

from __future__ import annotations

from typing import Dict, List, Union

from ax.core.parameter import ChoiceParameter, FixedParameter, Parameter, RangeParameter
from ax.core.types import ComparisonOp
from ax.utils.common.base import SortableBase

[docs]class ParameterConstraint(SortableBase):
"""Base class for linear parameter constraints.

Constraints are expressed using a map from parameter name to weight
followed by a bound.

The constraint is satisfied if sum_i(w_i * v_i) <= b where:
w is the vector of parameter weights.
v is a vector of parameter values.
b is the specified bound.
"""

def __init__(self, constraint_dict: Dict[str, float], bound: float) -> None:
"""Initialize ParameterConstraint

Args:
constraint_dict: Map from parameter name to weight.
bound: Bound of the inequality of the constraint.
"""
self._constraint_dict = constraint_dict
self._bound = bound

@property
def constraint_dict(self) -> Dict[str, float]:
"""Get mapping from parameter names to weights."""
return self._constraint_dict

@property
def bound(self) -> float:
"""Get bound of the inequality of the constraint."""
return self._bound

@bound.setter
def bound(self, bound: float) -> None:
"""Set bound."""
self._bound = bound

[docs]    def check(self, parameter_dict: Dict[str, Union[int, float]]) -> bool:
"""Whether or not the set of parameter values satisfies the constraint.

Does a weighted sum of the parameter values based on the constraint_dict
and checks that the sum is less than the bound.

Args:
parameter_dict: Map from parameter name to parameter value.

Returns:
Whether the constraint is satisfied.
"""
for parameter_name in self.constraint_dict.keys():
if parameter_name not in parameter_dict.keys():
raise ValueError(f"{parameter_name} not present in param_dict.")

weighted_sum = sum(
float(parameter_dict[param]) * weight
for param, weight in self.constraint_dict.items()
)
# Expected int for 2nd anonymous parameter to call int.__le__ but got
# float.
return weighted_sum <= self._bound + 1e-8  # allow for numerical imprecision

[docs]    def clone(self) -> ParameterConstraint:
"""Clone."""
return ParameterConstraint(
constraint_dict=self._constraint_dict.copy(), bound=self._bound
)

[docs]    def clone_with_transformed_parameters(
self, transformed_parameters: Dict[str, Parameter]
) -> ParameterConstraint:
"""Clone, but replaced parameters with transformed versions."""
return self.clone()

def __repr__(self) -> str:
return (
"ParameterConstraint("
+ " + ".join(
"{}*{}".format(v, k) for k, v in sorted(self.constraint_dict.items())
)
+ " <= {})".format(self._bound)
)

@property
def _unique_id(self) -> str:
return str(self)

[docs]class OrderConstraint(ParameterConstraint):
"""Constraint object for specifying one parameter to be smaller than another."""

_bound: float

def __init__(self, lower_parameter: Parameter, upper_parameter: Parameter) -> None:
"""Initialize OrderConstraint

Args:
lower_parameter: Parameter that should have the lower value.
upper_parameter: Parameter that should have the higher value.

Note:
The constraint p1 <= p2 can be expressed in matrix notation as
[1, -1] * [p1, p2]^T <= 0.
"""
validate_constraint_parameters([lower_parameter, upper_parameter])

self._lower_parameter = lower_parameter
self._upper_parameter = upper_parameter
self._bound = 0.0

@property
def lower_parameter(self) -> Parameter:
"""Parameter with lower value."""
return self._lower_parameter

@property
def upper_parameter(self) -> Parameter:
"""Parameter with higher value."""
return self._upper_parameter

@property
def parameters(self) -> List[Parameter]:
"""Parameters."""
return [self.lower_parameter, self.upper_parameter]

@property
def constraint_dict(self) -> Dict[str, float]:
"""Weights on parameters for linear constraint representation."""
return {self.lower_parameter.name: 1.0, self.upper_parameter.name: -1.0}

[docs]    def clone(self) -> OrderConstraint:
"""Clone."""
return OrderConstraint(
lower_parameter=self.lower_parameter.clone(),
upper_parameter=self._upper_parameter.clone(),
)

[docs]    def clone_with_transformed_parameters(
self, transformed_parameters: Dict[str, Parameter]
) -> OrderConstraint:
"""Clone, but replace parameters with transformed versions."""
return OrderConstraint(
lower_parameter=transformed_parameters[self.lower_parameter.name],
upper_parameter=transformed_parameters[self._upper_parameter.name],
)

def __repr__(self) -> str:
return "OrderConstraint({} <= {})".format(
self.lower_parameter.name, self.upper_parameter.name
)

[docs]class SumConstraint(ParameterConstraint):
"""Constraint on the sum of parameters being greater or less than a bound."""

def __init__(
self, parameters: List[Parameter], is_upper_bound: bool, bound: float
) -> None:
"""Initialize SumConstraint

Args:
parameters: List of parameters whose sum to constrain on.
is_upper_bound: Whether the bound is an upper or lower bound on the sum.
bound: The bound on the sum.
"""
validate_constraint_parameters(parameters)

self._parameters = parameters
self._is_upper_bound: bool = is_upper_bound
self._parameter_names: List[str] = [parameter.name for parameter in parameters]
self._bound: float = self._inequality_weight * bound
self._constraint_dict: Dict[str, float] = {
name: self._inequality_weight for name in self._parameter_names
}

@property
def parameters(self) -> List[Parameter]:
"""Parameters."""
return self._parameters

@property
def constraint_dict(self) -> Dict[str, float]:
"""Weights on parameters for linear constraint representation."""
return self._constraint_dict

@property
def op(self) -> ComparisonOp:
"""Whether the sum is constrained by a <= or >= inequality."""
return ComparisonOp.LEQ if self._is_upper_bound else ComparisonOp.GEQ

@property
def is_upper_bound(self) -> bool:
"""Whether the bound is an upper or lower bound on the sum."""
return self._is_upper_bound

[docs]    def clone(self) -> SumConstraint:
"""Clone.

To use the same constraint, we need to reconstruct the original bound.
We do this by re-applying the original bound weighting.
"""
return SumConstraint(
parameters=[p.clone() for p in self._parameters],
is_upper_bound=self._is_upper_bound,
bound=self._inequality_weight * self._bound,
)

[docs]    def clone_with_transformed_parameters(
self, transformed_parameters: Dict[str, Parameter]
) -> SumConstraint:
"""Clone, but replace parameters with transformed versions."""
return SumConstraint(
parameters=[transformed_parameters[p.name] for p in self._parameters],
is_upper_bound=self._is_upper_bound,
bound=self._inequality_weight * self._bound,
)

@property
def _inequality_weight(self) -> float:
"""Multiplier of all terms in the inequality.

If the constraint is an upper bound, it is v1 + v2 ... v_n <= b
If the constraint is an lower bound, it is -v1 + -v2 ... -v_n <= -b
This property returns 1 or -1 depending on the scenario
"""
return 1.0 if self._is_upper_bound else -1.0

def __repr__(self) -> str:
symbol = ">=" if self.op == ComparisonOp.GEQ else "<="
return (
"SumConstraint("
+ " + ".join(self._parameter_names)
+ " {} {})".format(
symbol, self._bound if self.op == ComparisonOp.LEQ else -self._bound
)
)

[docs]def validate_constraint_parameters(parameters: List[Parameter]) -> None:
"""Basic validation of parameters used in a constraint.

Args:
parameters: Parameters used in constraint.

Raises:
ValueError if the parameters are not valid for use.
"""
unique_parameter_names = {p.name for p in parameters}
if len(unique_parameter_names) != len(parameters):
raise ValueError("Duplicate parameter in constraint.")

for parameter in parameters:
if not parameter.is_numeric:
raise ValueError(
"Parameter constraints only supported for numeric parameters."
)

# Constraints on FixedParameters are non-sensical.
if isinstance(parameter, FixedParameter):
raise ValueError("Parameter constraints not supported for FixedParameter.")

# ChoiceParameters are transformed either using OneHotEncoding
# or the OrderedChoice transform. Both are non-linear, and
# Ax models only support linear constraints.
if isinstance(parameter, ChoiceParameter):
raise ValueError("Parameter constraints not supported for ChoiceParameter.")

# Log parameters require a non-linear transformation, and Ax
# models only support linear constraints.
if isinstance(parameter, RangeParameter) and parameter.log_scale is True:
raise ValueError(
"Parameter constraints not allowed on log scale parameters."
)