Source code for ax.storage.json_store.encoders

#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import re
from typing import Any, Dict, Type

from ax.benchmark.problems.hpo.torchvision import (
    PyTorchCNNTorchvisionBenchmarkProblem,
)
from ax.core import ObservationFeatures
from ax.core.arm import Arm
from ax.core.batch_trial import BatchTrial
from ax.core.data import Data
from ax.core.experiment import Experiment
from ax.core.generator_run import GeneratorRun
from ax.core.map_data import MapData, MapKeyInfo
from ax.core.metric import Metric
from ax.core.multi_type_experiment import MultiTypeExperiment
from ax.core.objective import MultiObjective, Objective, ScalarizedObjective
from ax.core.optimization_config import (
    MultiObjectiveOptimizationConfig,
    OptimizationConfig,
)
from ax.core.outcome_constraint import OutcomeConstraint
from ax.core.parameter import ChoiceParameter, FixedParameter, RangeParameter
from ax.core.parameter_constraint import (
    OrderConstraint,
    ParameterConstraint,
    SumConstraint,
)
from ax.core.runner import Runner
from ax.core.search_space import SearchSpace
from ax.core.trial import Trial
from ax.early_stopping.strategies import (
    PercentileEarlyStoppingStrategy,
    ThresholdEarlyStoppingStrategy,
    LogicalEarlyStoppingStrategy,
)
from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy
from ax.modelbridge.registry import _encode_callables_as_references
from ax.modelbridge.transforms.base import Transform
from ax.modelbridge.transforms.winsorize import WinsorizationConfig
from ax.models.torch.botorch_modular.model import BoTorchModel
from ax.models.torch.botorch_modular.surrogate import Surrogate
from ax.storage.botorch_modular_registry import CLASS_TO_REGISTRY
from ax.storage.transform_registry import TRANSFORM_REGISTRY
from ax.utils.common.serialization import serialize_init_args
from ax.utils.common.typeutils import not_none


