Source code for ax.storage.sqa_store.load
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Optional
from ax.core.experiment import Experiment
from ax.modelbridge.generation_strategy import GenerationStrategy
from ax.storage.sqa_store.db import session_scope
from ax.storage.sqa_store.decoder import Decoder
from ax.storage.sqa_store.sqa_classes import SQAExperiment, SQAGenerationStrategy
from ax.storage.sqa_store.sqa_config import SQAConfig
# ---------------------------- Loading `Experiment`. ---------------------------
[docs]def load_experiment(
experiment_name: str, config: Optional[SQAConfig] = None
) -> Experiment:
"""Load experiment by name (uses default SQAConfig)."""
config = config or SQAConfig()
decoder = Decoder(config=config)
return _load_experiment(experiment_name=experiment_name, decoder=decoder)
def _load_experiment(experiment_name: str, decoder: Decoder) -> Experiment:
"""Load experiment by name, using given Decoder instance.
1) Get SQLAlchemy object from DB.
2) Convert to corresponding Ax object.
"""
# Convert SQA to user-facing class outside of session scope to avoid timeouts
return decoder.experiment_from_sqa(
experiment_sqa=_get_experiment_sqa(experiment_name=experiment_name)
)
def _get_experiment_sqa(experiment_name: str) -> SQAExperiment:
"""Obtains SQLAlchemy experiment object from DB."""
with session_scope() as session:
sqa_experiment = (
session.query(SQAExperiment).filter_by(name=experiment_name).one_or_none()
)
if sqa_experiment is None:
raise ValueError(f"Experiment `{experiment_name}` not found.")
return sqa_experiment
# ------------------------ Loading `GenerationStrategy`. -----------------------
[docs]def load_generation_strategy_by_experiment_name(
experiment_name: str, config: Optional[SQAConfig] = None
) -> GenerationStrategy:
"""Finds a generation strategy attached to an experiment specified by a name
and restores it from its corresponding SQA object.
"""
config = config or SQAConfig()
decoder = Decoder(config=config)
return _load_generation_strategy_by_experiment_name(
experiment_name=experiment_name, decoder=decoder
)
[docs]def load_generation_strategy_by_id(
gs_id: int, config: Optional[SQAConfig] = None
) -> GenerationStrategy:
"""Finds a generation strategy stored by a given ID and restores it."""
config = config or SQAConfig()
decoder = Decoder(config=config)
return _load_generation_strategy_by_id(gs_id=gs_id, decoder=decoder)
def _load_generation_strategy_by_id(gs_id: int, decoder: Decoder) -> GenerationStrategy:
"""Finds a generation strategy stored by a given ID and restores it."""
with session_scope() as session:
gs_sqa = session.query(SQAGenerationStrategy).filter_by(id=gs_id).one_or_none()
if gs_sqa is None:
raise ValueError(f"Generation strategy with ID #{gs_id} not found.")
return decoder.generation_strategy_from_sqa(gs_sqa=gs_sqa)
def _load_generation_strategy_by_experiment_name(
experiment_name: str, decoder: Decoder
) -> GenerationStrategy:
"""Load a generation strategy attached to an experiment specified by a name,
using given Decoder instance.
1) Get SQLAlchemy object from DB.
2) Convert to corresponding Ax object.
"""
with session_scope() as session:
gs_sqa = (
session.query(SQAGenerationStrategy)
.join(SQAExperiment.generation_strategy)
.filter(SQAExperiment.name == experiment_name)
.one_or_none()
)
if gs_sqa is None:
raise ValueError(
f"Experiment {experiment_name} does not have a generation strategy "
"attached to it."
)
return decoder.generation_strategy_from_sqa(gs_sqa=gs_sqa)