Source code for ax.utils.common.equality

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.

from datetime import datetime
from typing import Any, Callable, List, Optional


[docs]def equality_typechecker(eq_func: Callable) -> Callable: """A decorator to wrap all __eq__ methods to ensure that the inputs are of the right type. """ # no type annotation for now; breaks sphinx-autodoc-typehints def _type_safe_equals(self, other): if not isinstance(other, self.__class__): return False return eq_func(self, other) return _type_safe_equals
[docs]def list_equals(list1: List[Any], list2: List[Any]) -> bool: """Compare equality of two lists of core Ax objects. Assumptions: -- The contents of each list are types that implement __eq__ -- The lists do not contain duplicates Checking equality is then the same as checking that the lists are the same length, and that one is a subset of the other. """ if len(list1) != len(list2): return False for item1 in list1: found = False for item2 in list2: if item1 == item2: found = True break if not found: return False return True
[docs]def datetime_equals(dt1: Optional[datetime], dt2: Optional[datetime]) -> bool: """Compare equality of two datetimes, ignoring microseconds.""" if not dt1 and not dt2: return True if not (dt1 and dt2): return False return dt1.replace(microsecond=0) == dt2.replace(microsecond=0)