Source code for ax.storage.sqa_store.db

#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

from contextlib import contextmanager, nullcontext
from os import remove as remove_file
from typing import Any, Callable, ContextManager, Generator, Optional, TypeVar

from sqlalchemy import create_engine
from sqlalchemy.engine.base import Engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import scoped_session, Session, sessionmaker


# some constants for database fields
HASH_FIELD_LENGTH: int = 32
NAME_OR_TYPE_FIELD_LENGTH: int = 100
LONG_STRING_FIELD_LENGTH: int = 255
JSON_FIELD_LENGTH: int = 4096

# by default, Text gets mapped to a TEXT field in MySQL is 2^16 - 1
# we use have MEDIUMTEXT and LONGTEXT in the MySQL db; in this case, use
# Text(MEDIUMTEXT_BYTES) or Text(LONGTEXT_BYTES). This is preferable to
# using MEDIUMTEXT and LONGTEXT directly because those are incompatible with
# SQLite that is used in unit tests.
MEDIUMTEXT_BYTES: int = 2**24 - 1
LONGTEXT_BYTES: int = 2**32 - 1

# global database variables
SESSION_FACTORY: Optional[Session] = None

# set this to false to prevent SQLAlchemy for automatically expiring objects
# on commit, which essentially makes them unusable outside of a session
# see e.g. https://stackoverflow.com/a/50272761
EXPIRE_ON_COMMIT = False

T = TypeVar("T")