[docs]def experiment_to_dict(experiment: Experiment) -> Dict[str, Any]: """Convert Ax experiment to a dictionary.""" return { "__type": experiment.__class__.__name__, "name": experiment._name, "description": experiment.description, "experiment_type": experiment.experiment_type, "search_space": experiment.search_space, "optimization_config": experiment.optimization_config, "tracking_metrics": list(experiment._tracking_metrics.values()), "runner": experiment.runner, "status_quo": experiment.status_quo, "time_created": experiment.time_created, "trials": experiment.trials, "is_test": experiment.is_test, "data_by_trial": experiment.data_by_trial, "properties": experiment._properties, "default_data_type": experiment._default_data_type, }
[docs]def multi_type_experiment_to_dict(experiment: MultiTypeExperiment) -> Dict[str, Any]: """Convert AE multitype experiment to a dictionary.""" multi_type_dict = { "default_trial_type": experiment._default_trial_type, "_metric_to_canonical_name": experiment._metric_to_canonical_name, "_metric_to_trial_type": experiment._metric_to_trial_type, "_trial_type_to_runner": experiment._trial_type_to_runner, } multi_type_dict.update(experiment_to_dict(experiment)) return multi_type_dict
[docs]def batch_to_dict(batch: BatchTrial) -> Dict[str, Any]: """Convert Ax batch to a dictionary.""" return { "__type": batch.__class__.__name__, "index": batch.index, "trial_type": batch.trial_type, "ttl_seconds": batch.ttl_seconds, "status": batch.status, "status_quo": batch.status_quo, "status_quo_weight_override": batch._status_quo_weight_override, "time_created": batch.time_created, "time_completed": batch.time_completed, "time_staged": batch.time_staged, "time_run_started": batch.time_run_started, "abandoned_reason": batch.abandoned_reason, "run_metadata": batch.run_metadata, "stop_metadata": batch.stop_metadata, "generator_run_structs": batch.generator_run_structs, "runner": batch.runner, "abandoned_arms_metadata": batch._abandoned_arms_metadata, "num_arms_created": batch._num_arms_created, "optimize_for_power": batch.optimize_for_power, "generation_step_index": batch._generation_step_index, "properties": batch._properties, }
[docs]def trial_to_dict(trial: Trial) -> Dict[str, Any]: """Convert Ax trial to a dictionary.""" return { "__type": trial.__class__.__name__, "index": trial.index, "trial_type": trial.trial_type, "ttl_seconds": trial.ttl_seconds, "status": trial.status, "time_created": trial.time_created, "time_completed": trial.time_completed, "time_staged": trial.time_staged, "time_run_started": trial.time_run_started, "abandoned_reason": trial.abandoned_reason, "run_metadata": trial.run_metadata, "stop_metadata": trial.stop_metadata, "generator_run": trial.generator_run, "runner": trial.runner, "num_arms_created": trial._num_arms_created, "generation_step_index": trial._generation_step_index, "properties": trial._properties, }
[docs]def range_parameter_to_dict(parameter: RangeParameter) -> Dict[str, Any]: """Convert Ax range parameter to a dictionary.""" return { "__type": parameter.__class__.__name__, "name": parameter.name, "parameter_type": parameter.parameter_type, "lower": parameter.lower, "upper": parameter.upper, "log_scale": parameter.log_scale, "logit_scale": parameter.logit_scale, "digits": parameter.digits, "is_fidelity": parameter.is_fidelity, "target_value": parameter.target_value, }
[docs]def choice_parameter_to_dict(parameter: ChoiceParameter) -> Dict[str, Any]: """Convert Ax choice parameter to a dictionary.""" return { "__type": parameter.__class__.__name__, "is_ordered": parameter.is_ordered, "is_task": parameter.is_task, "name": parameter.name, "parameter_type": parameter.parameter_type, "values": parameter.values, "is_fidelity": parameter.is_fidelity, "target_value": parameter.target_value, "dependents": parameter.dependents if parameter.is_hierarchical else None, }
[docs]def fixed_parameter_to_dict(parameter: FixedParameter) -> Dict[str, Any]: """Convert Ax fixed parameter to a dictionary.""" return { "__type": parameter.__class__.__name__, "name": parameter.name, "parameter_type": parameter.parameter_type, "value": parameter.value, "is_fidelity": parameter.is_fidelity, "target_value": parameter.target_value, "dependents": parameter.dependents if parameter.is_hierarchical else None, }
[docs]def order_parameter_constraint_to_dict( parameter_constraint: OrderConstraint, ) -> Dict[str, Any]: """Convert Ax order parameter constraint to a dictionary.""" return { "__type": parameter_constraint.__class__.__name__, "lower_name": parameter_constraint.lower_parameter.name, "upper_name": parameter_constraint.upper_parameter.name, }
[docs]def sum_parameter_constraint_to_dict( parameter_constraint: SumConstraint, ) -> Dict[str, Any]: """Convert Ax sum parameter constraint to a dictionary.""" return { "__type": parameter_constraint.__class__.__name__, "parameter_names": parameter_constraint._parameter_names, "is_upper_bound": parameter_constraint._is_upper_bound, # SumParameterConstraint constructor takes in absolute value of # the bound and transforms it based on the is_upper_bound value "bound": abs(parameter_constraint._bound), }
[docs]def parameter_constraint_to_dict( parameter_constraint: ParameterConstraint, ) -> Dict[str, Any]: """Convert Ax sum parameter constraint to a dictionary.""" return { "__type": parameter_constraint.__class__.__name__, "constraint_dict": parameter_constraint.constraint_dict, "bound": parameter_constraint.bound, }
[docs]def arm_to_dict(arm: Arm) -> Dict[str, Any]: """Convert Ax arm to a dictionary.""" return { "__type": arm.__class__.__name__, "parameters": arm.parameters, "name": arm._name, }
[docs]def search_space_to_dict(search_space: SearchSpace) -> Dict[str, Any]: """Convert Ax search space to a dictionary.""" return { "__type": search_space.__class__.__name__, "parameters": list(search_space.parameters.values()), "parameter_constraints": search_space.parameter_constraints, }
[docs]def metric_to_dict(metric: Metric) -> Dict[str, Any]: """Convert Ax metric to a dictionary.""" properties = serialize_init_args(object=metric) properties["__type"] = metric.__class__.__name__ return properties
[docs]def objective_to_dict(objective: Objective) -> Dict[str, Any]: """Convert Ax objective to a dictionary.""" return { "__type": objective.__class__.__name__, "metric": objective.metric, "minimize": objective.minimize, }
[docs]def multi_objective_to_dict(objective: MultiObjective) -> Dict[str, Any]: """Convert Ax objective to a dictionary.""" return { "__type": objective.__class__.__name__, "objectives": objective.objectives, "weights": objective.weights, }
[docs]def scalarized_objective_to_dict(objective: ScalarizedObjective) -> Dict[str, Any]: """Convert Ax objective to a dictionary.""" return { "__type": objective.__class__.__name__, "metrics": objective.metrics, "weights": objective.weights, "minimize": objective.minimize, }
[docs]def outcome_constraint_to_dict(outcome_constraint: OutcomeConstraint) -> Dict[str, Any]: """Convert Ax outcome constraint to a dictionary.""" return { "__type": outcome_constraint.__class__.__name__, "metric": outcome_constraint.metric, "op": outcome_constraint.op, "bound": outcome_constraint.bound, "relative": outcome_constraint.relative, }
[docs]def optimization_config_to_dict( optimization_config: OptimizationConfig, ) -> Dict[str, Any]: """Convert Ax optimization config to a dictionary.""" return { "__type": optimization_config.__class__.__name__, "objective": optimization_config.objective, "outcome_constraints": optimization_config.outcome_constraints, }
[docs]def multi_objective_optimization_config_to_dict( multi_objective_optimization_config: MultiObjectiveOptimizationConfig, ) -> Dict[str, Any]: """Convert Ax optimization config to a dictionary.""" return { "__type": multi_objective_optimization_config.__class__.__name__, "objective": multi_objective_optimization_config.objective, "outcome_constraints": multi_objective_optimization_config.outcome_constraints, "objective_thresholds": multi_objective_optimization_config.objective_thresholds, # noqa E501 }
[docs]def generator_run_to_dict(generator_run: GeneratorRun) -> Dict[str, Any]: """Convert Ax generator run to a dictionary.""" gr = generator_run cand_metadata = gr.candidate_metadata_by_arm_signature return { "__type": gr.__class__.__name__, "arms": gr.arms, "weights": gr.weights, "optimization_config": gr.optimization_config, "search_space": gr.search_space, "time_created": gr.time_created, "model_predictions": gr.model_predictions, "best_arm_predictions": gr.best_arm_predictions, "generator_run_type": gr.generator_run_type, "index": gr.index, "fit_time": gr.fit_time, "gen_time": gr.gen_time, "model_key": gr._model_key, "model_kwargs": gr._model_kwargs, "bridge_kwargs": gr._bridge_kwargs, "gen_metadata": gr._gen_metadata, "model_state_after_gen": gr._model_state_after_gen, "generation_step_index": gr._generation_step_index, "candidate_metadata_by_arm_signature": cand_metadata, }
[docs]def runner_to_dict(runner: Runner) -> Dict[str, Any]: """Convert Ax runner to a dictionary.""" runner_cls = type(runner) properties = runner_cls.serialize_init_args(runner=runner) properties["__type"] = runner.__class__.__name__ return properties
[docs]def data_to_dict(data: Data) -> Dict[str, Any]: """Convert Ax data to a dictionary.""" return { "__type": data.__class__.__name__, "df": data.df, "description": data.description, }
[docs]def map_data_to_dict(map_data: MapData) -> Dict[str, Any]: """Convert Ax map data to a dictionary.""" return { "__type": map_data.__class__.__name__, "df": map_data.map_df, "map_key_infos": map_data.map_key_infos, "description": map_data.description, }
[docs]def map_key_info_to_dict(mki: MapKeyInfo) -> Dict[str, Any]: """Convert Ax map data metadata to a dictionary.""" return { "__type": mki.__class__.__name__, "key": mki.key, "default_value": mki.default_value, }
[docs]def transform_type_to_dict(transform_type: Type[Transform]) -> Dict[str, Any]: """Convert a transform class to a dictionary.""" return { "__type": "Type[Transform]", "index_in_registry": TRANSFORM_REGISTRY[transform_type], "transform_type": f"{transform_type}", }
[docs]def generation_step_to_dict(generation_step: GenerationStep) -> Dict[str, Any]: """Converts Ax generation step to a dictionary.""" return { "__type": generation_step.__class__.__name__, "model": generation_step.model, "num_trials": generation_step.num_trials, "min_trials_observed": generation_step.min_trials_observed, "max_parallelism": generation_step.max_parallelism, "use_update": generation_step.use_update, "enforce_num_trials": generation_step.enforce_num_trials, "model_kwargs": _encode_callables_as_references( generation_step.model_kwargs or {} ), "model_gen_kwargs": _encode_callables_as_references( generation_step.model_gen_kwargs or {} ), "index": generation_step.index, "should_deduplicate": generation_step.should_deduplicate, }
[docs]def generation_strategy_to_dict( generation_strategy: GenerationStrategy, ) -> Dict[str, Any]: """Converts Ax generation strategy to a dictionary.""" if generation_strategy.uses_non_registered_models: raise ValueError( # pragma: no cover "Generation strategies that use custom models provided through " "callables cannot be serialized and stored." ) return { "__type": generation_strategy.__class__.__name__, "db_id": generation_strategy._db_id, "name": generation_strategy.name, "steps": generation_strategy._steps, "curr_index": generation_strategy._curr.index, "generator_runs": generation_strategy._generator_runs, "had_initialized_model": generation_strategy.model is not None, "experiment": generation_strategy._experiment, }
[docs]def observation_features_to_dict(obs_features: ObservationFeatures) -> Dict[str, Any]: """Converts Ax observation features to a dictionary""" return { "__type": obs_features.__class__.__name__, "parameters": obs_features.parameters, "trial_index": obs_features.trial_index, "start_time": obs_features.start_time, "end_time": obs_features.end_time, "random_split": obs_features.random_split, "metadata": obs_features.metadata, }
[docs]def botorch_model_to_dict(model: BoTorchModel) -> Dict[str, Any]: """Convert Ax model to a dictionary.""" return { "__type": model.__class__.__name__, "surrogate": model.surrogate, "surrogate_options": model.surrogate_options, "acquisition_class": model.acquisition_class, "botorch_acqf_class": model._botorch_acqf_class, "acquisition_options": model.acquisition_options or {}, "refit_on_update": model.refit_on_update, "refit_on_cv": model.refit_on_cv, "warm_start_refit": model.warm_start_refit, }
[docs]def surrogate_to_dict(surrogate: Surrogate) -> Dict[str, Any]: """Convert Ax surrogate to a dictionary.""" dict_representation = {"__type": surrogate.__class__.__name__} dict_representation.update(surrogate._serialize_attributes_as_kwargs()) return dict_representation
[docs]def botorch_modular_to_dict(class_type: Type[Any]) -> Dict[str, Any]: """Convert any class to a dictionary.""" for _class in CLASS_TO_REGISTRY: if issubclass(class_type, _class): registry = CLASS_TO_REGISTRY[_class] if class_type not in registry: raise ValueError( # pragma: no cover f"Class `{class_type.__name__}` not in Type[{_class.__name__}] " "registry, please add it. BoTorch object registries are " "located in `ax/storage/botorch_modular_registry.py`." ) return { "__type": f"Type[{_class.__name__}]", "index": registry[class_type], "class": f"{_class}", } raise ValueError( f"{class_type} does not have a corresponding parent class in " "CLASS_TO_REGISTRY." )
[docs]def botorch_component_to_dict(input_obj: Type[Any]) -> Dict[str, Any]: class_type = input_obj.__class__ state_dict = input_obj.state_dict() # Cast dict values to float to avoid errors with Tensors. state_dict = {k: float(v) for k, v in state_dict.items()} return { "__type": f"{class_type.__name__}", "index": class_type, "class": f"{class_type}", "state_dict": state_dict, }
[docs]def percentile_early_stopping_strategy_to_dict( strategy: PercentileEarlyStoppingStrategy, ) -> Dict[str, Any]: """Convert Ax percentile early stopping strategy to a dictionary.""" return { "__type": strategy.__class__.__name__, "metric_names": strategy.metric_names, "percentile_threshold": strategy.percentile_threshold, "min_progression": strategy.min_progression, "min_curves": strategy.min_curves, "trial_indices_to_ignore": strategy.trial_indices_to_ignore, "true_objective_metric_name": strategy.true_objective_metric_name, "seconds_between_polls": strategy.seconds_between_polls, }
[docs]def threshold_early_stopping_strategy_to_dict( strategy: ThresholdEarlyStoppingStrategy, ) -> Dict[str, Any]: """Convert Ax metric-threshold early stopping strategy to a dictionary.""" return { "__type": strategy.__class__.__name__, "metric_names": strategy.metric_names, "metric_threshold": strategy.metric_threshold, "min_progression": strategy.min_progression, "trial_indices_to_ignore": strategy.trial_indices_to_ignore, "true_objective_metric_name": strategy.true_objective_metric_name, }
[docs]def logical_early_stopping_strategy_to_dict( strategy: LogicalEarlyStoppingStrategy, ) -> Dict[str, Any]: return { "__type": strategy.__class__.__name__, "left": strategy.left, "right": strategy.right, }
[docs]def winsorization_config_to_dict(config: WinsorizationConfig) -> Dict[str, Any]: """Convert Ax winsorization config to a dictionary.""" return { "__type": config.__class__.__name__, "lower_quantile_margin": config.lower_quantile_margin, "upper_quantile_margin": config.upper_quantile_margin, "lower_boundary": config.lower_boundary, "upper_boundary": config.upper_boundary, }
[docs]def pytorch_cnn_torchvision_benchmark_problem_to_dict( problem: PyTorchCNNTorchvisionBenchmarkProblem, ) -> Dict[str, Any]: # unit tests for this in benchmark suite return { # pragma: no cover "__type": problem.__class__.__name__, "name": not_none(re.compile("(?<=::).*").search(problem.name)).group(), }