#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations
from datetime import datetime
from typing import Any, Callable, Dict, List, Optional, Tuple
import numpy as np
import pandas as pd
from ax.utils.common.typeutils import numpy_type_to_python_type
[docs]def equality_typechecker(eq_func: Callable) -> Callable:
"""A decorator to wrap all __eq__ methods to ensure that the inputs
are of the right type.
"""
# no type annotation for now; breaks sphinx-autodoc-typehints
def _type_safe_equals(self, other):
if not isinstance(other, self.__class__):
return False
return eq_func(self, other)
return _type_safe_equals
[docs]def same_elements(list1: List[Any], list2: List[Any]) -> bool:
"""Compare equality of two lists of core Ax objects.
Assumptions:
-- The contents of each list are types that implement __eq__
-- The lists do not contain duplicates
Checking equality is then the same as checking that the lists are the same
length, and that one is a subset of the other.
"""
if len(list1) != len(list2):
return False
for item1 in list1:
found = False
for item2 in list2:
if isinstance(item1, np.ndarray) or isinstance(item2, np.ndarray):
if (
isinstance(item1, np.ndarray)
and isinstance(item2, np.ndarray)
and np.array_equal(item1, item2)
):
found = True
break
elif item1 == item2:
found = True
break
if not found:
return False
return True
[docs]def datetime_equals(dt1: Optional[datetime], dt2: Optional[datetime]) -> bool:
"""Compare equality of two datetimes, ignoring microseconds."""
if not dt1 and not dt2:
return True
if not (dt1 and dt2):
return False
return dt1.replace(microsecond=0) == dt2.replace(microsecond=0)
[docs]def dataframe_equals(df1: pd.DataFrame, df2: pd.DataFrame) -> bool:
"""Compare equality of two pandas dataframes."""
try:
if df1.empty and df2.empty:
equal = True
else:
pd.testing.assert_frame_equal(
df1.sort_index(axis=1), df2.sort_index(axis=1), check_exact=False
)
equal = True
except AssertionError:
equal = False
return equal
[docs]def object_attribute_dicts_equal(
one_dict: Dict[str, Any], other_dict: Dict[str, Any]
) -> bool:
"""Utility to check if all items in attribute dicts of two Ax objects
are the same.
NOTE: Special-cases some Ax object attributes, like "_experiment" or
"_model", where full equality is hard to check.
"""
unequal_type, unequal_value = object_attribute_dicts_find_unequal_fields(
one_dict=one_dict, other_dict=other_dict
)
return not bool(unequal_type or unequal_value)
[docs]def object_attribute_dicts_find_unequal_fields(
one_dict: Dict[str, Any],
other_dict: Dict[str, Any],
fast_return: bool = True,
skip_db_id_check: bool = False,
) -> Tuple[Dict[str, Tuple[Any, Any]], Dict[str, Tuple[Any, Any]]]:
"""Utility for finding out what attributes of two objects' attribute dicts
are unequal.
Args:
one_dict: First object's attribute dict (`obj.__dict__`).
other_dict: Second object's attribute dict (`obj.__dict__`).
fast_return: Boolean representing whether to return as soon as a
single unequal attribute was found or to iterate over all attributes
and collect all unequal ones.
Returns:
Two dictionaries:
- attribute name to attribute values of unequal type (as a tuple),
- attribute name to attribute values of unequal value (as a tuple).
"""
unequal_type, unequal_value = {}, {}
for field in one_dict:
one_val = one_dict.get(field)
other_val = other_dict.get(field)
one_val = numpy_type_to_python_type(one_val)
other_val = numpy_type_to_python_type(other_val)
if type(one_val) != type(other_val):
unequal_type[field] = (one_val, other_val)
if fast_return:
return unequal_type, unequal_value
if field == "_experiment":
# prevent infinite loop when checking equality of Trials
equal = one_val is other_val is None or (one_val._name == other_val._name)
elif field == "analysis_scheduler":
# prevent infinite loop when checking equality of analysis runs
equal = one_val is other_val is None or (one_val.db_id == other_val.db_id)
elif field == "_db_id":
equal = skip_db_id_check or one_val == other_val
elif field == "_model": # pragma: no cover (tested in modelbridge)
# TODO[T52643706]: replace with per-`ModelBridge` method like
# `equivalent_models`, to compare models more meaningfully.
if not hasattr(one_val, "model"):
equal = not hasattr(other_val, "model")
else:
# If model bridges have a `model` attribute, the types of the
# values of those attributes should be equal if the model
# bridge is the same.
equal = isinstance(one_val.model, type(other_val.model))
elif isinstance(one_val, list):
equal = isinstance(other_val, list) and same_elements(one_val, other_val)
elif isinstance(one_val, dict):
equal = isinstance(other_val, dict) and sorted(one_val.keys()) == sorted(
other_val.keys()
)
equal = equal and same_elements(
list(one_val.values()), list(other_val.values())
)
elif isinstance(one_val, np.ndarray):
equal = np.array_equal(one_val, other_val)
elif isinstance(one_val, datetime):
equal = datetime_equals(one_val, other_val)
elif isinstance(one_val, float):
equal = np.isclose(one_val, other_val)
elif isinstance(one_val, pd.DataFrame):
equal = dataframe_equals(one_val, other_val)
else:
equal = one_val == other_val
if not equal:
unequal_value[field] = (one_val, other_val)
if fast_return:
return unequal_type, unequal_value
return unequal_type, unequal_value