[docs]class SQABase: """Metaclass for SQLAlchemy classes corresponding to core Ax classes.""" pass
Base = declarative_base(cls=SQABase)
[docs]def create_mysql_engine_from_creator( # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. creator: Callable, echo: bool = False, pool_recycle: int = 10, **kwargs: Any, ) -> Engine: """Create a SQLAlchemy engine with the MySQL dialect given a creator function. Args: creator: a callable which returns a DBAPI connection. echo: if True, set engine to be verbose. pool_recycle: number of seconds after which to recycle connections. -1 means no timeout. Default is 10 seconds. **kwargs: keyword args passed to `create_engine` Returns: Engine: SQLAlchemy engine with connection to MySQL DB. """ return create_engine( "mysql://", creator=creator, pool_recycle=pool_recycle, echo=echo, **kwargs )
[docs]def create_mysql_engine_from_url( url: str, echo: bool = False, pool_recycle: int = 10, **kwargs: Any ) -> Engine: """Create a SQLAlchemy engine with the MySQL dialect given a database url. Args: url: a database url that can include username, password, hostname, database name as well as optional keyword arguments for additional configuration. e.g. `dialect+driver://username:password@host:port/database`. echo: if True, set engine to be verbose. pool_recycle: number of seconds after which to recycle connections. -1 means no timeout. Default is 10 seconds. **kwargs: keyword args passed to `create_engine` Returns: Engine: SQLAlchemy engine with connection to MySQL DB. """ return create_engine(url, pool_recycle=pool_recycle, echo=echo, **kwargs)
[docs]def create_test_engine(path: Optional[str] = None, echo: bool = True) -> Engine: """Creates a SQLAlchemy engine object for use in unit tests. Args: path: if None, use in-memory SQLite; else attempt to create a SQLite DB in the path provided. echo: if True, set engine to be verbose. Returns: Engine: an instance of SQLAlchemy engine. """ if path is None: # From SQLALchemy docs: # "To use a SQLite :memory: database, specify an empty URL: # `engine = create_engine('sqlite://')`" # (https://docs.sqlalchemy.org/en/14/core/engines.html#sqlite) db_path = "sqlite://" else: db_path = "sqlite:///{path}".format(path=path) return create_engine(db_path, echo=echo)
[docs]def init_engine_and_session_factory( url: Optional[str] = None, # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. creator: Optional[Callable] = None, echo: bool = False, force_init: bool = False, **kwargs: Any, ) -> None: """Initialize the global engine and SESSION_FACTORY for SQLAlchemy. The initialization needs to only happen once. Note that it is possible to re-initialize the engine by setting the `force_init` flag to True, but this should only be used if you are absolutely certain that you know what you are doing. Args: url: a database url that can include username, password, hostname, database name as well as optional keyword arguments for additional configuration. e.g. `dialect+driver://username:password@host:port/database`. Either this argument or `creator` argument must be specified. creator: a callable which returns a DBAPI connection. Either this argument or `url` argument must be specified. echo: if True, logging for engine is enabled. force_init: if True, allows re-initializing engine and session factory. **kwargs: keyword arguments passed to `create_mysql_engine_from_creator` """ global SESSION_FACTORY if SESSION_FACTORY is not None: if force_init: SESSION_FACTORY.bind.dispose() else: return if url is not None: engine = create_mysql_engine_from_url(url=url, echo=echo, **kwargs) elif creator is not None: engine = create_mysql_engine_from_creator(creator=creator, echo=echo, **kwargs) else: raise ValueError("Must specify either `url` or `creator`.") SESSION_FACTORY = scoped_session( sessionmaker(bind=engine, expire_on_commit=EXPIRE_ON_COMMIT) )
[docs]def init_test_engine_and_session_factory( tier_or_path: Optional[str] = None, echo: bool = False, force_init: bool = False, **kwargs: Any, ) -> None: """Initialize the global engine and SESSION_FACTORY for SQLAlchemy, using an in-memory SQLite database. The initialization needs to only happen once. Note that it is possible to re-initialize the engine by setting the `force_init` flag to True, but this should only be used if you are absolutely certain that you know what you are doing. Args: tier_or_path: the name of the DB tier. echo: if True, logging for engine is enabled. force_init: if True, allows re-initializing engine and session factory. **kwargs: keyword arguments passed to `create_mysql_engine_from_creator` """ global SESSION_FACTORY if SESSION_FACTORY is not None: if force_init: SESSION_FACTORY.bind.dispose() else: return engine = create_test_engine(path=tier_or_path, echo=echo) create_all_tables(engine) SESSION_FACTORY = scoped_session( sessionmaker(bind=engine, expire_on_commit=EXPIRE_ON_COMMIT) )
[docs]def remove_test_db_file(tier_or_path: str) -> None: """Remove the test DB file from system, useful for cleanup in tests.""" remove_file(tier_or_path)
[docs]def create_all_tables(engine: Engine) -> None: """Create all tables that inherit from Base. Args: engine: a SQLAlchemy engine with a connection to a MySQL or SQLite DB. Note: In order for all tables to be correctly created, all modules that define a mapped class that inherits from `Base` must be imported. """ if ( engine.dialect.name == "mysql" and engine.dialect.default_schema_name == "adaptive_experiment" ): raise Exception("Cannot mutate tables in XDB. Use AOSC.") Base.metadata.create_all(engine)
[docs]def get_session() -> Session: """Fetch a SQLAlchemy session with a connection to a DB. Unless `init_engine_and_session_factory` is called first with custom args, this will automatically initialize a connection to `xdb.adaptive_experiment`. Returns: Session: an instance of a SQLAlchemy session. """ global SESSION_FACTORY if SESSION_FACTORY is None: init_engine_and_session_factory() assert SESSION_FACTORY is not None # pyre-fixme[29]: `Session` is not a function. return SESSION_FACTORY()
[docs]def get_engine() -> Engine: """Fetch a SQLAlchemy engine, if already initialized. If not initialized, need to either call `init_engine_and_session_factory` or `get_session` explicitly. Returns: Engine: an instance of a SQLAlchemy engine with a connection to a DB. """ global SESSION_FACTORY if SESSION_FACTORY is None: raise ValueError("Engine must be initialized first.") return SESSION_FACTORY.bind
[docs]@contextmanager def session_scope() -> Generator[Session, None, None]: """Provide a transactional scope around a series of operations.""" session = get_session() try: yield session session.commit() except Exception: session.rollback() raise finally: session.close()
[docs]def optional_session_scope( session: Optional[Session] = None, ) -> ContextManager[Session]: if session is not None: return nullcontext(session) return session_scope()