Source code for ax.storage.sqa_store.load

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from typing import Optional

from ax.core.experiment import Experiment
from ax.modelbridge.generation_strategy import GenerationStrategy
from ax.storage.sqa_store.db import session_scope
from ax.storage.sqa_store.decoder import Decoder
from ax.storage.sqa_store.sqa_classes import SQAExperiment
from ax.storage.sqa_store.sqa_config import SQAConfig


# ---------------------------- Loading `Experiment`. ---------------------------


[docs]def load_experiment( experiment_name: str, config: Optional[SQAConfig] = None ) -> Experiment: """Load experiment by name (uses default SQAConfig).""" config = config or SQAConfig() decoder = Decoder(config=config) return _load_experiment(experiment_name=experiment_name, decoder=decoder)
def _load_experiment(experiment_name: str, decoder: Decoder) -> Experiment: """Load experiment by name, using given Decoder instance. 1) Get SQLAlchemy object from DB. 2) Convert to corresponding Ax object. """ # Convert SQA to user-facing class outside of session scope to avoid timeouts return decoder.experiment_from_sqa( experiment_sqa=_get_experiment_sqa( experiment_name=experiment_name, decoder=decoder ) ) def _get_experiment_sqa(experiment_name: str, decoder: Decoder) -> SQAExperiment: """Obtains SQLAlchemy experiment object from DB.""" exp_sqa_class = decoder.config.class_to_sqa_class[Experiment] with session_scope() as session: sqa_experiment = ( session.query(exp_sqa_class).filter_by(name=experiment_name).one_or_none() ) if sqa_experiment is None: raise ValueError(f"Experiment '{experiment_name}' not found.") return sqa_experiment # pyre-ignore[7] def _get_experiment_id(experiment_name: str, decoder: Decoder) -> Optional[int]: """Get DB ID of the experiment by the given name if its in DB, return None otherwise. """ exp_sqa_class = decoder.config.class_to_sqa_class[Experiment] with session_scope() as session: sqa_experiment_id = ( session.query(exp_sqa_class.id) # pyre-ignore .filter_by(name=experiment_name) .one_or_none() ) if sqa_experiment_id is None: return None return sqa_experiment_id[0] # ------------------------ Loading `GenerationStrategy`. -----------------------
[docs]def load_generation_strategy_by_experiment_name( experiment_name: str, config: Optional[SQAConfig] = None ) -> GenerationStrategy: """Finds a generation strategy attached to an experiment specified by a name and restores it from its corresponding SQA object. """ config = config or SQAConfig() decoder = Decoder(config=config) return _load_generation_strategy_by_experiment_name( experiment_name=experiment_name, decoder=decoder )
[docs]def load_generation_strategy_by_id( gs_id: int, config: Optional[SQAConfig] = None ) -> GenerationStrategy: """Finds a generation strategy stored by a given ID and restores it.""" config = config or SQAConfig() decoder = Decoder(config=config) return _load_generation_strategy_by_id(gs_id=gs_id, decoder=decoder)
def _load_generation_strategy_by_id(gs_id: int, decoder: Decoder) -> GenerationStrategy: """Finds a generation strategy stored by a given ID and restores it.""" gs_sqa_class = decoder.config.class_to_sqa_class[GenerationStrategy] with session_scope() as session: gs_sqa = session.query(gs_sqa_class).filter_by(id=gs_id).one_or_none() if gs_sqa is None: raise ValueError(f"Generation strategy with ID #{gs_id} not found.") return decoder.generation_strategy_from_sqa(gs_sqa=gs_sqa) # pyre-ignore[6] def _load_generation_strategy_by_experiment_name( experiment_name: str, decoder: Decoder ) -> GenerationStrategy: """Load a generation strategy attached to an experiment specified by a name, using given Decoder instance. 1) Get SQLAlchemy object from DB. 2) Convert to corresponding Ax object. """ exp_sqa_class = decoder.config.class_to_sqa_class[Experiment] gs_sqa_class = decoder.config.class_to_sqa_class[GenerationStrategy] with session_scope() as session: gs_sqa = ( session.query(gs_sqa_class) .join(exp_sqa_class.generation_strategy) # pyre-ignore[16] # pyre-fixme[16]: `SQABase` has no attribute `name`. .filter(exp_sqa_class.name == experiment_name) .one_or_none() ) if gs_sqa is None: raise ValueError( f"Experiment {experiment_name} does not have a generation strategy " "attached to it." ) # pyre-fixme[6]: Expected `SQAGenerationStrategy` for 1st param but got `SQABase`. return decoder.generation_strategy_from_sqa(gs_sqa=gs_sqa) def _get_generation_strategy_id( experiment_name: str, decoder: Decoder ) -> Optional[int]: """Get DB ID of the generation strategy, associated with the experiment with the given name if its in DB, return None otherwise. """ exp_sqa_class = decoder.config.class_to_sqa_class[Experiment] gs_sqa_class = decoder.config.class_to_sqa_class[GenerationStrategy] with session_scope() as session: sqa_gs_id = ( session.query(gs_sqa_class.id) # pyre-ignore[16] .join(exp_sqa_class.generation_strategy) # pyre-ignore[16] # pyre-fixme[16]: `SQABase` has no attribute `name`. .filter(exp_sqa_class.name == experiment_name) .one_or_none() ) if sqa_gs_id is None: return None return sqa_gs_id[0]