#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import time
from logging import INFO
from typing import List, Optional, Tuple, Type
from ax.core.base_trial import BaseTrial
from ax.core.experiment import Experiment
from ax.core.generator_run import GeneratorRun
from ax.exceptions.core import UnsupportedError
from ax.modelbridge.generation_strategy import GenerationStrategy
from ax.utils.common.executils import retry_on_exception
from ax.utils.common.logger import _round_floats_for_logging, get_logger
from ax.utils.common.typeutils import not_none
RETRY_EXCEPTION_TYPES: Tuple[Type[Exception], ...] = ()
try: # We don't require SQLAlchemy by default.
from ax.storage.sqa_store.db import init_engine_and_session_factory
from ax.storage.sqa_store.load import (
_get_experiment_id,
_get_generation_strategy_id,
_load_experiment,
_load_generation_strategy_by_experiment_name,
)
from ax.storage.sqa_store.save import (
_save_experiment,
_save_generation_strategy,
_save_new_trials,
_update_generation_strategy,
_update_trials,
)
from sqlalchemy.exc import OperationalError
from sqlalchemy.orm.exc import StaleDataError
from ax.storage.sqa_store.structs import DBSettings
# We retry on `OperationalError` if saving to DB.
RETRY_EXCEPTION_TYPES = (OperationalError, StaleDataError)
except ModuleNotFoundError: # pragma: no cover
DBSettings = None
logger = get_logger(__name__)
[docs]class WithDBSettingsBase:
"""Helper class providing methods for saving changes made to an experiment
if `db_settings` property is set to a non-None value on the instance.
"""
_db_settings: Optional[DBSettings] = None
def __init__(
self, db_settings: Optional[DBSettings] = None, logging_level: int = INFO
) -> None:
if db_settings and (not DBSettings or not isinstance(db_settings, DBSettings)):
raise ValueError(
"`db_settings` argument should be of type ax.storage.sqa_store."
"structs.DBSettings. To use `DBSettings`, you will need SQLAlchemy "
"installed in your environment (can be installed through pip)."
)
self._db_settings = db_settings
if self.db_settings_set:
init_engine_and_session_factory(
creator=self.db_settings.creator, url=self.db_settings.url
)
logger.setLevel(logging_level)
@property
def db_settings_set(self) -> bool:
"""Whether non-None DB settings are set on this instance."""
return self._db_settings is not None
@property
def db_settings(self) -> DBSettings:
"""DB settings set on this instance; guaranteed to be non-None."""
if self._db_settings is None:
raise ValueError("No DB settings are set on this instance.")
return not_none(self._db_settings)
def _get_experiment_and_generation_strategy_db_id(
self, experiment_name: str
) -> Tuple[Optional[int], Optional[int]]:
"""Retrieve DB ids of experiment by the given name and the associated
generation strategy. Each ID is None if corresponding object is not
found.
"""
if not self.db_settings_set:
return None, None
exp_id = _get_experiment_id(
experiment_name=experiment_name, decoder=self.db_settings.decoder
)
if not exp_id:
return None, None
gs_id = _get_generation_strategy_id(
experiment_name=experiment_name, decoder=self.db_settings.decoder
)
return exp_id, gs_id
def _maybe_save_experiment_and_generation_strategy(
self, experiment: Experiment, generation_strategy: GenerationStrategy
) -> Tuple[bool, bool]:
"""If DB settings are set on this `WithDBSettingsBase` instance, checks
whether given experiment and generation strategy are already saved and
saves them, if not.
Returns:
Tuple of two booleans: whether experiment was saved in the course of
this function's execution and whether generation strategy was
saved.
"""
saved_exp, saved_gs = False, False
if self.db_settings_set:
if experiment._name is None:
raise ValueError(
"Experiment must specify a name to use storage functionality."
)
exp_name = not_none(experiment.name)
exp_id, gs_id = self._get_experiment_and_generation_strategy_db_id(
experiment_name=exp_name
)
if exp_id: # Experiment in DB.
# TODO: Switch to just updating experiment when selective-field
# update is available.
logger.info(f"Experiment {exp_name} is in DB, updating it.")
self._save_experiment_to_db_if_possible(experiment=experiment)
saved_exp = True
else: # Experiment not yet in DB.
logger.info(f"Experiment {exp_name} is not yet in DB, storing it.")
self._save_experiment_to_db_if_possible(experiment=experiment)
saved_exp = True
if gs_id and generation_strategy._db_id != gs_id:
raise UnsupportedError(
"Experiment was associated with generation strategy in DB, "
f"but a new generation strategy {generation_strategy.name} "
"was provided. To use the generation strategy currently in DB,"
" instantiate scheduler via: `Scheduler.with_stored_experiment`."
)
if not gs_id or generation_strategy._db_id is None:
# There is no GS associated with experiment or the generation
# strategy passed in is different from the one associated with
# experiment currently.
logger.info(
f"Generation strategy {generation_strategy.name} is not yet in DB, "
"storing it."
)
self._save_generation_strategy_to_db_if_possible(
generation_strategy=generation_strategy
)
saved_gs = True
return saved_exp, saved_gs
def _load_experiment_and_generation_strategy(
self, experiment_name: str
) -> Tuple[Optional[Experiment], Optional[GenerationStrategy]]:
"""Loads experiment and its corresponding generation strategy from database
if DB settings are set on this `WithDBSettingsBase` instance.
Args:
experiment_name: Name of the experiment to load, used as unique
identifier by which to find the experiment.
Returns:
- Tuple of `None` and `None` if `DBSettings` are set and no experiment
exists by the given name.
- Tuple of `Experiment` and `None` if experiment exists but does not
have a generation strategy attached to it.
- Tuple of `Experiment` and `GenerationStrategy` if experiment exists
and has a generation strategy attached to it.
"""
if not self.db_settings_set:
raise ValueError("Cannot load from DB in absence of DB settings.")
start_time = time.time()
experiment = _load_experiment(experiment_name, decoder=self.db_settings.decoder)
if not isinstance(experiment, Experiment) or experiment.is_simple_experiment:
raise ValueError("Service API only supports `Experiment`.")
logger.debug(
f"Loaded experiment {experiment_name} in "
f"{_round_floats_for_logging(time.time() - start_time)} seconds."
)
try:
start_time = time.time()
generation_strategy = _load_generation_strategy_by_experiment_name(
experiment_name=experiment_name, decoder=self.db_settings.decoder
)
logger.debug(
f"Loaded generation strategy for experiment {experiment_name} in "
f"{_round_floats_for_logging(time.time() - start_time)} seconds."
)
except ValueError as err:
if "does not have a generation strategy" in str(err):
return experiment, None
raise # `ValueError` here could signify more than just absence of GS.
return experiment, generation_strategy
@retry_on_exception(
retries=3,
default_return_on_suppression=False,
exception_types=RETRY_EXCEPTION_TYPES,
)
def _save_experiment_to_db_if_possible(
self, experiment: Experiment, suppress_all_errors: bool = False
) -> bool:
"""Saves attached experiment and generation strategy if DB settings are
set on this `WithDBSettingsBase` instance.
Args:
experiment: Experiment to save new trials in DB.
suppress_all_errors: Flag for `retry_on_exception` that makes
the decorator suppress the thrown exception even if it
occurred in all the retries (exception is still logged).
Returns:
bool: Whether the experiment was saved.
"""
if self.db_settings_set:
start_time = time.time()
_save_experiment(experiment, encoder=self.db_settings.encoder)
logger.debug(
f"Saved experiment {experiment.name} in "
f"{_round_floats_for_logging(time.time() - start_time)} seconds."
)
return True
return False
@retry_on_exception(
retries=3,
default_return_on_suppression=False,
exception_types=RETRY_EXCEPTION_TYPES,
)
def _save_new_trial_to_db_if_possible(
self,
experiment: Experiment,
trial: BaseTrial,
suppress_all_errors: bool = False,
) -> bool:
"""Saves new trial on given experiment if DB settings are set on this
`WithDBSettingsBase` instance.
Args:
experiment: Experiment, on which to save new trial in DB.
trials: Newly added trial to save.
suppress_all_errors: Flag for `retry_on_exception` that makes
the decorator suppress the thrown exception even if it
occurred in all the retries (exception is still logged).
Returns:
bool: Whether the trial was saved.
"""
return self._save_new_trials_to_db_if_possible(
experiment, [trial], suppress_all_errors
)
@retry_on_exception(
retries=3,
default_return_on_suppression=False,
exception_types=RETRY_EXCEPTION_TYPES,
)
def _save_new_trials_to_db_if_possible(
self,
experiment: Experiment,
trials: List[BaseTrial],
suppress_all_errors: bool = False,
) -> bool:
"""Saves new trials on given experiment if DB settings are set on this
`WithDBSettingsBase` instance.
Args:
experiment: Experiment, on which to save new trials in DB.
trials: Newly added trials to save.
suppress_all_errors: Flag for `retry_on_exception` that makes
the decorator suppress the thrown exception even if it
occurred in all the retries (exception is still logged).
Returns:
bool: Whether the trials were saved.
"""
if self.db_settings_set:
start_time = time.time()
_save_new_trials(
experiment=experiment, trials=trials, encoder=self.db_settings.encoder
)
logger.debug(
f"Saved trials {[trial.index for trial in trials]} in "
f"{_round_floats_for_logging(time.time() - start_time)} seconds."
)
return True
return False
@retry_on_exception(
retries=3,
default_return_on_suppression=False,
exception_types=RETRY_EXCEPTION_TYPES,
)
def _save_updated_trial_to_db_if_possible(
self,
experiment: Experiment,
trial: BaseTrial,
suppress_all_errors: bool = False,
) -> bool:
"""Saves updated trials on given experiment if DB settings are set on this
`WithDBSettingsBase` instance.
Args:
experiment: Experiment, on which to save updated trials in DB.
trial: Newly updated trial to save.
suppress_all_errors: Flag for `retry_on_exception` that makes
the decorator suppress the thrown exception even if it
occurred in all the retries (exception is still logged).
Returns:
bool: Whether the trial was saved.
"""
return self._save_updated_trials_to_db_if_possible(
experiment, [trial], suppress_all_errors
)
@retry_on_exception(
retries=3,
default_return_on_suppression=False,
exception_types=RETRY_EXCEPTION_TYPES,
)
def _save_updated_trials_to_db_if_possible(
self,
experiment: Experiment,
trials: List[BaseTrial],
suppress_all_errors: bool = False,
) -> bool:
"""Saves updated trials on given experiment if DB settings are set on this
`WithDBSettingsBase` instance.
Args:
experiment: Experiment, on which to save updated trials in DB.
trials: Newly updated trials to save.
suppress_all_errors: Flag for `retry_on_exception` that makes
the decorator suppress the thrown exception even if it
occurred in all the retries (exception is still logged).
Returns:
bool: Whether the trials were saved.
"""
if self.db_settings_set:
start_time = time.time()
_update_trials(
experiment=experiment, trials=trials, encoder=self.db_settings.encoder
)
logger.debug(
f"Updated trials {[trial.index for trial in trials]} in "
f"{_round_floats_for_logging(time.time() - start_time)} seconds."
)
return True
return False
@retry_on_exception(
retries=3,
default_return_on_suppression=False,
exception_types=RETRY_EXCEPTION_TYPES,
)
def _save_generation_strategy_to_db_if_possible(
self, generation_strategy: GenerationStrategy, suppress_all_errors: bool = False
) -> bool:
"""Saves given generation strategy if DB settings are set on this
`WithDBSettingsBase` instance.
Args:
generation_strategy: Generation strategy to save in DB.
suppress_all_errors: Flag for `retry_on_exception` that makes
the decorator suppress the thrown exception even if it
occurred in all the retries (exception is still logged).
Returns:
bool: Whether the generation strategy was saved.
"""
if self.db_settings_set:
start_time = time.time()
_save_generation_strategy(
generation_strategy=generation_strategy,
encoder=self.db_settings.encoder,
)
logger.debug(
f"Saved generation strategy {generation_strategy.name} in "
f"{_round_floats_for_logging(time.time() - start_time)} seconds."
)
return True
return False
@retry_on_exception(
retries=3,
default_return_on_suppression=False,
exception_types=RETRY_EXCEPTION_TYPES,
)
def _update_generation_strategy_in_db_if_possible(
self,
generation_strategy: GenerationStrategy,
new_generator_runs: List[GeneratorRun],
suppress_all_errors: bool = False,
) -> bool:
"""Updates the given generation strategy with new generator runs (and with
new current generation step if applicable) if DB settings are set
on this `WithDBSettingsBase` instance.
Args:
generation_strategy: Generation strategy to update in DB.
new_generator_runs: New generator runs of this generation strategy
since its last save.
suppress_all_errors: Flag for `retry_on_exception` that makes
the decorator suppress the thrown exception even if it
occurred in all the retries (exception is still logged).
Returns:
bool: Whether the experiment was saved.
"""
if self.db_settings_set:
start_time = time.time()
_update_generation_strategy(
generation_strategy=generation_strategy,
generator_runs=new_generator_runs,
encoder=self.db_settings.encoder,
)
logger.debug(
f"Updated generation strategy {generation_strategy.name} in "
f"{_round_floats_for_logging(time.time() - start_time)} seconds."
)
return True
return False