Source code for ax.storage.sqa_store.sqa_config

#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

from collections.abc import Callable
from dataclasses import dataclass, field
from enum import Enum
from typing import Any

from ax.analysis.analysis import AnalysisCard

from ax.core.arm import Arm
from ax.core.auxiliary import AuxiliaryExperimentPurpose
from ax.core.batch_trial import AbandonedArm
from ax.core.data import Data
from ax.core.experiment import Experiment
from ax.core.generator_run import GeneratorRun, GeneratorRunType
from ax.core.metric import Metric
from ax.core.parameter import Parameter
from ax.core.parameter_constraint import ParameterConstraint
from ax.core.runner import Runner
from ax.core.trial import Trial
from ax.modelbridge.generation_strategy import GenerationStrategy
from ax.storage.json_store.registry import (
    CORE_CLASS_DECODER_REGISTRY,
    CORE_CLASS_ENCODER_REGISTRY,
    CORE_DECODER_REGISTRY,
    CORE_ENCODER_REGISTRY,
    TDecoderRegistry,
)
from ax.storage.metric_registry import CORE_METRIC_REGISTRY
from ax.storage.runner_registry import CORE_RUNNER_REGISTRY
from ax.storage.sqa_store.db import SQABase
from ax.storage.sqa_store.sqa_classes import (
    SQAAbandonedArm,
    SQAAnalysisCard,
    SQAArm,
    SQAData,
    SQAExperiment,
    SQAGenerationStrategy,
    SQAGeneratorRun,
    SQAMetric,
    SQAParameter,
    SQAParameterConstraint,
    SQARunner,
    SQATrial,
)
from ax.utils.common.base import Base


[docs] @dataclass class SQAConfig: """Metadata needed to save and load an experiment to SQLAlchemy. Attributes: class_to_sqa_class: Mapping of user-facing class to SQLAlchemy class that it will be encoded to. This allows overwriting of the default classes to provide custom save functionality. experiment_type_enum: Enum containing valid Experiment types. generator_run_type_enum: Enum containing valid Generator Run types. json_encoder_registry: Mapping from user-facing types to their json serialization function. """ def _default_class_to_sqa_class(self=None) -> dict[type[Base], type[SQABase]]: # pyre-ignore [7] return { AbandonedArm: SQAAbandonedArm, Arm: SQAArm, Data: SQAData, Experiment: SQAExperiment, GenerationStrategy: SQAGenerationStrategy, GeneratorRun: SQAGeneratorRun, Parameter: SQAParameter, ParameterConstraint: SQAParameterConstraint, Metric: SQAMetric, Runner: SQARunner, Trial: SQATrial, AnalysisCard: SQAAnalysisCard, } class_to_sqa_class: dict[type[Base], type[SQABase]] = field( default_factory=_default_class_to_sqa_class ) experiment_type_enum: Enum | type[Enum] | None = None generator_run_type_enum: Enum | type[Enum] | None = GeneratorRunType auxiliary_experiment_purpose_enum: type[Enum] = AuxiliaryExperimentPurpose # pyre-fixme[4]: Attribute annotation cannot contain `Any`. # pyre-fixme[24]: Generic type `type` expects 1 type parameter, use # `typing.Type` to avoid runtime subscripting errors. json_encoder_registry: dict[type, Callable[[Any], dict[str, Any]]] = field( default_factory=lambda: CORE_ENCODER_REGISTRY ) # pyre-fixme[4]: Attribute annotation cannot contain `Any`. # pyre-fixme[24]: Generic type `type` expects 1 type parameter, use # `typing.Type` to avoid runtime subscripting errors. json_class_encoder_registry: dict[type, Callable[[Any], dict[str, Any]]] = field( default_factory=lambda: CORE_CLASS_ENCODER_REGISTRY ) json_decoder_registry: TDecoderRegistry = field( default_factory=lambda: CORE_DECODER_REGISTRY ) # pyre-fixme[4]: Attribute annotation cannot contain `Any`. json_class_decoder_registry: dict[str, Callable[[dict[str, Any]], Any]] = field( default_factory=lambda: CORE_CLASS_DECODER_REGISTRY ) metric_registry: dict[type[Metric], int] = field( default_factory=lambda: CORE_METRIC_REGISTRY ) runner_registry: dict[type[Runner], int] = field( default_factory=lambda: CORE_RUNNER_REGISTRY ) @property def reverse_metric_registry(self) -> dict[int, type[Metric]]: return {v: k for k, v in self.metric_registry.items()} @property def reverse_runner_registry(self) -> dict[int, type[Runner]]: return {v: k for k, v in self.runner_registry.items()}