#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Dict, List, Optional, Set, Tuple
from ax.core.observation import ObservationData, ObservationFeatures
from ax.core.optimization_config import OptimizationConfig
from ax.core.parameter import ChoiceParameter, FixedParameter
from ax.core.search_space import SearchSpace
from ax.core.types import TGenMetadata, TParamValueList
from ax.modelbridge.array import (
extract_objective_weights,
extract_outcome_constraints,
validate_optimization_config,
)
from ax.modelbridge.base import ModelBridge
from ax.modelbridge.modelbridge_utils import array_to_observation_data
from ax.models.discrete_base import DiscreteModel
from ax.models.types import TConfig
FIT_MODEL_ERROR = "Model must be fit before {action}."
# pyre-fixme[13]: Attribute `model` is never initialized.
# pyre-fixme[13]: Attribute `outcomes` is never initialized.
# pyre-fixme[13]: Attribute `parameters` is never initialized.
# pyre-fixme[13]: Attribute `search_space` is never initialized.
[docs]class DiscreteModelBridge(ModelBridge):
"""A model bridge for using models based on discrete parameters.
Requires that all parameters have been transformed to ChoiceParameters.
"""
model: DiscreteModel
outcomes: List[str]
parameters: List[str]
search_space: Optional[SearchSpace]
def _fit(
self,
model: DiscreteModel,
search_space: SearchSpace,
observation_features: List[ObservationFeatures],
observation_data: List[ObservationData],
) -> None:
self.model = model
# Convert observations to arrays
self.parameters = list(search_space.parameters.keys())
all_metric_names: Set[str] = set()
for od in observation_data:
all_metric_names.update(od.metric_names)
self.outcomes = list(all_metric_names)
# Convert observations to arrays
Xs_array, Ys_array, Yvars_array = _convert_observations(
observation_data=observation_data,
observation_features=observation_features,
outcomes=self.outcomes,
parameters=self.parameters,
)
# Extract parameter values
parameter_values = _get_parameter_values(search_space, self.parameters)
self.model.fit(
Xs=Xs_array,
Ys=Ys_array,
Yvars=Yvars_array,
parameter_values=parameter_values,
outcome_names=self.outcomes,
)
def _predict(
self, observation_features: List[ObservationFeatures]
) -> List[ObservationData]:
# Convert observations to array
X = [
[of.parameters[param] for param in self.parameters]
for of in observation_features
]
f, cov = self.model.predict(X=X)
# Convert arrays to observations
return array_to_observation_data(f=f, cov=cov, outcomes=self.outcomes)
def _gen(
self,
n: int,
search_space: SearchSpace,
pending_observations: Dict[str, List[ObservationFeatures]],
fixed_features: ObservationFeatures,
model_gen_options: Optional[TConfig] = None,
optimization_config: Optional[OptimizationConfig] = None,
) -> Tuple[
List[ObservationFeatures],
List[float],
Optional[ObservationFeatures],
TGenMetadata,
]:
"""Generate new candidates according to search_space and
optimization_config.
The outcome constraints should be transformed to no longer be relative.
"""
# Validation
if not self.parameters: # pragma: no cover
raise ValueError(FIT_MODEL_ERROR.format(action="_gen"))
# Extract parameter values
parameter_values = _get_parameter_values(search_space, self.parameters)
# Extract objective and outcome constraints
if len(self.outcomes) == 0 or optimization_config is None: # pragma: no cover
objective_weights = None
outcome_constraints = None
else:
validate_optimization_config(optimization_config, self.outcomes)
objective_weights = extract_objective_weights(
objective=optimization_config.objective, outcomes=self.outcomes
)
outcome_constraints = extract_outcome_constraints(
outcome_constraints=optimization_config.outcome_constraints,
outcomes=self.outcomes,
)
# Get fixed features
fixed_features_dict = {
self.parameters.index(p_name): val
for p_name, val in fixed_features.parameters.items()
}
fixed_features_dict = (
fixed_features_dict if len(fixed_features_dict) > 0 else None
)
# Pending observations
if len(pending_observations) == 0:
pending_array: Optional[List[List[TParamValueList]]] = None
else:
pending_array = [[] for _ in self.outcomes]
for metric_name, po_list in pending_observations.items():
pending_array[self.outcomes.index(metric_name)] = [
[po.parameters[p] for p in self.parameters] for po in po_list
]
# Generate the candidates
X, w, gen_metadata = self.model.gen(
n=n,
parameter_values=parameter_values,
objective_weights=objective_weights,
outcome_constraints=outcome_constraints,
fixed_features=fixed_features_dict,
pending_observations=pending_array,
model_gen_options=model_gen_options,
)
observation_features = []
for x in X:
observation_features.append(
ObservationFeatures(
parameters={p: x[i] for i, p in enumerate(self.parameters)}
)
)
# TODO[drfreund, bletham]: implement best_point identification and
# return best_point instead of None
return observation_features, w, None, gen_metadata
def _cross_validate(
self,
search_space: SearchSpace,
obs_feats: List[ObservationFeatures],
obs_data: List[ObservationData],
cv_test_points: List[ObservationFeatures],
) -> List[ObservationData]:
"""Make predictions at cv_test_points using only the data in obs_feats
and obs_data.
"""
Xs_train, Ys_train, Yvars_train = _convert_observations(
observation_data=obs_data,
observation_features=obs_feats,
outcomes=self.outcomes,
parameters=self.parameters,
)
X_test = [
[obsf.parameters[param] for param in self.parameters]
for obsf in cv_test_points
]
# Use the model to do the cross validation
f_test, cov_test = self.model.cross_validate(
Xs_train=Xs_train, Ys_train=Ys_train, Yvars_train=Yvars_train, X_test=X_test
)
# Convert array back to ObservationData
return array_to_observation_data(f=f_test, cov=cov_test, outcomes=self.outcomes)
def _convert_observations(
observation_data: List[ObservationData],
observation_features: List[ObservationFeatures],
outcomes: List[str],
parameters: List[str],
) -> Tuple[List[List[TParamValueList]], List[List[float]], List[List[float]]]:
Xs: List[List[TParamValueList]] = [[] for _ in outcomes]
Ys: List[List[float]] = [[] for _ in outcomes]
Yvars: List[List[float]] = [[] for _ in outcomes]
for i, obsf in enumerate(observation_features):
try:
x = [obsf.parameters[param] for param in parameters]
except (KeyError, TypeError):
# Out of design point
raise ValueError("Out of design points cannot be converted.")
for j, m in enumerate(observation_data[i].metric_names):
k = outcomes.index(m)
Xs[k].append(x)
Ys[k].append(observation_data[i].means[j])
Yvars[k].append(observation_data[i].covariance[j, j])
return Xs, Ys, Yvars
def _get_parameter_values(
search_space: SearchSpace, param_names: List[str]
) -> List[TParamValueList]:
"""Extract parameter values from a search space of discrete parameters."""
parameter_values: List[TParamValueList] = []
for p_name in param_names:
p = search_space.parameters[p_name]
# Validation
if isinstance(p, ChoiceParameter):
# Set values
parameter_values.append(p.values)
elif isinstance(p, FixedParameter):
parameter_values.append([p.value])
else:
raise ValueError(f"{p} not ChoiceParameter or FixedParameter")
return parameter_values