Source code for ax.modelbridge.random

#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict


from typing import Any

from ax.core.data import Data

from ax.core.experiment import Experiment
from ax.core.observation import Observation, ObservationData, ObservationFeatures
from ax.core.optimization_config import OptimizationConfig
from ax.core.search_space import SearchSpace
from ax.modelbridge.base import GenResults, ModelBridge
from ax.modelbridge.modelbridge_utils import (
    extract_parameter_constraints,
    extract_search_space_digest,
    get_fixed_features,
    parse_observation_features,
    transform_callback,
)
from ax.modelbridge.transforms.base import Transform
from ax.models.random.base import RandomModel
from ax.models.types import TConfig


FIT_MODEL_ERROR = "Model must be fit before {action}."


[docs] class RandomModelBridge(ModelBridge): """A model bridge for using purely random 'models'. Data and optimization configs are not required. This model bridge interfaces with RandomModel. Attributes: model: A RandomModel used to generate candidates (note: this an awkward use of the word 'model'). parameters: Params found in search space on modelbridge init. Args: experiment: Is used to get arm parameters. Is not mutated. search_space: Search space for fitting the model. Constraints need not be the same ones used in gen. RangeParameter bounds are considered soft and will be expanded to match the range of the data sent in for fitting, if expand_model_space is True. data: Ax Data. model: Interface will be specified in subclass. If model requires initialization, that should be done prior to its use here. transforms: List of uninitialized transform classes. Forward transforms will be applied in this order, and untransforms in the reverse order. transform_configs: A dictionary from transform name to the transform config dictionary. status_quo_name: Name of the status quo arm. Can only be used if Data has a single set of ObservationFeatures corresponding to that arm. status_quo_features: ObservationFeatures to use as status quo. Either this or status_quo_name should be specified, not both. optimization_config: Optimization config defining how to optimize the model. fit_out_of_design: If specified, all training data are used. Otherwise, only in design points are used. fit_abandoned: Whether data for abandoned arms or trials should be included in model training data. If ``False``, only non-abandoned points are returned. fit_tracking_metrics: Whether to fit a model for tracking metrics. Setting this to False will improve runtime at the expense of models not being available for predicting tracking metrics. NOTE: This can only be set to False when the optimization config is provided. fit_on_init: Whether to fit the model on initialization. This can be used to skip model fitting when a fitted model is not needed. To fit the model afterwards, use `_process_and_transform_data` to get the transformed inputs and call `_fit_if_implemented` with the transformed inputs. """ # pyre-fixme[13]: Attribute `model` is never initialized. model: RandomModel # pyre-fixme[13]: Attribute `parameters` is never initialized. parameters: list[str] def __init__( self, search_space: SearchSpace, # pyre-fixme[2]: Parameter annotation cannot be `Any`. model: Any, transforms: list[type[Transform]] | None = None, experiment: Experiment | None = None, data: Data | None = None, transform_configs: dict[str, TConfig] | None = None, status_quo_name: str | None = None, status_quo_features: ObservationFeatures | None = None, optimization_config: OptimizationConfig | None = None, fit_out_of_design: bool = False, fit_abandoned: bool = False, fit_tracking_metrics: bool = True, fit_on_init: bool = True, ) -> None: super().__init__( search_space=search_space, model=model, transforms=transforms, experiment=experiment, data=data, transform_configs=transform_configs, status_quo_name=status_quo_name, status_quo_features=status_quo_features, optimization_config=optimization_config, expand_model_space=False, fit_out_of_design=fit_out_of_design, fit_abandoned=fit_abandoned, fit_tracking_metrics=fit_tracking_metrics, fit_on_init=fit_on_init, ) def _fit( self, model: RandomModel, search_space: SearchSpace, observations: list[Observation] | None = None, ) -> None: self.model = model # Extract and fix parameters from initial search space. self.parameters = list(search_space.parameters.keys()) def _gen( self, n: int, search_space: SearchSpace, pending_observations: dict[str, list[ObservationFeatures]], fixed_features: ObservationFeatures | None, optimization_config: OptimizationConfig | None, model_gen_options: TConfig | None, ) -> GenResults: """Generate new candidates according to a search_space.""" # Extract parameter values search_space_digest = extract_search_space_digest(search_space, self.parameters) # Get fixed features fixed_features_dict = get_fixed_features(fixed_features, self.parameters) # Extract param constraints linear_constraints = extract_parameter_constraints( search_space.parameter_constraints, self.parameters ) # Generate the candidates X, w = self.model.gen( n=n, bounds=search_space_digest.bounds, linear_constraints=linear_constraints, fixed_features=fixed_features_dict, model_gen_options=model_gen_options, rounding_func=transform_callback(self.parameters, self.transforms), ) observation_features = parse_observation_features(X, self.parameters) return GenResults( observation_features=observation_features, weights=w.tolist(), ) def _predict( self, observation_features: list[ObservationFeatures] ) -> list[ObservationData]: """Apply terminal transform, predict, and reverse terminal transform on output. """ raise NotImplementedError("RandomModelBridge does not support prediction.") def _cross_validate( self, search_space: SearchSpace, cv_training_data: list[Observation], cv_test_points: list[ObservationFeatures], use_posterior_predictive: bool = False, ) -> list[ObservationData]: raise NotImplementedError def _set_status_quo( self, experiment: Experiment | None, status_quo_name: str | None, status_quo_features: ObservationFeatures | None, ) -> None: pass