Source code for ax.storage.sqa_store.db

#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

from __future__ import annotations

from collections.abc import Callable, Generator

from contextlib import contextmanager
from typing import Any, TypeVar

from sqlalchemy import create_engine
from sqlalchemy.engine.base import Engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import scoped_session, Session, sessionmaker


# some constants for database fields
HASH_FIELD_LENGTH: int = 32
NAME_OR_TYPE_FIELD_LENGTH: int = 100
LONG_STRING_FIELD_LENGTH: int = 255
JSON_FIELD_LENGTH: int = 4096

# by default, Text gets mapped to a TEXT field in MySQL is 2^16 - 1
# we use have MEDIUMTEXT and LONGTEXT in the MySQL db; in this case, use
# Text(MEDIUMTEXT_BYTES) or Text(LONGTEXT_BYTES). This is preferable to
# using MEDIUMTEXT and LONGTEXT directly because those are incompatible with
# SQLite that is used in unit tests.
MEDIUMTEXT_BYTES: int = 2**24 - 1
LONGTEXT_BYTES: int = 2**32 - 1

# global database variables
SESSION_FACTORY: Session | None = None

# set this to false to prevent SQLAlchemy for automatically expiring objects
# on commit, which essentially makes them unusable outside of a session
# see e.g. https://stackoverflow.com/a/50272761
EXPIRE_ON_COMMIT = False

T = TypeVar("T")


[docs] class SQABase: """Metaclass for SQLAlchemy classes corresponding to core Ax classes.""" __allow_unmapped__ = True __table_args__ = {"extend_existing": True} pass
Base = declarative_base(cls=SQABase)
[docs] def create_mysql_engine_from_creator( creator: Callable, echo: bool = False, pool_recycle: int = 10, **kwargs: Any, ) -> Engine: """Create a SQLAlchemy engine with the MySQL dialect given a creator function. Args: creator: a callable which returns a DBAPI connection. echo: if True, set engine to be verbose. pool_recycle: number of seconds after which to recycle connections. -1 means no timeout. Default is 10 seconds. **kwargs: keyword args passed to `create_engine` Returns: Engine: SQLAlchemy engine with connection to MySQL DB. """ return create_engine( "mysql://", creator=creator, pool_recycle=pool_recycle, echo=echo, **kwargs )
[docs] def create_mysql_engine_from_url( url: str, echo: bool = False, pool_recycle: int = 10, **kwargs: Any ) -> Engine: """Create a SQLAlchemy engine with the MySQL dialect given a database url. Args: url: a database url that can include username, password, hostname, database name as well as optional keyword arguments for additional configuration. e.g. `dialect+driver://username:password@host:port/database`. echo: if True, set engine to be verbose. pool_recycle: number of seconds after which to recycle connections. -1 means no timeout. Default is 10 seconds. **kwargs: keyword args passed to `create_engine` Returns: Engine: SQLAlchemy engine with connection to MySQL DB. """ return create_engine(url, pool_recycle=pool_recycle, echo=echo, **kwargs)
[docs] def create_test_engine(path: str | None = None, echo: bool = True) -> Engine: """Creates a SQLAlchemy engine object for use in unit tests. Args: path: if None, use in-memory SQLite; else attempt to create a SQLite DB in the path provided. echo: if True, set engine to be verbose. Returns: Engine: an instance of SQLAlchemy engine. """ if path is None: # From SQLALchemy docs: # "To use a SQLite :memory: database, specify an empty URL: # `engine = create_engine('sqlite://')`" # (https://docs.sqlalchemy.org/en/14/core/engines.html#sqlite) db_path = "sqlite://" else: db_path = f"sqlite:///{path}" return create_engine(db_path, echo=echo)
[docs] def init_engine_and_session_factory( url: str | None = None, creator: Callable | None = None, echo: bool = False, force_init: bool = False, **kwargs: Any, ) -> None: """Initialize the global engine and SESSION_FACTORY for SQLAlchemy. The initialization needs to only happen once. Note that it is possible to re-initialize the engine by setting the `force_init` flag to True, but this should only be used if you are absolutely certain that you know what you are doing. Args: url: a database url that can include username, password, hostname, database name as well as optional keyword arguments for additional configuration. e.g. `dialect+driver://username:password@host:port/database`. Either this argument or `creator` argument must be specified. creator: a callable which returns a DBAPI connection. Either this argument or `url` argument must be specified. echo: if True, logging for engine is enabled. force_init: if True, allows re-initializing engine and session factory. **kwargs: keyword arguments passed to `create_mysql_engine_from_creator` """ global SESSION_FACTORY if SESSION_FACTORY is not None: if force_init: SESSION_FACTORY.bind.dispose() else: return if url is not None: engine = create_mysql_engine_from_url(url=url, echo=echo, **kwargs) elif creator is not None: engine = create_mysql_engine_from_creator(creator=creator, echo=echo, **kwargs) else: raise ValueError("Must specify either `url` or `creator`.") SESSION_FACTORY = scoped_session( sessionmaker(bind=engine, expire_on_commit=EXPIRE_ON_COMMIT) )
[docs] def init_test_engine_and_session_factory( tier_or_path: str | None = None, echo: bool = False, force_init: bool = False, **kwargs: Any, ) -> None: """Initialize the global engine and SESSION_FACTORY for SQLAlchemy, using an in-memory SQLite database. The initialization needs to only happen once. Note that it is possible to re-initialize the engine by setting the `force_init` flag to True, but this should only be used if you are absolutely certain that you know what you are doing. Args: tier_or_path: the name of the DB tier. echo: if True, logging for engine is enabled. force_init: if True, allows re-initializing engine and session factory. **kwargs: keyword arguments passed to `create_mysql_engine_from_creator` """ global SESSION_FACTORY if SESSION_FACTORY is not None: if force_init: SESSION_FACTORY.bind.dispose() else: return engine = create_test_engine(path=tier_or_path, echo=echo) create_all_tables(engine) SESSION_FACTORY = scoped_session( sessionmaker(bind=engine, expire_on_commit=EXPIRE_ON_COMMIT) )
[docs] def create_all_tables(engine: Engine) -> None: """Create all tables that inherit from Base. Args: engine: a SQLAlchemy engine with a connection to a MySQL or SQLite DB. Note: In order for all tables to be correctly created, all modules that define a mapped class that inherits from `Base` must be imported. """ if engine.dialect.name == "mysql" and engine.dialect.default_schema_name == "ax": raise ValueError( "The open-source Ax table creation is likely not applicable in this case," + "please contact the Adaptive Experimentation team if you need help." ) Base.metadata.create_all(engine)
[docs] def get_session() -> Session: """Fetch a SQLAlchemy session with a connection to a DB. Returns: Session: an instance of a SQLAlchemy session. """ global SESSION_FACTORY if SESSION_FACTORY is None: init_engine_and_session_factory() assert SESSION_FACTORY is not None # pyre-fixme[29]: `Session` is not a function. return SESSION_FACTORY()
[docs] def get_engine() -> Engine: """Fetch a SQLAlchemy engine, if already initialized. If not initialized, need to either call `init_engine_and_session_factory` or `get_session` explicitly. Returns: Engine: an instance of a SQLAlchemy engine with a connection to a DB. """ global SESSION_FACTORY if SESSION_FACTORY is None: raise ValueError("Engine must be initialized first.") return SESSION_FACTORY.bind
[docs] @contextmanager def session_scope() -> Generator[Session, None, None]: """Provide a transactional scope around a series of operations.""" session = get_session() try: yield session session.commit() except Exception: session.rollback() raise finally: session.close()