Source code for ax.utils.common.base
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# pyre-strict
from __future__ import annotations
import abc
from typing import Optional
from ax.utils.common.equality import equality_typechecker, object_attribute_dicts_equal
[docs]class Base:
"""Metaclass for core Ax classes. Provides an equality check and `db_id`
property for SQA storage.
"""
_db_id: Optional[int] = None
@property
def db_id(self) -> Optional[int]:
return self._db_id
@db_id.setter
def db_id(self, db_id: int) -> None:
self._db_id = db_id
@equality_typechecker
def __eq__(self, other: Base) -> bool:
return object_attribute_dicts_equal(
one_dict=self.__dict__, other_dict=other.__dict__
)
@equality_typechecker
def _eq_skip_db_id_check(self, other: Base) -> bool:
return object_attribute_dicts_equal(
one_dict=self.__dict__, other_dict=other.__dict__, skip_db_id_check=True
)
[docs]class SortableBase(Base, metaclass=abc.ABCMeta):
"""Extension to the base class that also provides an inequality check."""
@property
@abc.abstractmethod
def _unique_id(self) -> str:
"""Returns an identification string that can be used to uniquely
identify this instance from others attached to the same parent
object. For example, for ``Trials`` this can be their index,
since that is unique w.r.t. to parent ``Experiment`` object.
For ``GenerationNode``-s attached to a ``GenerationStrategy``,
this can be their name since we ensure uniqueness of it upon
``GenerationStrategy`` instantiation.
This method is needed to correctly update SQLAlchemy objects
that appear as children of other objects, in lists or other
sortable collections or containers.
"""
pass
def __lt__(self, other: SortableBase) -> bool:
return self._unique_id < other._unique_id