#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# pyre-strict
from typing import Dict, List, Optional
from ax.core.experiment import Experiment
from ax.core.observation import Observation, ObservationData, ObservationFeatures
from ax.core.optimization_config import OptimizationConfig
from ax.core.search_space import SearchSpace
from ax.modelbridge.base import GenResults, ModelBridge
from ax.modelbridge.modelbridge_utils import (
extract_parameter_constraints,
extract_search_space_digest,
get_fixed_features,
parse_observation_features,
transform_callback,
)
from ax.models.random.base import RandomModel
from ax.models.types import TConfig
FIT_MODEL_ERROR = "Model must be fit before {action}."
# pyre-fixme[13]: Attribute `model` is never initialized.
# pyre-fixme[13]: Attribute `parameters` is never initialized.
[docs]class RandomModelBridge(ModelBridge):
"""A model bridge for using purely random 'models'.
Data and optimization configs are not required.
This model bridge interfaces with RandomModel.
Attributes:
model: A RandomModel used to generate candidates
(note: this an awkward use of the word 'model').
parameters: Params found in search space on modelbridge init.
"""
model: RandomModel
parameters: List[str]
def _fit(
self,
model: RandomModel,
search_space: SearchSpace,
observations: Optional[List[Observation]] = None,
) -> None:
self.model = model
# Extract and fix parameters from initial search space.
self.parameters = list(search_space.parameters.keys())
def _gen(
self,
n: int,
search_space: SearchSpace,
pending_observations: Dict[str, List[ObservationFeatures]],
fixed_features: Optional[ObservationFeatures],
optimization_config: Optional[OptimizationConfig],
model_gen_options: Optional[TConfig],
) -> GenResults:
"""Generate new candidates according to a search_space."""
# Extract parameter values
search_space_digest = extract_search_space_digest(search_space, self.parameters)
# Get fixed features
fixed_features_dict = get_fixed_features(fixed_features, self.parameters)
# Extract param constraints
linear_constraints = extract_parameter_constraints(
search_space.parameter_constraints, self.parameters
)
# Generate the candidates
X, w = self.model.gen(
n=n,
bounds=search_space_digest.bounds,
linear_constraints=linear_constraints,
fixed_features=fixed_features_dict,
model_gen_options=model_gen_options,
rounding_func=transform_callback(self.parameters, self.transforms),
)
observation_features = parse_observation_features(X, self.parameters)
return GenResults(
observation_features=observation_features,
weights=w.tolist(),
)
def _predict(
self, observation_features: List[ObservationFeatures]
) -> List[ObservationData]:
"""Apply terminal transform, predict, and reverse terminal transform on
output.
"""
raise NotImplementedError("RandomModelBridge does not support prediction.")
def _cross_validate(
self,
search_space: SearchSpace,
cv_training_data: List[Observation],
cv_test_points: List[ObservationFeatures],
) -> List[ObservationData]:
raise NotImplementedError
def _set_status_quo(
self,
experiment: Optional[Experiment],
status_quo_name: Optional[str],
status_quo_features: Optional[ObservationFeatures],
) -> None:
pass