Source code for ax.storage.sqa_store.db

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.

from contextlib import contextmanager
from datetime import datetime
from enum import Enum
from typing import Any, Callable, Generator, List, Optional, TypeVar

import numpy as np
from ax.exceptions.storage import ImmutabilityError
from ax.storage.sqa_store.utils import is_foreign_key_field
from ax.utils.common.equality import datetime_equals, equality_typechecker
from sqlalchemy import create_engine, inspect
from sqlalchemy.engine.base import Engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import Session, scoped_session, sessionmaker


# some constants for database fields
HASH_FIELD_LENGTH: int = 32
NAME_OR_TYPE_FIELD_LENGTH: int = 100
LONG_STRING_FIELD_LENGTH: int = 255
JSON_FIELD_LENGTH: int = 4096

# by default, Text gets mapped to a TEXT field in MySQL is 2^16 - 1
# we use have MEDIUMTEXT and LONGTEXT in the MySQL db; in this case, use
# Text(MEDIUMTEXT_BYTES) or Text(LONGTEXT_BYTES). This is preferable to
# using MEDIUMTEXT and LONGTEXT directly because those are incompatible with
# SQLite that is used in unit tests.
MEDIUMTEXT_BYTES: int = 2 ** 24 - 1
LONGTEXT_BYTES: int = 2 ** 32 - 1

# global database variables
Ax_PROD_TIER: str = "xdb.adaptive_experiment"
SESSION_FACTORY: Optional[Session] = None

# set this to false to prevent SQLAlchemy for automatically expiring objects
# on commit, which essentially makes them unusable outside of a session
# see e.g. https://stackoverflow.com/a/50272761
EXPIRE_ON_COMMIT = False

T = TypeVar("T")


