#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# pyre-strict
from __future__ import annotations
import json
import warnings
from collections.abc import Callable
from copy import deepcopy
from dataclasses import dataclass, field
from typing import Any
from ax.core.data import Data
from ax.core.experiment import Experiment
from ax.core.generator_run import GeneratorRun
from ax.core.observation import ObservationFeatures
from ax.core.optimization_config import OptimizationConfig
from ax.core.search_space import SearchSpace
from ax.exceptions.core import AxWarning, UserInputError
from ax.modelbridge.base import ModelBridge
from ax.modelbridge.cross_validation import (
compute_diagnostics,
cross_validate,
CVDiagnostics,
CVResult,
get_fit_and_std_quality_and_generalization_dict,
)
from ax.modelbridge.registry import ModelRegistryBase
from ax.utils.common.base import SortableBase
from ax.utils.common.kwargs import (
consolidate_kwargs,
filter_kwargs,
get_function_argument_names,
)
from ax.utils.common.serialization import SerializationMixin
from pyre_extensions import none_throws
TModelFactory = Callable[..., ModelBridge]
[docs]
class ModelSpecJSONEncoder(json.JSONEncoder):
"""Generic encoder to avoid JSON errors in ModelSpec.__repr__"""
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
[docs]
def default(self, o: Any) -> str:
return repr(o)
[docs]
@dataclass
class ModelSpec(SortableBase, SerializationMixin):
model_enum: ModelRegistryBase
# Kwargs to pass into the `Model` + `ModelBridge` constructors in
# `ModelRegistryBase.__call__`.
model_kwargs: dict[str, Any] = field(default_factory=dict)
# Kwargs to pass to `ModelBridge.gen`.
model_gen_kwargs: dict[str, Any] = field(default_factory=dict)
# Kwargs to pass to `cross_validate`.
model_cv_kwargs: dict[str, Any] = field(default_factory=dict)
# An optional override for the model key. Each `ModelSpec` in a
# `GenerationNode` must have a unique key to ensure identifiability.
model_key_override: str | None = None
# Fitted model, constructed using specified `model_kwargs` and `Data`
# on `ModelSpec.fit`
_fitted_model: ModelBridge | None = None
# Stored cross validation results set in cross validate.
_cv_results: list[CVResult] | None = None
# Stored cross validation diagnostics set in cross validate.
_diagnostics: CVDiagnostics | None = None
# Stored to check if the CV result & diagnostic cache is safe to reuse.
_last_cv_kwargs: dict[str, Any] | None = None
# Stored to check if the model can be safely updated in fit.
_last_fit_arg_ids: dict[str, int] | None = None
def __post_init__(self) -> None:
self.model_kwargs = self.model_kwargs or {}
self.model_gen_kwargs = self.model_gen_kwargs or {}
self.model_cv_kwargs = self.model_cv_kwargs or {}
@property
def fitted_model(self) -> ModelBridge:
"""Returns the fitted Ax model, asserting fit() was called"""
self._assert_fitted()
return none_throws(self._fitted_model)
@property
def fixed_features(self) -> ObservationFeatures | None:
"""
Fixed generation features to pass into the Model's `.gen` function.
"""
return self.model_gen_kwargs.get("fixed_features", None)
@fixed_features.setter
def fixed_features(self, value: ObservationFeatures | None) -> None:
"""
Fixed generation features to pass into the Model's `.gen` function.
"""
self.model_gen_kwargs["fixed_features"] = value
@property
def model_key(self) -> str:
"""Key string to identify the model used by this ``ModelSpec``."""
if self.model_key_override is not None:
return self.model_key_override
else:
return self.model_enum.value
[docs]
def fit(
self,
experiment: Experiment,
data: Data,
**model_kwargs: Any,
) -> None:
"""Fits the specified model on the given experiment + data using the
model kwargs set on the model spec, alongside any passed down as
kwargs to this function (local kwargs take precedent)
"""
# unset any cross validation cache
self._cv_results, self._diagnostics = None, None
# NOTE: It's important to copy `self.model_kwargs` here to avoid actually
# adding contents of `model_kwargs` passed to this method, to
# `self.model_kwargs`.
combined_model_kwargs = {**self.model_kwargs, **model_kwargs}
if self._fitted_model is not None and self._safe_to_update(
experiment=experiment, combined_model_kwargs=combined_model_kwargs
):
# Update the data on the modelbridge and call `_fit`.
# This will skip model fitting if the data has not changed.
observations, search_space = self.fitted_model._process_and_transform_data(
experiment=experiment, data=data
)
self.fitted_model._fit_if_implemented(
search_space=search_space, observations=observations, time_so_far=0.0
)
else:
# Fit from scratch.
self._fitted_model = self.model_enum(
experiment=experiment,
data=data,
**combined_model_kwargs,
)
self._last_fit_arg_ids = self._get_fit_arg_ids(
experiment=experiment, combined_model_kwargs=combined_model_kwargs
)
[docs]
def cross_validate(
self,
model_cv_kwargs: dict[str, Any] | None = None,
) -> tuple[list[CVResult] | None, CVDiagnostics | None]:
"""
Call cross_validate, compute_diagnostics and cache the results.
If the model cannot be cross validated, warn and return None.
NOTE: If there are cached results, and the cache was computed using
the same kwargs, this will return the cached results.
Args:
model_cv_kwargs: Optional kwargs to pass into `cross_validate` call.
These are combined with `self.model_cv_kwargs`, with the
`model_cv_kwargs` taking precedence over `self.model_cv_kwargs`.
Returns:
A tuple of CV results (observed vs predicted values) and the
corresponding diagnostics.
"""
cv_kwargs = {**self.model_cv_kwargs, **(model_cv_kwargs or {})}
if (
self._cv_results is not None
and self._diagnostics is not None
and cv_kwargs == self._last_cv_kwargs
):
return self._cv_results, self._diagnostics
self._assert_fitted()
try:
self._cv_results = cross_validate(model=self.fitted_model, **cv_kwargs)
except NotImplementedError:
warnings.warn(
f"{self.model_enum.value} cannot be cross validated", stacklevel=2
)
return None, None
self._diagnostics = compute_diagnostics(self._cv_results)
self._last_cv_kwargs = cv_kwargs
return self._cv_results, self._diagnostics
@property
def cv_results(self) -> list[CVResult] | None:
"""
Cached CV results from `self.cross_validate()`
if it has been successfully called
"""
return self._cv_results
@property
def diagnostics(self) -> CVDiagnostics | None:
"""
Cached CV diagnostics from `self.cross_validate()`
if it has been successfully called
"""
return self._diagnostics
[docs]
def gen(self, **model_gen_kwargs: Any) -> GeneratorRun:
"""Generates candidates from the fitted model, using the model gen
kwargs set on the model spec, alongside any passed as kwargs
to this function (local kwargs take precedent)
NOTE: Model must have been fit prior to calling gen()
Args:
n: Integer representing how many arms should be in the generator run
produced by this method. NOTE: Some underlying models may ignore
the ``n`` and produce a model-determined number of arms. In that
case this method will also output a generator run with number of
arms that can differ from ``n``.
pending_observations: A map from metric name to pending
observations for that metric, used by some models to avoid
resuggesting points that are currently being evaluated.
"""
fitted_model = self.fitted_model
model_gen_kwargs = consolidate_kwargs(
kwargs_iterable=[self.model_gen_kwargs, model_gen_kwargs],
keywords=get_function_argument_names(fitted_model.gen),
)
# copy to ensure there is no in-place modification
model_gen_kwargs = deepcopy(model_gen_kwargs)
generator_run = fitted_model.gen(**model_gen_kwargs)
fit_and_std_quality_and_generalization_dict = (
get_fit_and_std_quality_and_generalization_dict(
fitted_model_bridge=self.fitted_model,
)
)
generator_run._gen_metadata = (
{} if generator_run.gen_metadata is None else generator_run.gen_metadata
)
generator_run._gen_metadata.update(
**fit_and_std_quality_and_generalization_dict
)
return generator_run
[docs]
def copy(self) -> ModelSpec:
"""`ModelSpec` is both a spec and an object that performs actions.
Copying is useful to avoid changes to a singleton model spec.
"""
return self.__class__(
model_enum=self.model_enum,
model_kwargs=deepcopy(self.model_kwargs),
model_gen_kwargs=deepcopy(self.model_gen_kwargs),
model_cv_kwargs=deepcopy(self.model_cv_kwargs),
model_key_override=self.model_key_override,
)
def _safe_to_update(
self,
experiment: Experiment,
combined_model_kwargs: dict[str, Any],
) -> bool:
"""Checks if the object id of any of the non-data fit arguments has changed.
This is a cheap way of checking that we're attempting to re-fit the same
model for the same experiment, which is a very reasonable expectation
since this all happens on the same `ModelSpec` instance.
"""
if self.model_key == "TRBO":
# Temporary hack to unblock TRBO.
# TODO[T167756515] Remove when TRBO revamp diff lands.
return True
return self._last_fit_arg_ids == self._get_fit_arg_ids(
experiment=experiment, combined_model_kwargs=combined_model_kwargs
)
def _get_fit_arg_ids(
self,
experiment: Experiment,
combined_model_kwargs: dict[str, Any],
) -> dict[str, int]:
"""Construct a dictionary mapping arg name to object id."""
return {
"experiment": id(experiment),
**{k: id(v) for k, v in combined_model_kwargs.items()},
}
def _assert_fitted(self) -> None:
"""Helper that verifies a model was fitted, raising an error if not"""
if self._fitted_model is None:
raise UserInputError("No fitted model found. Call fit() to generate one")
def __repr__(self) -> str:
model_kwargs = json.dumps(
self.model_kwargs, sort_keys=True, cls=ModelSpecJSONEncoder
)
model_gen_kwargs = json.dumps(
self.model_gen_kwargs, sort_keys=True, cls=ModelSpecJSONEncoder
)
model_cv_kwargs = json.dumps(
self.model_cv_kwargs, sort_keys=True, cls=ModelSpecJSONEncoder
)
return (
"ModelSpec("
f"\tmodel_enum={self.model_enum.value},\n"
f"\tmodel_kwargs={model_kwargs},\n"
f"\tmodel_gen_kwargs={model_gen_kwargs},\n"
f"\tmodel_cv_kwargs={model_cv_kwargs},\n"
")"
)
def __hash__(self) -> int:
return hash(repr(self))
def __eq__(self, other: ModelSpec) -> bool:
return repr(self) == repr(other)
@property
def _unique_id(self) -> str:
"""Returns the unique ID of this model spec"""
# TODO @mgarrard verify that this is unique enough
return str(hash(self))
[docs]
@dataclass
class FactoryFunctionModelSpec(ModelSpec):
factory_function: TModelFactory | None = None
# pyre-ignore[15]: `ModelSpec` has this as non-optional
model_enum: ModelRegistryBase | None = None
def __post_init__(self) -> None:
super().__post_init__()
if self.model_enum is not None:
raise UserInputError(
"Use regular `ModelSpec` when it's possible to describe the "
"model as `ModelRegistryBase` subclass enum member."
)
if self.factory_function is None:
raise UserInputError(
"Please specify a valid function returning a `ModelBridge` instance "
"as the required `factory_function` argument to "
"`FactoryFunctionModelSpec`."
)
if self.model_key_override is None:
try:
# `model` is defined via a factory function.
# pyre-ignore[16]: Anonymous callable has no attribute `__name__`.
self.model_key_override = none_throws(self.factory_function).__name__
except Exception:
raise TypeError(
f"{self.factory_function} is not a valid function, cannot extract "
"name. Please provide the model name using `model_key_override`."
)
warnings.warn(
"Using a factory function to describe the model, so optimization state "
"cannot be stored and optimization is not resumable if interrupted.",
AxWarning,
stacklevel=3,
)
[docs]
def fit(
self,
experiment: Experiment,
data: Data,
search_space: SearchSpace | None = None,
optimization_config: OptimizationConfig | None = None,
**model_kwargs: Any,
) -> None:
"""Fits the specified model on the given experiment + data using the
model kwargs set on the model spec, alongside any passed down as
kwargs to this function (local kwargs take precedent)
"""
factory_function = none_throws(self.factory_function)
all_kwargs = deepcopy(self.model_kwargs)
all_kwargs.update(model_kwargs)
self._fitted_model = factory_function(
# Factory functions do not have a unified signature; e.g. some factory
# functions (like `get_sobol`) require search space instead of experiment.
# Therefore, we filter kwargs to remove unnecessary ones and add additional
# arguments like `search_space` and `optimization_config`.
**filter_kwargs(
factory_function,
experiment=experiment,
data=data,
search_space=search_space or experiment.search_space,
optimization_config=optimization_config
or experiment.optimization_config,
**all_kwargs,
)
)