Source code for ax.storage.sqa_store.validation

#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

from collections.abc import Callable
from logging import Logger
from typing import Any, TypeVar

from ax.storage.sqa_store.db import SQABase
from ax.storage.sqa_store.reduced_state import GR_LARGE_MODEL_ATTRS
from ax.storage.sqa_store.sqa_classes import (
    ONLY_ONE_FIELDS,
    ONLY_ONE_METRIC_FIELDS,
    SQAMetric,
    SQAParameter,
    SQAParameterConstraint,
    SQARunner,
)
from ax.utils.common.logger import get_logger
from sqlalchemy import event
from sqlalchemy.engine import Connection
from sqlalchemy.orm.attributes import InstrumentedAttribute
from sqlalchemy.orm.base import NO_VALUE
from sqlalchemy.orm.mapper import Mapper

T = TypeVar("T")


logger: Logger = get_logger(__name__)


[docs] def listens_for_multiple( targets: list[InstrumentedAttribute], identifier: str, *args: Any, **kwargs: Any, ) -> Callable: """Analogue of SQLAlchemy `listen_for`, but applies the same listening handler function to multiple instrumented attributes. """ def wrapper(fn: Callable) -> Callable: for target in targets: event.listen(target, identifier, fn, *args, **kwargs) return fn return wrapper
# pyre-fixme[3]: Return annotation cannot be `Any`.
[docs] def consistency_exactly_one(instance: SQABase, exactly_one_fields: list[str]) -> Any: """Ensure that exactly one of `exactly_one_fields` has a value set.""" values = [getattr(instance, field) is not None for field in exactly_one_fields] if sum(values) != 1: raise ValueError( f"{instance.__class__.__name__} must have exactly one of the following " f"fields set: {', '.join(exactly_one_fields)}." )
[docs] @listens_for_multiple( targets=GR_LARGE_MODEL_ATTRS, identifier="set", # `retval=True` instruct the operation ('set' on attributes in `targets`) to use # the return value of decorated function to set the attribute. retval=True, # `propagate=True` ensures that targets with subclasses of SQA classes used by # default Ax OSS encoder inherit the event listeners. propagate=True, ) def do_not_set_existing_value_to_null( instance: SQABase, new_value: T, old_value: T, initiator_event: event.Events ) -> T: no_value = [None, NO_VALUE] if new_value in no_value and old_value not in no_value: logger.debug( f"New value for attribute is `None` or has no value, but old value " f"was set, so keeping the old value ({old_value})." ) return old_value return new_value
@event.listens_for( SQAParameter, "before_insert", ) @event.listens_for( SQAParameter, "before_update", ) # pyre-fixme[11]: Annotation `Mapper` is not defined as a type. def validate_parameter(mapper: Mapper, connection: Connection, target: SQABase) -> None: consistency_exactly_one(target, ONLY_ONE_FIELDS) @event.listens_for(SQAParameterConstraint, "before_insert") @event.listens_for(SQAParameterConstraint, "before_update") def validate_parameter_constraint( mapper: Mapper, connection: Connection, target: SQABase ) -> None: consistency_exactly_one(target, ONLY_ONE_FIELDS) @event.listens_for(SQAMetric, "before_insert") @event.listens_for(SQAMetric, "before_update") def validate_metric(mapper: Mapper, connection: Connection, target: SQABase) -> None: consistency_exactly_one(target, ONLY_ONE_FIELDS + ONLY_ONE_METRIC_FIELDS) @event.listens_for(SQARunner, "before_insert") @event.listens_for(SQARunner, "before_update") def validate_runner(mapper: Mapper, connection: Connection, target: SQABase) -> None: consistency_exactly_one(target, ["experiment_id", "trial_id"])