[docs]class SQABase: """Metaclass for SQLAlchemy classes corresponding to core Ax classes.""" @property def attributes(self): """Return a list of the column attributes and relationship fields on this SQABase instance. Used for iterating over the fields to determine equality, perform updates, etc. """ mapper = inspect(self).mapper attrs = [c.key for c in mapper.column_attrs] # exclude backpopulated relationships; those will be accounted for on the # owning class relationships = [c.key for c in mapper.relationships if not c.back_populates] return attrs + relationships
[docs] @staticmethod def list_equals(l1: List[T], l2: List[T]) -> bool: """Compare equality of two lists. Assumptions: -- The lists do not contain duplicates Checking equality is then the same as checking that the lists are the same length, and that one is a subset of the other. """ if len(l1) != len(l2): return False for x in l1: for y in l2: if type(x) != type(y): equal = False if isinstance(x, SQABase): # pyre-fixme[6]: Expected `SQABase` for 1st param but got `T`. equal = x.equals(y) elif isinstance(x, (int, float, str, bool, dict, Enum)): # pyre-fixme[29]: `Union[Callable[[object], bool], # Callable[[object], bool], Callable[[object], bool], # Callable[[object], bool], Callable[[object], bool], # Callable[[object], bool]]` is not a function. equal = x == y else: raise ValueError( f"Calling list_equals on unsupported types: " f"{type(x) and {type(y)}}" ) if equal: break else: return False return True
[docs] @staticmethod def list_update(l1: List[T], l2: List[T]) -> List[T]: """Given an existing list (`l1`) and an new version (`l2`): -- update the existing items in `l1` that have matching items in `l2` -- delete existing items in `l1` that don't have matching items in `l2` -- add items in `l2` that don't exist in `l1` e.g. list_update([1,2,3], [1,5]) => [1,5] list_update([Arm(name="0_0")], [Arm(name="0_0"), Arm(name="0_1")]) => [Arm(name="0_0"), Arm(name="0_1")] where Arm(name="0_0") has been updated, not replaced, so that we don't delete/recreate the DB row """ if not l1 and not l2: return l1 types = [type(x) for x in l1 + l2] if not all(x == types[0] for x in types): raise ValueError( "Cannot call `list_update` on lists that contain " "multiple different types." ) type_ = types[0] if type_ in [int, float, str, bool, dict, Enum]: # No need to do a special update here; just return the new list return l2 if not issubclass(type_, SQABase): raise ValueError(f"Calling list_update on unsupported type {type_}.") unique_id = getattr(type_, "unique_id", None) if unique_id is None: return SQABase.list_update_without_unique_id(l1, l2) l1_dict = {getattr(x, unique_id): x for x in l1} l2_dict = {getattr(x, unique_id): x for x in l2} if len(l1_dict) != len(l1) or len(l2_dict) != len(l2): # If unique_ids aren't actually unique (could happen if all values # are None), act as if there are no unique ids at all return SQABase.list_update_without_unique_id(l1, l2) # pragma: no cover new_list = [] for key, new_val in l2_dict.items(): # For each item in the new list, try to find a match in the old list. if key in l1_dict: # If there is a matching item in the old list, update it. old_val = l1_dict[key] # pyre-fixme[16]: `Variable[T]` has no attribute `update`. old_val.update(new_val) new_list.append(old_val) else: # If there is no matching item, append the new item. new_list.append(new_val) return new_list
[docs] @staticmethod def list_update_without_unique_id(l1: List[T], l2: List[T]) -> List[T]: """Merge a new list (`l2`) into an existing list (`l1`) This method works for lists whose items do not have a unique_id field. If the lists are equal, return the old one. Else, return the new one. """ if SQABase.list_equals(l1, l2): return l1 return l2
@equality_typechecker def equals(self, other: "SQABase") -> bool: """Check if `other` equals `self.`""" for field in self.attributes: if field in ["id", "_sa_instance_state"] or is_foreign_key_field(field): # We don't want to perform equality checks on foreign key fields, # since our equality checks are used to determine whether or not # to a new object is the same as an existing one. The new object # will always have None for its foreign key fields, because it # hasn't been inserted into the database yet. continue if not self.fields_equal(other, field): return False return True
[docs] def update(self, other: "SQABase") -> None: """Merge `other` into `self.`""" ignore_during_update_fields = set( getattr(self, "ignore_during_update_fields", []) + ["id", "_sa_instance_state"] ) immutable_fields = set(getattr(self, "immutable_fields", [])) for field in self.attributes: if field in immutable_fields: if self.fields_equal(other, field): continue raise ImmutabilityError( f"Cannot change `{field}` of {self.__class__.__name__}." ) if ( field in ignore_during_update_fields # We don't want to update foreign key fields, e.g. experiment_id. # The new object will always have a value of None for this field, # but we don't want to overwrite the value on the existing object. or is_foreign_key_field(field) ): continue self.update_field(other, field)
[docs] def update_field(self, other: "SQABase", field: str) -> None: """Update `field` on `self` to be equal to `field` on `other`.""" self_val = getattr(self, field) other_val = getattr(other, field) if isinstance(self_val, list) and isinstance(other_val, list): other_val = SQABase.list_update(self_val, other_val) elif isinstance(self_val, SQABase) and isinstance(other_val, SQABase): self_val.update(other_val) other_val = self_val elif self.fields_equal(other, field): return setattr(self, field, other_val)
[docs] def fields_equal(self, other: "SQABase", field: str) -> bool: """Check if `field` on `self` is equal to `field` on `other`.""" self_val = getattr(self, field) other_val = getattr(other, field) if type(self_val) != type(other_val): return False if isinstance(self_val, list): return SQABase.list_equals(self_val, other_val) elif isinstance(self_val, SQABase): return self_val.equals(other_val) elif isinstance(self_val, datetime): return datetime_equals(self_val, other_val) elif isinstance(self_val, float): return np.isclose(self_val, other_val) else: return self_val == other_val
Base = declarative_base(cls=SQABase)
[docs]def create_mysql_engine_from_creator( creator: Callable, echo: bool = False, pool_recycle: int = 10, **kwargs: Any ) -> Engine: """Create a SQLAlchemy engine with the MySQL dialect given a creator function. Args: creator: a callable which returns a DBAPI connection. echo: if True, set engine to be verbose. pool_recycle: number of seconds after which to recycle connections. -1 means no timeout. Default is 10 seconds. **kwargs: keyword args passed to `create_engine` Returns: Engine: SQLAlchemy engine with connection to MySQL DB. """ return create_engine( "mysql://", creator=creator, pool_recycle=pool_recycle, echo=echo, **kwargs )
[docs]def create_mysql_engine_from_url( url: str, echo: bool = False, pool_recycle: int = 10, **kwargs: Any ) -> Engine: """Create a SQLAlchemy engine with the MySQL dialect given a database url. Args: url: a database url that can include username, password, hostname, database name as well as optional keyword arguments for additional configuration. e.g. `dialect+driver://username:password@host:port/database`. echo: if True, set engine to be verbose. pool_recycle: number of seconds after which to recycle connections. -1 means no timeout. Default is 10 seconds. **kwargs: keyword args passed to `create_engine` Returns: Engine: SQLAlchemy engine with connection to MySQL DB. """ return create_engine( name_or_url=url, pool_recycle=pool_recycle, echo=echo, **kwargs )
[docs]def create_test_engine(path: Optional[str] = None, echo: bool = True) -> Engine: """Creates a SQLAlchemy engine object for use in unit tests. Args: path: if None, use in-memory SQLite; else attempt to create a SQLite DB in the path provided. echo: if True, set engine to be verbose. Returns: Engine: an instance of SQLAlchemy engine. """ if path is None: db_path = "sqlite://" else: db_path = "sqlite:///{path}".format(path=path) return create_engine(db_path, echo=echo)
[docs]def init_engine_and_session_factory( url: Optional[str] = None, creator: Optional[Callable] = None, echo: bool = False, force_init: bool = False, **kwargs: Any, ) -> None: """Initialize the global engine and SESSION_FACTORY for SQLAlchemy. The initialization needs to only happen once. Note that it is possible to re-initialize the engine by setting the `force_init` flag to True, but this should only be used if you are absolutely certain that you know what you are doing. Args: url: a database url that can include username, password, hostname, database name as well as optional keyword arguments for additional configuration. e.g. `dialect+driver://username:password@host:port/database`. Either this argument or `creator` argument must be specified. creator: a callable which returns a DBAPI connection. Either this argument or `url` argument must be specified. echo: if True, logging for engine is enabled. force_init: if True, allows re-initializing engine and session factory. **kwargs: keyword arguments passed to `create_mysql_engine_from_creator` """ global SESSION_FACTORY if SESSION_FACTORY is not None: if force_init: SESSION_FACTORY.bind.dispose() else: return # pragma: no cover if url is not None: engine = create_mysql_engine_from_url(url=url, echo=echo, **kwargs) elif creator is not None: engine = create_mysql_engine_from_creator(creator=creator, echo=echo, **kwargs) else: raise ValueError("Must specify either `url` or `creator`.") # pragma: no cover SESSION_FACTORY = scoped_session( sessionmaker(bind=engine, expire_on_commit=EXPIRE_ON_COMMIT) )
[docs]def init_test_engine_and_session_factory( tier_or_path: Optional[str] = None, echo: bool = False, force_init: bool = False, **kwargs: Any, ) -> None: """Initialize the global engine and SESSION_FACTORY for SQLAlchemy, using an in-memory SQLite database. The initialization needs to only happen once. Note that it is possible to re-initialize the engine by setting the `force_init` flag to True, but this should only be used if you are absolutely certain that you know what you are doing. Args: tier_or_path: the name of the DB tier. echo: if True, logging for engine is enabled. force_init: if True, allows re-initializing engine and session factory. **kwargs: keyword arguments passed to `create_mysql_engine_from_creator` """ global SESSION_FACTORY if SESSION_FACTORY is not None: if force_init: SESSION_FACTORY.bind.dispose() else: return engine = create_test_engine(path=tier_or_path, echo=echo) create_all_tables(engine) SESSION_FACTORY = scoped_session( sessionmaker(bind=engine, expire_on_commit=EXPIRE_ON_COMMIT) )
[docs]def create_all_tables(engine: Engine) -> None: """Create all tables that inherit from Base. Args: engine: a SQLAlchemy engine with a connection to a MySQL or SQLite DB. Note: In order for all tables to be correctly created, all modules that define a mapped class that inherits from `Base` must be imported. """ if ( engine.dialect.name == "mysql" and engine.dialect.default_schema_name == "adaptive_experiment" ): raise Exception("Cannot mutate tables in XDB. Use AOSC.") # pragma: no cover Base.metadata.create_all(engine)
[docs]def get_session() -> Session: """Fetch a SQLAlchemy session with a connection to a DB. Unless `init_engine_and_session_factory` is called first with custom args, this will automatically initialize a connection to `xdb.adaptive_experiment`. Returns: Session: an instance of a SQLAlchemy session. """ global SESSION_FACTORY if SESSION_FACTORY is None: init_engine_and_session_factory() # pragma: no cover assert SESSION_FACTORY is not None # pyre-fixme[29]: `Session` is not a function. return SESSION_FACTORY()
[docs]def get_engine() -> Engine: """Fetch a SQLAlchemy engine, if already initialized. If not initialized, need to either call `init_engine_and_session_factory` or `get_session` explicitly. Returns: Engine: an instance of a SQLAlchemy engine with a connection to a DB. """ global SESSION_FACTORY if SESSION_FACTORY is None: raise ValueError("Engine must be initialized first.") # pragma: no cover return SESSION_FACTORY.bind
[docs]@contextmanager def session_scope() -> Generator[Session, None, None]: """Provide a transactional scope around a series of operations.""" session = get_session() try: yield session session.commit() except Exception: # pragma: no cover session.rollback() # pragma: no cover raise # pragma: no cover finally: session.close()