Source code for ax.core.base
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from datetime import datetime
import numpy as np
import pandas as pd
from ax.utils.common.equality import (
datetime_equals,
equality_typechecker,
same_elements,
)
from ax.utils.common.typeutils import numpy_type_to_python_type
[docs]class Base(object):
"""Metaclass for core Ax classes."""
@equality_typechecker
def __eq__(self, other: "Base") -> bool:
for field in self.__dict__.keys():
self_val = getattr(self, field)
other_val = getattr(other, field)
self_val = numpy_type_to_python_type(self_val)
other_val = numpy_type_to_python_type(other_val)
if type(self_val) != type(other_val):
return False
if field == "_experiment":
# prevent infinite loop when checking equality of Trials
equal = self_val is other_val is None or (
self_val._name == other_val._name
)
elif field == "_model": # pragma: no cover (tested in modelbridge)
# TODO[T52643706]: replace with per-`ModelBridge` method like
# `equivalent_models`, to compare models more meaningfully.
if not hasattr(self_val, "model"):
equal = not hasattr(other_val, "model")
else:
# If model bridges have a `model` attribute, the types of the
# values of those attributes should be equal if the model
# bridge is the same.
equal = isinstance(self_val.model, type(other_val.model))
elif isinstance(self_val, list):
equal = same_elements(self_val, other_val)
elif isinstance(self_val, dict):
equal = sorted(self_val.keys()) == sorted(other_val.keys())
equal = equal and same_elements(
list(self_val.values()), list(other_val.values())
)
elif isinstance(self_val, np.ndarray):
equal = np.array_equal(self_val, other_val)
elif isinstance(self_val, datetime):
equal = datetime_equals(self_val, other_val)
elif isinstance(self_val, float):
equal = np.isclose(self_val, other_val)
elif isinstance(self_val, pd.DataFrame):
try:
if self_val.empty and other_val.empty:
equal = True
else:
pd.testing.assert_frame_equal(
self_val.sort_index(axis=1),
other_val.sort_index(axis=1),
check_exact=False,
)
equal = True
except AssertionError:
equal = False
else:
equal = self_val == other_val
if not equal:
return False
return True