Source code for ax.storage.sqa_store.sqa_enum
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# pyre-strict
import enum
from typing import Any
from ax.storage.sqa_store.db import NAME_OR_TYPE_FIELD_LENGTH
from sqlalchemy import types
[docs]
class BaseNullableEnum(types.TypeDecorator):
cache_ok = True
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
def __init__(self, enum: Any, *arg: list[Any], **kw: dict[Any, Any]) -> None:
types.TypeDecorator.__init__(self, *arg, **kw)
# pyre-fixme[4]: Attribute must be annotated.
self._member_map = enum._member_map_
# pyre-fixme[4]: Attribute must be annotated.
self._value2member_map = enum._value2member_map_
# pyre-fixme[3]: Return annotation cannot be `Any`.
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
[docs]
def process_bind_param(self, value: Any, dialect: Any) -> Any:
if value is None:
return value
if not isinstance(value, enum.Enum):
raise TypeError("Value is not an instance of Enum.")
val = self._member_map.get(value.name)
if val is None:
raise ValueError(
"Member '{value}' is not a supported enum: {members}".format(
value=value, members=list(self._member_map.keys())
)
)
return val._value_
# pyre-fixme[3]: Return annotation cannot be `Any`.
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
[docs]
def process_result_value(self, value: Any, dialect: Any) -> Any:
if value is None:
return value
member = self._value2member_map.get(value)
if member is None:
raise ValueError(
f"Value '{value}' is not one of the supported "
+ "enum values: {supported_values}".format(
supported_values=list(self._value2member_map.keys())
)
)
return member
[docs]
class IntEnum(BaseNullableEnum):
# pyre-fixme[8]: Attribute has type `SmallInteger`; used as
# `Type[sqlalchemy.sql.sqltypes.SmallInteger]`.
impl: types.SmallInteger = types.SmallInteger
[docs]
class StringEnum(BaseNullableEnum):
impl = types.VARCHAR(NAME_OR_TYPE_FIELD_LENGTH)