Source code for ax.storage.sqa_store.validation

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.

from typing import Any, List

from ax.storage.sqa_store.db import SQABase
from ax.storage.sqa_store.sqa_classes import (
    ONLY_ONE_FIELDS,
    SQAMetric,
    SQAParameter,
    SQAParameterConstraint,
    SQARunner,
)
from sqlalchemy import event
from sqlalchemy.engine import Connection
from sqlalchemy.orm.mapper import Mapper


[docs]def consistency_exactly_one(instance: SQABase, exactly_one_fields: List[str]) -> Any: """Ensure that exactly one of `exactly_one_fields` has a value set.""" values = [getattr(instance, field) is not None for field in exactly_one_fields] if sum(values) != 1: raise ValueError( f"{instance.__class__.__name__} must have exactly one of the following " f"fields set: {', '.join(exactly_one_fields)}." )
[docs]@event.listens_for(SQAParameter, "before_insert") @event.listens_for(SQAParameter, "before_update") def validate_parameter(mapper: Mapper, connection: Connection, target: SQABase) -> None: consistency_exactly_one(target, ONLY_ONE_FIELDS)
[docs]@event.listens_for(SQAParameterConstraint, "before_insert") @event.listens_for(SQAParameterConstraint, "before_update") def validate_parameter_constraint( mapper: Mapper, connection: Connection, target: SQABase ) -> None: consistency_exactly_one(target, ONLY_ONE_FIELDS)
[docs]@event.listens_for(SQAMetric, "before_insert") @event.listens_for(SQAMetric, "before_update") def validate_metric(mapper: Mapper, connection: Connection, target: SQABase) -> None: consistency_exactly_one(target, ONLY_ONE_FIELDS)
[docs]@event.listens_for(SQARunner, "before_insert") @event.listens_for(SQARunner, "before_update") def validate_runner(mapper: Mapper, connection: Connection, target: SQABase) -> None: consistency_exactly_one(target, ["experiment_id", "trial_id"])