#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Dict, List, Optional, Tuple
from ax.core.data import Data
from ax.core.experiment import Experiment
from ax.core.observation import ObservationData, ObservationFeatures
from ax.core.optimization_config import OptimizationConfig
from ax.core.search_space import SearchSpace
from ax.core.types import TConfig, TGenMetadata
from ax.modelbridge.base import ModelBridge
from ax.modelbridge.modelbridge_utils import (
extract_parameter_constraints,
get_bounds_and_task,
get_fixed_features,
parse_observation_features,
transform_callback,
)
from ax.models.random.base import RandomModel
from ax.utils.common.docutils import copy_doc
FIT_MODEL_ERROR = "Model must be fit before {action}."
# pyre-fixme[13]: Attribute `model` is never initialized.
# pyre-fixme[13]: Attribute `parameters` is never initialized.
[docs]class RandomModelBridge(ModelBridge):
"""A model bridge for using purely random 'models'.
Data and optimization configs are not required.
This model bridge interfaces with RandomModel.
Attributes:
model: A RandomModel used to generate candidates
(note: this an awkward use of the word 'model').
parameters: Params found in search space on modelbridge init.
"""
model: RandomModel
parameters: List[str]
def _fit(
self,
model: RandomModel,
search_space: SearchSpace,
observation_features: Optional[List[ObservationFeatures]] = None,
observation_data: Optional[List[ObservationData]] = None,
) -> None:
self.model = model
# Extract and fix parameters from initial search space.
self.parameters = list(search_space.parameters.keys())
# pyre-fixme[56]: While applying decorator
# `ax.utils.common.docutils.copy_doc(...)`: Argument `experiment` expected.
[docs] @copy_doc(ModelBridge.update)
def update(self, new_data: Data, experiment: Experiment) -> None:
pass # pragma: no cover
def _gen(
self,
n: int,
search_space: SearchSpace,
pending_observations: Dict[str, List[ObservationFeatures]],
fixed_features: ObservationFeatures,
optimization_config: Optional[OptimizationConfig],
model_gen_options: Optional[TConfig],
) -> Tuple[
List[ObservationFeatures],
List[float],
Optional[ObservationFeatures],
TGenMetadata,
]:
"""Generate new candidates according to a search_space."""
# Extract parameter values
bounds, _, _ = get_bounds_and_task(search_space, self.parameters)
# Get fixed features
fixed_features_dict = get_fixed_features(fixed_features, self.parameters)
# Extract param constraints
linear_constraints = extract_parameter_constraints(
search_space.parameter_constraints, self.parameters
)
# Generate the candidates
X, w = self.model.gen(
n=n,
bounds=bounds,
linear_constraints=linear_constraints,
fixed_features=fixed_features_dict,
model_gen_options=model_gen_options,
rounding_func=transform_callback(self.parameters, self.transforms),
)
observation_features = parse_observation_features(X, self.parameters)
return observation_features, w.tolist(), None, {}
def _predict(
self, observation_features: List[ObservationFeatures]
) -> List[ObservationData]:
"""Apply terminal transform, predict, and reverse terminal transform on
output.
"""
raise NotImplementedError("RandomModelBridge does not support prediction.")
def _cross_validate(
self,
obs_feats: List[ObservationFeatures],
obs_data: List[ObservationData],
cv_test_points: List[ObservationFeatures],
) -> List[ObservationData]:
raise NotImplementedError