#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# pyre-strict
from __future__ import annotations
from datetime import datetime
from decimal import Decimal
from typing import Any, List
from ax.core.base_trial import TrialStatus
from ax.core.batch_trial import LifecycleStage
from ax.core.parameter import ParameterType
from ax.core.types import (
ComparisonOp,
TModelPredict,
TModelPredictArm,
TParameterization,
TParamValue,
)
from ax.storage.sqa_store.db import (
Base,
LONG_STRING_FIELD_LENGTH,
LONGTEXT_BYTES,
NAME_OR_TYPE_FIELD_LENGTH,
)
from ax.storage.sqa_store.json import (
JSONEncodedDict,
JSONEncodedList,
JSONEncodedLongTextDict,
JSONEncodedObject,
JSONEncodedTextDict,
)
from ax.storage.sqa_store.sqa_enum import IntEnum, StringEnum
from ax.storage.sqa_store.timestamp import IntTimestamp
from ax.storage.utils import DataType, DomainType, MetricIntent, ParameterConstraintType
from sqlalchemy import (
BigInteger,
Boolean,
Column,
Float,
ForeignKey,
Integer,
String,
Text,
)
from sqlalchemy.orm import backref, relationship
ONLY_ONE_FIELDS = ["experiment_id", "generator_run_id"]
ONLY_ONE_METRIC_FIELDS = ["scalarized_objective_id", "scalarized_outcome_constraint_id"]
[docs]
class SQAParameter(Base):
__tablename__: str = "parameter_v2"
domain_type: Column[DomainType] = Column(IntEnum(DomainType), nullable=False)
experiment_id: Column[int | None] = Column(Integer, ForeignKey("experiment_v2.id"))
id: Column[int] = Column(Integer, primary_key=True)
generator_run_id: Column[int | None] = Column(
Integer, ForeignKey("generator_run_v2.id")
)
name: Column[str] = Column(String(NAME_OR_TYPE_FIELD_LENGTH), nullable=False)
parameter_type: Column[ParameterType] = Column(
IntEnum(ParameterType), nullable=False
)
is_fidelity: Column[bool | None] = Column(Boolean)
target_value: Column[TParamValue | None] = Column(JSONEncodedObject)
# Attributes for Range Parameters
digits: Column[int | None] = Column(Integer)
log_scale: Column[bool | None] = Column(Boolean)
lower: Column[Decimal | None] = Column(Float)
upper: Column[Decimal | None] = Column(Float)
# Attributes for Choice Parameters
choice_values: Column[List[TParamValue] | None] = Column(JSONEncodedList)
is_ordered: Column[bool | None] = Column(Boolean)
is_task: Column[bool | None] = Column(Boolean)
dependents: Column[dict[TParamValue, List[str]] | None] = Column(JSONEncodedObject)
# Attributes for Fixed Parameters
fixed_value: Column[TParamValue | None] = Column(JSONEncodedObject)
[docs]
class SQAParameterConstraint(Base):
__tablename__: str = "parameter_constraint_v2"
bound: Column[Decimal] = Column(Float, nullable=False)
constraint_dict: Column[dict[str, float]] = Column(JSONEncodedDict, nullable=False)
experiment_id: Column[int | None] = Column(Integer, ForeignKey("experiment_v2.id"))
id: Column[int] = Column(Integer, primary_key=True)
generator_run_id: Column[int | None] = Column(
Integer, ForeignKey("generator_run_v2.id")
)
type: Column[IntEnum] = Column(IntEnum(ParameterConstraintType), nullable=False)
[docs]
class SQAMetric(Base):
__tablename__: str = "metric_v2"
experiment_id: Column[int | None] = Column(Integer, ForeignKey("experiment_v2.id"))
generator_run_id: Column[int | None] = Column(
Integer, ForeignKey("generator_run_v2.id")
)
id: Column[int] = Column(Integer, primary_key=True)
lower_is_better: Column[bool | None] = Column(Boolean)
intent: Column[MetricIntent] = Column(StringEnum(MetricIntent), nullable=False)
metric_type: Column[int] = Column(Integer, nullable=False)
name: Column[str] = Column(String(LONG_STRING_FIELD_LENGTH), nullable=False)
properties: Column[dict[str, Any] | None] = Column(JSONEncodedTextDict, default={})
# Attributes for Objectives
minimize: Column[bool | None] = Column(Boolean)
# Attributes for Outcome Constraints
op: Column[ComparisonOp | None] = Column(IntEnum(ComparisonOp))
bound: Column[Decimal | None] = Column(Float)
relative: Column[bool | None] = Column(Boolean)
# Multi-type Experiment attributes
trial_type: Column[str | None] = Column(String(NAME_OR_TYPE_FIELD_LENGTH))
canonical_name: Column[str | None] = Column(String(NAME_OR_TYPE_FIELD_LENGTH))
scalarized_objective_id: Column[int | None] = Column(
Integer, ForeignKey("metric_v2.id")
)
# Relationship containing SQAMetric(s) only defined for the parent metric
# of Multi/Scalarized Objective contains all children of the parent metric
# join_depth argument: used for loading self-referential relationships
# https://docs.sqlalchemy.org/en/13/orm/self_referential.html#configuring-self-referential-eager-loading
scalarized_objective_children_metrics: List["SQAMetric"] = relationship(
"SQAMetric",
cascade="all, delete-orphan",
lazy=True,
foreign_keys=[scalarized_objective_id],
)
# Attribute only defined for the children of Scalarized Objective
scalarized_objective_weight: Column[Decimal | None] = Column(Float)
scalarized_outcome_constraint_id: Column[int | None] = Column(
Integer, ForeignKey("metric_v2.id")
)
scalarized_outcome_constraint_children_metrics: List["SQAMetric"] = relationship(
"SQAMetric",
cascade="all, delete-orphan",
lazy=True,
foreign_keys=[scalarized_outcome_constraint_id],
)
scalarized_outcome_constraint_weight: Column[Decimal | None] = Column(Float)
[docs]
class SQAArm(Base):
__tablename__: str = "arm_v2"
generator_run_id: Column[int] = Column(
Integer, ForeignKey("generator_run_v2.id"), nullable=False
)
id: Column[int] = Column(Integer, primary_key=True)
name: Column[str | None] = Column(String(NAME_OR_TYPE_FIELD_LENGTH))
parameters: Column[TParameterization] = Column(JSONEncodedTextDict, nullable=False)
weight: Column[Decimal] = Column(Float, nullable=False, default=1.0)
[docs]
class SQAAbandonedArm(Base):
__tablename__: str = "abandoned_arm_v2"
abandoned_reason: Column[str | None] = Column(String(LONG_STRING_FIELD_LENGTH))
id: Column[int] = Column(Integer, primary_key=True)
name: Column[str] = Column(String(NAME_OR_TYPE_FIELD_LENGTH), nullable=False)
time_abandoned: Column[datetime] = Column(
IntTimestamp, nullable=False, default=datetime.now
)
trial_id: Column[int] = Column(Integer, ForeignKey("trial_v2.id"), nullable=False)
[docs]
class SQAGeneratorRun(Base):
__tablename__: str = "generator_run_v2"
best_arm_name: Column[str | None] = Column(String(NAME_OR_TYPE_FIELD_LENGTH))
best_arm_parameters: Column[TParameterization | None] = Column(JSONEncodedTextDict)
best_arm_predictions: Column[TModelPredictArm | None] = Column(JSONEncodedList)
generator_run_type: Column[int | None] = Column(Integer)
id: Column[int] = Column(Integer, primary_key=True)
index: Column[int | None] = Column(Integer)
model_predictions: Column[TModelPredict | None] = Column(JSONEncodedList)
time_created: Column[datetime] = Column(
IntTimestamp, nullable=False, default=datetime.now
)
trial_id: Column[int | None] = Column(Integer, ForeignKey("trial_v2.id"))
weight: Column[Decimal | None] = Column(Float)
fit_time: Column[Decimal | None] = Column(Float)
gen_time: Column[Decimal | None] = Column(Float)
model_key: Column[str | None] = Column(String(NAME_OR_TYPE_FIELD_LENGTH))
model_kwargs: Column[dict[str, Any] | None] = Column(JSONEncodedTextDict)
bridge_kwargs: Column[dict[str, Any] | None] = Column(JSONEncodedTextDict)
gen_metadata: Column[dict[str, Any] | None] = Column(JSONEncodedTextDict)
model_state_after_gen: Column[dict[str, Any] | None] = Column(JSONEncodedTextDict)
generation_strategy_id: Column[int | None] = Column(
Integer, ForeignKey("generation_strategy.id")
)
generation_step_index: Column[int | None] = Column(Integer)
candidate_metadata_by_arm_signature: Column[dict[str, Any] | None] = Column(
JSONEncodedTextDict
)
generation_node_name: Column[str | None] = Column(String(NAME_OR_TYPE_FIELD_LENGTH))
# relationships
# Use selectin loading for collections to prevent idle timeout errors
# (https://docs.sqlalchemy.org/en/13/orm/loading_relationships.html#selectin-eager-loading)
arms: List[SQAArm] = relationship(
"SQAArm",
cascade="all, delete-orphan",
lazy="selectin",
order_by=lambda: SQAArm.id,
)
metrics: List[SQAMetric] = relationship(
"SQAMetric", cascade="all, delete-orphan", lazy="selectin"
)
parameters: List[SQAParameter] = relationship(
"SQAParameter", cascade="all, delete-orphan", lazy="selectin"
)
parameter_constraints: List[SQAParameterConstraint] = relationship(
"SQAParameterConstraint", cascade="all, delete-orphan", lazy="selectin"
)
[docs]
class SQARunner(Base):
__tablename__: str = "runner"
id: Column[int] = Column(Integer, primary_key=True)
experiment_id: Column[int | None] = Column(Integer, ForeignKey("experiment_v2.id"))
properties: Column[dict[str, Any] | None] = Column(
JSONEncodedLongTextDict, default={}
)
runner_type: Column[int] = Column(Integer, nullable=False)
trial_id: Column[int | None] = Column(Integer, ForeignKey("trial_v2.id"))
# Multi-type Experiment attributes
trial_type: Column[str | None] = Column(String(NAME_OR_TYPE_FIELD_LENGTH))
[docs]
class SQAData(Base):
__tablename__: str = "data_v2"
id: Column[int] = Column(Integer, primary_key=True)
data_json: Column[str] = Column(Text(LONGTEXT_BYTES), nullable=False)
description: Column[str | None] = Column(String(LONG_STRING_FIELD_LENGTH))
experiment_id: Column[int | None] = Column(Integer, ForeignKey("experiment_v2.id"))
time_created: Column[int] = Column(BigInteger, nullable=False)
trial_index: Column[int | None] = Column(Integer)
generation_strategy_id: Column[int | None] = Column(
Integer, ForeignKey("generation_strategy.id")
)
structure_metadata_json: Column[str | None] = Column(
Text(LONGTEXT_BYTES), nullable=True
)
[docs]
class SQAGenerationStrategy(Base):
__tablename__: str = "generation_strategy"
id: Column[int] = Column(Integer, primary_key=True)
name: Column[str] = Column(String(NAME_OR_TYPE_FIELD_LENGTH), nullable=False)
steps: Column[List[dict[str, Any]]] = Column(JSONEncodedList, nullable=False)
curr_index: Column[int | None] = Column(Integer, nullable=True)
experiment_id: Column[int | None] = Column(Integer, ForeignKey("experiment_v2.id"))
nodes: Column[List[dict[str, Any]]] = Column(JSONEncodedList, nullable=True)
curr_node_name: Column[str | None] = Column(
String(NAME_OR_TYPE_FIELD_LENGTH), nullable=True
)
generator_runs: List[SQAGeneratorRun] = relationship(
"SQAGeneratorRun",
cascade="all, delete-orphan",
lazy="selectin",
order_by=lambda: SQAGeneratorRun.id,
)
[docs]
class SQATrial(Base):
__tablename__: str = "trial_v2"
abandoned_reason: Column[str | None] = Column(String(NAME_OR_TYPE_FIELD_LENGTH))
failed_reason: Column[str | None] = Column(String(NAME_OR_TYPE_FIELD_LENGTH))
deployed_name: Column[str | None] = Column(String(NAME_OR_TYPE_FIELD_LENGTH))
experiment_id: Column[int] = Column(
Integer, ForeignKey("experiment_v2.id"), nullable=False
)
id: Column[int] = Column(Integer, primary_key=True)
index: Column[int] = Column(Integer, index=True, nullable=False)
is_batch: Column[bool] = Column("is_batched", Boolean, nullable=False, default=True)
lifecycle_stage: Column[LifecycleStage | None] = Column(
IntEnum(LifecycleStage), nullable=True
)
num_arms_created: Column[int] = Column(Integer, nullable=False, default=0)
optimize_for_power: Column[bool | None] = Column(Boolean)
ttl_seconds: Column[int | None] = Column(Integer)
run_metadata: Column[dict[str, Any] | None] = Column(JSONEncodedLongTextDict)
stop_metadata: Column[dict[str, Any] | None] = Column(JSONEncodedTextDict)
status: Column[TrialStatus] = Column(
IntEnum(TrialStatus), nullable=False, default=TrialStatus.CANDIDATE
)
status_quo_name: Column[str | None] = Column(String(NAME_OR_TYPE_FIELD_LENGTH))
time_completed: Column[datetime | None] = Column(IntTimestamp)
time_created: Column[datetime] = Column(IntTimestamp, nullable=False)
time_staged: Column[datetime | None] = Column(IntTimestamp)
time_run_started: Column[datetime | None] = Column(IntTimestamp)
trial_type: Column[str | None] = Column(String(NAME_OR_TYPE_FIELD_LENGTH))
generation_step_index: Column[int | None] = Column(Integer)
properties: Column[dict[str, Any] | None] = Column(JSONEncodedTextDict, default={})
# relationships
# Trials and experiments are mutable, so the children relationships need
# cascade="all, delete-orphan", which means if we remove or replace
# a child, the old one will be deleted.
# Use selectin loading for collections to prevent idle timeout errors
# (https://docs.sqlalchemy.org/en/13/orm/loading_relationships.html#selectin-eager-loading)
abandoned_arms: List[SQAAbandonedArm] = relationship(
"SQAAbandonedArm", cascade="all, delete-orphan", lazy="selectin"
)
generator_runs: List[SQAGeneratorRun] = relationship(
"SQAGeneratorRun", cascade="all, delete-orphan", lazy="selectin"
)
runner: SQARunner = relationship(
"SQARunner", uselist=False, cascade="all, delete-orphan", lazy=False
)
[docs]
class SQAAnalysisCard(Base):
__tablename__: str = "analysis_card"
id: Column[int] = Column(Integer, primary_key=True)
name: Column[str] = Column(String(NAME_OR_TYPE_FIELD_LENGTH), nullable=False)
title: Column[str] = Column(String(LONG_STRING_FIELD_LENGTH), nullable=False)
subtitle: Column[str] = Column(Text, nullable=False)
level: Column[int] = Column(Integer, nullable=False)
dataframe_json: Column[str] = Column(Text(LONGTEXT_BYTES), nullable=False)
blob: Column[str] = Column(Text(LONGTEXT_BYTES), nullable=False)
blob_annotation: Column[str] = Column(
String(NAME_OR_TYPE_FIELD_LENGTH), nullable=False
)
time_created: Column[datetime] = Column(IntTimestamp, nullable=False)
experiment_id: Column[int] = Column(
Integer, ForeignKey("experiment_v2.id"), nullable=False
)
attributes: Column[str] = Column(Text(LONGTEXT_BYTES), nullable=False)
[docs]
class SQAExperiment(Base):
__tablename__: str = "experiment_v2"
description: Column[str | None] = Column(String(LONG_STRING_FIELD_LENGTH))
experiment_type: Column[int | None] = Column(Integer)
id: Column[int] = Column(Integer, primary_key=True)
is_test: Column[bool] = Column(Boolean, nullable=False, default=False)
name: Column[str] = Column(String(NAME_OR_TYPE_FIELD_LENGTH), nullable=False)
properties: Column[dict[str, Any] | None] = Column(JSONEncodedTextDict, default={})
status_quo_name: Column[str | None] = Column(String(NAME_OR_TYPE_FIELD_LENGTH))
status_quo_parameters: Column[TParameterization | None] = Column(
JSONEncodedTextDict
)
time_created: Column[datetime] = Column(IntTimestamp, nullable=False)
default_trial_type: Column[str | None] = Column(String(NAME_OR_TYPE_FIELD_LENGTH))
default_data_type: Column[DataType] = Column(IntEnum(DataType), nullable=True)
# pyre-fixme[8]: Incompatible attribute type [8]: Attribute
# `auxiliary_experiments_by_purpose` declared in class `SQAExperiment` has
# type `Optional[Dict[str, List[str]]]` but is used as type `Column[typing.Any]`
auxiliary_experiments_by_purpose: dict[str, List[str]] | None = Column(
JSONEncodedTextDict, nullable=True, default={}
)
# relationships
# Trials and experiments are mutable, so the children relationships need
# cascade="all, delete-orphan", which means if we remove or replace
# a child, the old one will be deleted.
# Use selectin loading for collections to prevent idle timeout errors
# (https://docs.sqlalchemy.org/en/13/orm/loading_relationships.html#selectin-eager-loading)
data: List[SQAData] = relationship(
"SQAData", cascade="all, delete-orphan", lazy="selectin"
)
metrics: List[SQAMetric] = relationship(
"SQAMetric", cascade="all, delete-orphan", lazy="selectin"
)
parameters: List[SQAParameter] = relationship(
"SQAParameter", cascade="all, delete-orphan", lazy="selectin"
)
parameter_constraints: List[SQAParameterConstraint] = relationship(
"SQAParameterConstraint", cascade="all, delete-orphan", lazy="selectin"
)
runners: List[SQARunner] = relationship(
"SQARunner", cascade="all, delete-orphan", lazy=False
)
trials: List[SQATrial] = relationship(
"SQATrial", cascade="all, delete-orphan", lazy="selectin"
)
generation_strategy: SQAGenerationStrategy | None = relationship(
"SQAGenerationStrategy",
backref=backref("experiment", lazy=True),
uselist=False,
lazy=True,
)
analysis_cards: List[SQAAnalysisCard] = relationship(
"SQAAnalysisCard", cascade="all, delete-orphan", lazy="selectin"
)