Source code for ax.modelbridge.discrete

#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

from typing import Dict, List, Optional, Set, Tuple

from ax.core.observation import (
    Observation,
    ObservationData,
    ObservationFeatures,
    separate_observations,
)
from ax.core.optimization_config import OptimizationConfig
from ax.core.parameter import ChoiceParameter, FixedParameter
from ax.core.search_space import SearchSpace
from ax.core.types import TParamValueList
from ax.exceptions.core import UserInputError
from ax.modelbridge.base import GenResults, ModelBridge
from ax.modelbridge.modelbridge_utils import (
    array_to_observation_data,
    get_fixed_features,
)
from ax.modelbridge.torch import (
    extract_objective_weights,
    extract_outcome_constraints,
    validate_optimization_config,
)
from ax.models.discrete_base import DiscreteModel
from ax.models.types import TConfig


FIT_MODEL_ERROR = "Model must be fit before {action}."


# pyre-fixme[13]: Attribute `model` is never initialized.
# pyre-fixme[13]: Attribute `outcomes` is never initialized.
# pyre-fixme[13]: Attribute `parameters` is never initialized.
# pyre-fixme[13]: Attribute `search_space` is never initialized.
[docs]class DiscreteModelBridge(ModelBridge): """A model bridge for using models based on discrete parameters. Requires that all parameters have been transformed to ChoiceParameters. """ model: DiscreteModel outcomes: List[str] parameters: List[str] search_space: Optional[SearchSpace] def _fit( self, model: DiscreteModel, search_space: SearchSpace, observations: List[Observation], ) -> None: self.model = model # Convert observations to arrays self.parameters = list(search_space.parameters.keys()) all_metric_names: Set[str] = set() observation_features, observation_data = separate_observations(observations) for od in observation_data: all_metric_names.update(od.metric_names) self.outcomes = list(all_metric_names) # Convert observations to arrays Xs_array, Ys_array, Yvars_array = self._convert_observations( observation_data=observation_data, observation_features=observation_features, outcomes=self.outcomes, parameters=self.parameters, ) # Extract parameter values parameter_values = _get_parameter_values(search_space, self.parameters) self.model.fit( Xs=Xs_array, Ys=Ys_array, Yvars=Yvars_array, parameter_values=parameter_values, outcome_names=self.outcomes, ) def _predict( self, observation_features: List[ObservationFeatures] ) -> List[ObservationData]: # Convert observations to array X = [ [of.parameters[param] for param in self.parameters] for of in observation_features ] f, cov = self.model.predict(X=X) # Convert arrays to observations return array_to_observation_data(f=f, cov=cov, outcomes=self.outcomes) def _validate_gen_inputs( self, n: int, search_space: Optional[SearchSpace] = None, optimization_config: Optional[OptimizationConfig] = None, pending_observations: Optional[Dict[str, List[ObservationFeatures]]] = None, fixed_features: Optional[ObservationFeatures] = None, model_gen_options: Optional[TConfig] = None, ) -> None: """Validate inputs to `ModelBridge.gen`. Currently, this is only used to ensure that `n` is a positive integer or -1. """ if n < 1 and n != -1: raise UserInputError( f"Attempted to generate n={n} points. Number of points to generate " "must be either a positive integer or -1." ) def _gen( self, n: int, search_space: SearchSpace, pending_observations: Dict[str, List[ObservationFeatures]], fixed_features: Optional[ObservationFeatures], model_gen_options: Optional[TConfig] = None, optimization_config: Optional[OptimizationConfig] = None, ) -> GenResults: """Generate new candidates according to search_space and optimization_config. The outcome constraints should be transformed to no longer be relative. """ # Validation if not self.parameters: raise ValueError(FIT_MODEL_ERROR.format(action="_gen")) # Extract parameter values parameter_values = _get_parameter_values(search_space, self.parameters) # Extract objective and outcome constraints if len(self.outcomes) == 0 or optimization_config is None: objective_weights = None outcome_constraints = None else: validate_optimization_config(optimization_config, self.outcomes) objective_weights = extract_objective_weights( objective=optimization_config.objective, outcomes=self.outcomes ) outcome_constraints = extract_outcome_constraints( outcome_constraints=optimization_config.outcome_constraints, outcomes=self.outcomes, ) # Get fixed features fixed_features_dict = get_fixed_features( fixed_features=fixed_features, param_names=self.parameters ) # Pending observations if len(pending_observations) == 0: pending_array: Optional[List[List[TParamValueList]]] = None else: pending_array = [[] for _ in self.outcomes] for metric_name, po_list in pending_observations.items(): pending_array[self.outcomes.index(metric_name)] = [ [po.parameters[p] for p in self.parameters] for po in po_list ] # Generate the candidates X, w, gen_metadata = self.model.gen( n=n, parameter_values=parameter_values, objective_weights=objective_weights, outcome_constraints=outcome_constraints, fixed_features=fixed_features_dict, # pyre-ignore pending_observations=pending_array, model_gen_options=model_gen_options, ) observation_features = [] for x in X: observation_features.append( ObservationFeatures( parameters={p: x[i] for i, p in enumerate(self.parameters)} ) ) # TODO[drfreund, bletham]: implement best_point identification and # return best_point instead of None return GenResults( observation_features=observation_features, weights=w, gen_metadata=gen_metadata, ) def _cross_validate( self, search_space: SearchSpace, cv_training_data: List[Observation], cv_test_points: List[ObservationFeatures], ) -> List[ObservationData]: """Make predictions at cv_test_points using only the data in obs_feats and obs_data. """ observation_features, observation_data = separate_observations(cv_training_data) Xs_train, Ys_train, Yvars_train = self._convert_observations( observation_data=observation_data, observation_features=observation_features, outcomes=self.outcomes, parameters=self.parameters, ) X_test = [ [obsf.parameters[param] for param in self.parameters] for obsf in cv_test_points ] # Use the model to do the cross validation f_test, cov_test = self.model.cross_validate( Xs_train=Xs_train, Ys_train=Ys_train, Yvars_train=Yvars_train, X_test=X_test ) # Convert array back to ObservationData return array_to_observation_data(f=f_test, cov=cov_test, outcomes=self.outcomes) @classmethod def _convert_observations( cls, observation_data: List[ObservationData], observation_features: List[ObservationFeatures], outcomes: List[str], parameters: List[str], ) -> Tuple[List[List[TParamValueList]], List[List[float]], List[List[float]]]: Xs: List[List[TParamValueList]] = [[] for _ in outcomes] Ys: List[List[float]] = [[] for _ in outcomes] Yvars: List[List[float]] = [[] for _ in outcomes] for i, obsf in enumerate(observation_features): try: x = [obsf.parameters[param] for param in parameters] except (KeyError, TypeError): # Out of design point raise ValueError("Out of design points cannot be converted.") for j, m in enumerate(observation_data[i].metric_names): k = outcomes.index(m) Xs[k].append(x) Ys[k].append(observation_data[i].means[j]) Yvars[k].append(observation_data[i].covariance[j, j]) return Xs, Ys, Yvars
def _get_parameter_values( search_space: SearchSpace, param_names: List[str] ) -> List[TParamValueList]: """Extract parameter values from a search space of discrete parameters.""" parameter_values: List[TParamValueList] = [] for p_name in param_names: p = search_space.parameters[p_name] # Validation if isinstance(p, ChoiceParameter): # Set values parameter_values.append(p.values) elif isinstance(p, FixedParameter): parameter_values.append([p.value]) else: raise ValueError(f"{p} not ChoiceParameter or FixedParameter") return parameter_values