Source code for ax.utils.common.equality

#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

from __future__ import annotations

from datetime import datetime
from typing import Any, Callable, Dict, List, Optional, Tuple

import numpy as np
import pandas as pd
from ax.utils.common.typeutils_nonnative import numpy_type_to_python_type


# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
[docs]def equality_typechecker(eq_func: Callable) -> Callable: """A decorator to wrap all __eq__ methods to ensure that the inputs are of the right type. """ # no type annotation for now; breaks sphinx-autodoc-typehints # pyre-fixme[3]: Return type must be annotated. # pyre-fixme[2]: Parameter must be annotated. def _type_safe_equals(self, other): if not isinstance(other, self.__class__): return False return eq_func(self, other) return _type_safe_equals
# pyre-fixme[2]: Parameter annotation cannot contain `Any`.
[docs]def same_elements(list1: List[Any], list2: List[Any]) -> bool: """Compare equality of two lists of core Ax objects. Assumptions: -- The contents of each list are types that implement __eq__ -- The lists do not contain duplicates Checking equality is then the same as checking that the lists are the same length, and that one is a subset of the other. """ if len(list1) != len(list2): return False for item1 in list1: found = False for item2 in list2: if isinstance(item1, np.ndarray) or isinstance(item2, np.ndarray): if ( isinstance(item1, np.ndarray) and isinstance(item2, np.ndarray) and np.array_equal(item1, item2) ): found = True break elif item1 == item2: found = True break if not found: return False return True
[docs]def datetime_equals(dt1: Optional[datetime], dt2: Optional[datetime]) -> bool: """Compare equality of two datetimes, ignoring microseconds.""" if not dt1 and not dt2: return True if not (dt1 and dt2): return False return dt1.replace(microsecond=0) == dt2.replace(microsecond=0)
[docs]def dataframe_equals(df1: pd.DataFrame, df2: pd.DataFrame) -> bool: """Compare equality of two pandas dataframes.""" try: if df1.empty and df2.empty: equal = True else: pd.testing.assert_frame_equal( df1.sort_index(axis=1), df2.sort_index(axis=1), check_exact=False ) equal = True except AssertionError: equal = False return equal
[docs]def object_attribute_dicts_equal( one_dict: Dict[str, Any], other_dict: Dict[str, Any], skip_db_id_check: bool = False ) -> bool: """Utility to check if all items in attribute dicts of two Ax objects are the same. NOTE: Special-cases some Ax object attributes, like "_experiment" or "_model", where full equality is hard to check. Args: one_dict: First object's attribute dict (``obj.__dict__``). other_dict: Second object's attribute dict (``obj.__dict__``). skip_db_id_check: If ``True``, will exclude the ``db_id`` attributes from the equality check. Useful for ensuring that all attributes of an object are equal except the ids, with which one or both of them are saved to the database (e.g. if confirming an object before it was saved, to the version reloaded from the DB). """ unequal_type, unequal_value = object_attribute_dicts_find_unequal_fields( one_dict=one_dict, other_dict=other_dict, skip_db_id_check=skip_db_id_check ) return not bool(unequal_type or unequal_value)
# pyre-fixme[3]: Return annotation cannot contain `Any`.
[docs]def object_attribute_dicts_find_unequal_fields( one_dict: Dict[str, Any], other_dict: Dict[str, Any], fast_return: bool = True, skip_db_id_check: bool = False, ) -> Tuple[Dict[str, Tuple[Any, Any]], Dict[str, Tuple[Any, Any]]]: """Utility for finding out what attributes of two objects' attribute dicts are unequal. Args: one_dict: First object's attribute dict (``obj.__dict__``). other_dict: Second object's attribute dict (``obj.__dict__``). fast_return: Boolean representing whether to return as soon as a single unequal attribute was found or to iterate over all attributes and collect all unequal ones. skip_db_id_check: If ``True``, will exclude the ``db_id`` attributes from the equality check. Useful for ensuring that all attributes of an object are equal except the ids, with which one or both of them are saved to the database (e.g. if confirming an object before it was saved, to the version reloaded from the DB). Returns: Two dictionaries: - attribute name to attribute values of unequal type (as a tuple), - attribute name to attribute values of unequal value (as a tuple). """ unequal_type, unequal_value = {}, {} for field in one_dict: one_val = one_dict.get(field) other_val = other_dict.get(field) one_val = numpy_type_to_python_type(one_val) other_val = numpy_type_to_python_type(other_val) skip_type_check = skip_db_id_check and field == "_db_id" if not skip_type_check and (type(one_val) is not type(other_val)): unequal_type[field] = (one_val, other_val) if fast_return: return unequal_type, unequal_value if field == "_experiment": # Prevent infinite loop when checking equality of Trials (on Experiment, # with back-pointer), GenSteps (on GenerationStrategy), AnalysisRun-s # (on AnalysisScheduler). if one_val is None or other_val is None: equal = one_val is None and other_val is None else: # We compare `_name` because `name` attribute errors if not set. equal = one_val._name == other_val._name elif field == "_generation_strategy": # Prevent infinite loop when checking equality of Trials (on Experiment, # with back-pointer), GenSteps (on GenerationStrategy), AnalysisRun-s # (on AnalysisScheduler). if one_val is None or other_val is None: equal = one_val is None and other_val is None else: # We compare `name` because it is set dynimically in # some cases (see `GenerationStrategy.name` attribute). equal = one_val.name == other_val.name elif field == "analysis_scheduler": # prevent infinite loop when checking equality of analysis runs equal = one_val is other_val is None or (one_val.db_id == other_val.db_id) elif field == "_db_id": equal = skip_db_id_check or one_val == other_val elif field == "_model": # TODO[T52643706]: replace with per-`ModelBridge` method like # `equivalent_models`, to compare models more meaningfully. if not hasattr(one_val, "model") or not hasattr(other_val, "model"): equal = not hasattr(other_val, "model") and not hasattr( other_val, "model" ) else: # If model bridges have a `model` attribute, the types of the # values of those attributes should be equal if the model # bridge is the same. equal = ( hasattr(one_val, "model") and hasattr(other_val, "model") and isinstance(one_val.model, type(other_val.model)) ) elif isinstance(one_val, list): equal = isinstance(other_val, list) and same_elements(one_val, other_val) elif isinstance(one_val, dict): equal = isinstance(other_val, dict) and sorted(one_val.keys()) == sorted( other_val.keys() ) equal = equal and same_elements( list(one_val.values()), list(other_val.values()) ) elif isinstance(one_val, np.ndarray): equal = np.array_equal(one_val, other_val, equal_nan=True) elif isinstance(one_val, datetime): equal = datetime_equals(one_val, other_val) elif isinstance(one_val, float): equal = np.isclose(one_val, other_val) elif isinstance(one_val, pd.DataFrame): equal = dataframe_equals(one_val, other_val) else: equal = one_val == other_val if not equal: unequal_value[field] = (one_val, other_val) if fast_return: return unequal_type, unequal_value return unequal_type, unequal_value