Source code for ax.utils.common.equality
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from datetime import datetime
from typing import Any, Callable, List, Optional
import numpy as np
[docs]def equality_typechecker(eq_func: Callable) -> Callable:
"""A decorator to wrap all __eq__ methods to ensure that the inputs
are of the right type.
"""
# no type annotation for now; breaks sphinx-autodoc-typehints
def _type_safe_equals(self, other):
if not isinstance(other, self.__class__):
return False
return eq_func(self, other)
return _type_safe_equals
[docs]def same_elements(list1: List[Any], list2: List[Any]) -> bool:
"""Compare equality of two lists of core Ax objects.
Assumptions:
-- The contents of each list are types that implement __eq__
-- The lists do not contain duplicates
Checking equality is then the same as checking that the lists are the same
length, and that one is a subset of the other.
"""
if len(list1) != len(list2):
return False
for item1 in list1:
found = False
for item2 in list2:
if isinstance(item1, np.ndarray) or isinstance(item2, np.ndarray):
if (
isinstance(item1, np.ndarray)
and isinstance(item2, np.ndarray)
and np.array_equal(item1, item2)
):
found = True
break
elif item1 == item2:
found = True
break
if not found:
return False
return True
[docs]def datetime_equals(dt1: Optional[datetime], dt2: Optional[datetime]) -> bool:
"""Compare equality of two datetimes, ignoring microseconds."""
if not dt1 and not dt2:
return True
if not (dt1 and dt2):
return False
return dt1.replace(microsecond=0) == dt2.replace(microsecond=0)