Source code for ax.storage.sqa_store.sqa_classes

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.

from datetime import datetime
from typing import Any, Dict, List, Optional

from ax.core.base_trial import TrialStatus
from ax.core.parameter import ParameterType
from ax.core.types import (
    ComparisonOp,
    TModelPredict,
    TModelPredictArm,
    TParameterization,
    TParamValue,
)
from ax.storage.sqa_store.db import (
    LONG_STRING_FIELD_LENGTH,
    LONGTEXT_BYTES,
    NAME_OR_TYPE_FIELD_LENGTH,
    Base,
)
from ax.storage.sqa_store.json import (
    JSONEncodedDict,
    JSONEncodedList,
    JSONEncodedObject,
    JSONEncodedTextDict,
)
from ax.storage.sqa_store.sqa_enum import IntEnum, StringEnum
from ax.storage.sqa_store.timestamp import IntTimestamp
from ax.storage.utils import DomainType, MetricIntent, ParameterConstraintType
from sqlalchemy import (
    BigInteger,
    Boolean,
    Column,
    Float,
    ForeignKey,
    Integer,
    String,
    Text,
)
from sqlalchemy.orm import backref, relationship


ONLY_ONE_FIELDS = ["experiment_id", "generator_run_id"]


[docs]class SQAParameter(Base): __tablename__: str = "parameter_v2" domain_type: DomainType = Column(IntEnum(DomainType), nullable=False) experiment_id: Optional[int] = Column(Integer, ForeignKey("experiment_v2.id")) id: int = Column(Integer, primary_key=True) generator_run_id: Optional[int] = Column(Integer, ForeignKey("generator_run_v2.id")) name: str = Column(String(NAME_OR_TYPE_FIELD_LENGTH), nullable=False) parameter_type: ParameterType = Column(IntEnum(ParameterType), nullable=False) is_fidelity: Optional[bool] = Column(Boolean) target_value: Optional[TParamValue] = Column(JSONEncodedObject) # Attributes for Range Parameters digits: Optional[int] = Column(Integer) log_scale: Optional[bool] = Column(Boolean) lower: Optional[float] = Column(Float) upper: Optional[float] = Column(Float) # Attributes for Choice Parameters choice_values: Optional[List[TParamValue]] = Column(JSONEncodedList) is_ordered: Optional[bool] = Column(Boolean) is_task: Optional[bool] = Column(Boolean) # Attributes for Fixed Parameters fixed_value: Optional[TParamValue] = Column(JSONEncodedObject) immutable_fields = ["name", "parameter_type"] unique_id = "name"
[docs]class SQAParameterConstraint(Base): __tablename__: str = "parameter_constraint_v2" bound: float = Column(Float, nullable=False) constraint_dict: Dict[str, float] = Column(JSONEncodedDict, nullable=False) experiment_id: Optional[int] = Column(Integer, ForeignKey("experiment_v2.id")) id: int = Column(Integer, primary_key=True) generator_run_id: Optional[int] = Column(Integer, ForeignKey("generator_run_v2.id")) type: IntEnum = Column(IntEnum(ParameterConstraintType), nullable=False) # ParameterConstraints should never be updated; since they don't have # a field that can be used for a UID, if anything changes, # we should just throw them out and recreate them immutable_fields = ["type", "constraint_dict", "bound"]
[docs]class SQAMetric(Base): __tablename__: str = "metric_v2" experiment_id: Optional[int] = Column(Integer, ForeignKey("experiment_v2.id")) generator_run_id: Optional[int] = Column(Integer, ForeignKey("generator_run_v2.id")) id: int = Column(Integer, primary_key=True) lower_is_better: Optional[bool] = Column(Boolean) intent: MetricIntent = Column(StringEnum(MetricIntent), nullable=False) metric_type: int = Column(Integer, nullable=False) name: str = Column(String(LONG_STRING_FIELD_LENGTH), nullable=False) properties: Optional[Dict[str, Any]] = Column(JSONEncodedTextDict, default={}) # Attributes for Objectives minimize: Optional[bool] = Column(Boolean) # Attributes for Outcome Constraints op: Optional[ComparisonOp] = Column(IntEnum(ComparisonOp)) bound: Optional[float] = Column(Float) relative: Optional[bool] = Column(Boolean) # Multi-type Experiment attributes trial_type: Optional[str] = Column(String(NAME_OR_TYPE_FIELD_LENGTH)) canonical_name: Optional[str] = Column(String(NAME_OR_TYPE_FIELD_LENGTH)) immutable_fields = ["name", "metric_type"] unique_id = "name"
[docs]class SQAArm(Base): __tablename__: str = "arm_v2" generator_run_id: int = Column(Integer, ForeignKey("generator_run_v2.id")) id: int = Column(Integer, primary_key=True) name: Optional[str] = Column(String(NAME_OR_TYPE_FIELD_LENGTH)) parameters: TParameterization = Column(JSONEncodedTextDict, nullable=False) weight: float = Column(Float, nullable=False, default=1.0) immutable_fields = ["parameters"] unique_id = "name"
[docs]class SQAAbandonedArm(Base): __tablename__: str = "abandoned_arm_v2" abandoned_reason: Optional[str] = Column(String(LONG_STRING_FIELD_LENGTH)) id: int = Column(Integer, primary_key=True) name: str = Column(String(NAME_OR_TYPE_FIELD_LENGTH), nullable=False) time_abandoned: datetime = Column( IntTimestamp, nullable=False, default=datetime.now ) trial_id: int = Column(Integer, ForeignKey("trial_v2.id")) immutable_fields = ["name", "time_abandoned"] unique_id = "name"
[docs]class SQAGeneratorRun(Base): __tablename__: str = "generator_run_v2" best_arm_name: Optional[str] = Column(String(NAME_OR_TYPE_FIELD_LENGTH)) best_arm_parameters: Optional[TParameterization] = Column(JSONEncodedTextDict) best_arm_predictions: Optional[TModelPredictArm] = Column(JSONEncodedList) generator_run_type: Optional[int] = Column(Integer) id: int = Column(Integer, primary_key=True) index: Optional[int] = Column(Integer) model_predictions: Optional[TModelPredict] = Column(JSONEncodedList) time_created: datetime = Column(IntTimestamp, nullable=False, default=datetime.now) trial_id: Optional[int] = Column(Integer, ForeignKey("trial_v2.id")) weight: Optional[float] = Column(Float) fit_time: Optional[float] = Column(Float) gen_time: Optional[float] = Column(Float) model_key: Optional[str] = Column(String(NAME_OR_TYPE_FIELD_LENGTH)) model_kwargs: Optional[Dict[str, Any]] = Column(JSONEncodedTextDict) bridge_kwargs: Optional[Dict[str, Any]] = Column(JSONEncodedTextDict) generation_strategy_id: Optional[int] = Column( Integer, ForeignKey("generation_strategy.id") ) # relationships # Use selectin loading for collections to prevent idle timeout errors # (https://docs.sqlalchemy.org/en/13/orm/loading_relationships.html#selectin-eager-loading) arms: List[SQAArm] = relationship( "SQAArm", cascade="all, delete-orphan", lazy="selectin", order_by=lambda: SQAArm.id, ) metrics: List[SQAMetric] = relationship( "SQAMetric", cascade="all, delete-orphan", lazy="selectin" ) parameters: List[SQAParameter] = relationship( "SQAParameter", cascade="all, delete-orphan", lazy="selectin" ) parameter_constraints: List[SQAParameterConstraint] = relationship( "SQAParameterConstraint", cascade="all, delete-orphan", lazy="selectin" ) ignore_during_update_fields = ["time_created"] unique_id = "index"
[docs]class SQARunner(Base): __tablename__: str = "runner" id: int = Column(Integer, primary_key=True) experiment_id: Optional[int] = Column(Integer, ForeignKey("experiment_v2.id")) properties: Optional[Dict[str, Any]] = Column(JSONEncodedTextDict, default={}) runner_type: int = Column(Integer, nullable=False) trial_id: Optional[int] = Column(Integer, ForeignKey("trial_v2.id")) # Multi-type Experiment attributes trial_type: Optional[str] = Column(String(NAME_OR_TYPE_FIELD_LENGTH))
[docs]class SQAData(Base): __tablename__: str = "data_v2" id: int = Column(Integer, primary_key=True) data_json: str = Column(Text(LONGTEXT_BYTES), nullable=False) description: Optional[str] = Column(String(LONG_STRING_FIELD_LENGTH)) experiment_id: int = Column(Integer, ForeignKey("experiment_v2.id")) time_created: int = Column(BigInteger, nullable=False) trial_index: Optional[int] = Column(Integer) generation_strategy_id: Optional[int] = Column( Integer, ForeignKey("generation_strategy.id") ) unique_id = "time_created"
[docs]class SQAGenerationStrategy(Base): __tablename__: str = "generation_strategy" id: int = Column(Integer, primary_key=True) name: str = Column(String(NAME_OR_TYPE_FIELD_LENGTH), nullable=False) steps: List[Dict[str, Any]] = Column(JSONEncodedList, nullable=False) generated: List[str] = Column(JSONEncodedList, nullable=False) observed: List[str] = Column(JSONEncodedList, nullable=False) curr_index: int = Column(Integer, nullable=False) experiment_id: Optional[int] = Column(Integer, ForeignKey("experiment_v2.id")) generator_runs: List[SQAGeneratorRun] = relationship( "SQAGeneratorRun", cascade="all, delete-orphan", lazy="selectin", order_by=lambda: SQAGeneratorRun.id, ) data: SQAData = relationship( "SQAData", cascade="all, delete-orphan", lazy=False, uselist=False )
[docs]class SQATrial(Base): __tablename__: str = "trial_v2" abandoned_reason: Optional[str] = Column(String(NAME_OR_TYPE_FIELD_LENGTH)) deployed_name: Optional[str] = Column(String(NAME_OR_TYPE_FIELD_LENGTH)) experiment_id: int = Column(Integer, ForeignKey("experiment_v2.id")) id: int = Column(Integer, primary_key=True) index: int = Column(Integer, index=True, nullable=False) is_batch: bool = Column("is_batched", Boolean, nullable=False, default=True) num_arms_created: int = Column(Integer, nullable=False, default=0) run_metadata: Optional[Dict[str, Any]] = Column(JSONEncodedTextDict) status: TrialStatus = Column( IntEnum(TrialStatus), nullable=False, default=TrialStatus.CANDIDATE ) status_quo_name: Optional[str] = Column(String(NAME_OR_TYPE_FIELD_LENGTH)) time_completed: Optional[datetime] = Column(IntTimestamp) time_created: datetime = Column(IntTimestamp, nullable=False) time_staged: Optional[datetime] = Column(IntTimestamp) time_run_started: Optional[datetime] = Column(IntTimestamp) trial_type: Optional[str] = Column(String(NAME_OR_TYPE_FIELD_LENGTH)) # relationships # Trials and experiments are mutable, so the children relationships need # cascade="all, delete-orphan", which means if we remove or replace # a child, the old one will be deleted. # Use selectin loading for collections to prevent idle timeout errors # (https://docs.sqlalchemy.org/en/13/orm/loading_relationships.html#selectin-eager-loading) abandoned_arms: List[SQAAbandonedArm] = relationship( "SQAAbandonedArm", cascade="all, delete-orphan", lazy="selectin" ) generator_runs: List[SQAGeneratorRun] = relationship( "SQAGeneratorRun", cascade="all, delete-orphan", lazy="selectin" ) runner: SQARunner = relationship( "SQARunner", uselist=False, cascade="all, delete-orphan", lazy=False ) unique_id = "index" ignore_during_update_fields = ["time_created"] immutable_fields = ["is_batch"]
[docs]class SQAExperiment(Base): __tablename__: str = "experiment_v2" description: Optional[str] = Column(String(LONG_STRING_FIELD_LENGTH)) experiment_type: Optional[int] = Column(Integer) id: int = Column(Integer, primary_key=True) is_test: bool = Column(Boolean, nullable=False, default=False) name: str = Column(String(NAME_OR_TYPE_FIELD_LENGTH), nullable=False) properties: Optional[Dict[str, Any]] = Column(JSONEncodedTextDict, default={}) status_quo_name: Optional[str] = Column(String(NAME_OR_TYPE_FIELD_LENGTH)) status_quo_parameters: Optional[TParameterization] = Column(JSONEncodedTextDict) time_created: datetime = Column(IntTimestamp, nullable=False) default_trial_type: Optional[str] = Column(String(NAME_OR_TYPE_FIELD_LENGTH)) # relationships # Trials and experiments are mutable, so the children relationships need # cascade="all, delete-orphan", which means if we remove or replace # a child, the old one will be deleted. # Use selectin loading for collections to prevent idle timeout errors # (https://docs.sqlalchemy.org/en/13/orm/loading_relationships.html#selectin-eager-loading) data: List[SQAData] = relationship( "SQAData", cascade="all, delete-orphan", lazy="selectin" ) metrics: List[SQAMetric] = relationship( "SQAMetric", cascade="all, delete-orphan", lazy="selectin" ) parameters: List[SQAParameter] = relationship( "SQAParameter", cascade="all, delete-orphan", lazy="selectin" ) parameter_constraints: List[SQAParameterConstraint] = relationship( "SQAParameterConstraint", cascade="all, delete-orphan", lazy="selectin" ) runners: List[SQARunner] = relationship( "SQARunner", cascade="all, delete-orphan", lazy=False ) trials: List[SQATrial] = relationship( "SQATrial", cascade="all, delete-orphan", lazy="selectin" ) generation_strategy: Optional[SQAGenerationStrategy] = relationship( "SQAGenerationStrategy", backref=backref("experiment", lazy=False), uselist=False, lazy="selectin", ) immutable_fields = ["name"] ignore_during_update_fields = ["time_created"]