diff --git a/ax/utils/common/equality.py b/ax/utils/common/equality.py index 1f2868f7e5b..2a0841f68e3 100644 --- a/ax/utils/common/equality.py +++ b/ax/utils/common/equality.py @@ -42,30 +42,45 @@ def same_elements(list1: List[Any], list2: List[Any]) -> bool: -- The lists do not contain duplicates Checking equality is then the same as checking that the lists are the same - length, and that one is a subset of the other. + length, and that both are subsets of the other. """ if len(list1) != len(list2): return False + matched = [False] * len(list2) for item1 in list1: - found = False - for item2 in list2: - if isinstance(item1, np.ndarray) or isinstance(item2, np.ndarray): - if ( - isinstance(item1, np.ndarray) - and isinstance(item2, np.ndarray) - and np.array_equal(item1, item2) - ): - found = True - break - elif item1 == item2: - found = True + for i, item2 in enumerate(list2): + if not matched[i] and is_ax_equal(item1, item2): + matched[i] = True break - if not found: + else: return False + return all(matched) + - return True +# pyre-fixme[2]: Parameter annotation cannot contain `Any`. +def is_ax_equal(one_val: Any, other_val: Any) -> bool: + """Check for equality of two values, handling lists, dicts, dfs, floats, + dates, and numpy arrays. This method and ``same_elements`` function + as a recursive unit. + """ + if isinstance(one_val, list) and isinstance(other_val, list): + return same_elements(one_val, other_val) + elif isinstance(one_val, dict) and isinstance(other_val, dict): + return sorted(one_val.keys()) == sorted(other_val.keys()) and same_elements( + list(one_val.values()), list(other_val.values()) + ) + elif isinstance(one_val, np.ndarray) and isinstance(other_val, np.ndarray): + return np.array_equal(one_val, other_val, equal_nan=True) + elif isinstance(one_val, datetime): + return datetime_equals(one_val, other_val) + elif isinstance(one_val, float) and isinstance(other_val, float): + return np.isclose(one_val, other_val, equal_nan=True) + elif isinstance(one_val, pd.DataFrame) and isinstance(other_val, pd.DataFrame): + return dataframe_equals(one_val, other_val) + else: + return one_val == other_val def datetime_equals(dt1: Optional[datetime], dt2: Optional[datetime]) -> bool: @@ -198,25 +213,8 @@ def object_attribute_dicts_find_unequal_fields( and isinstance(one_val.model, type(other_val.model)) ) - elif isinstance(one_val, list): - equal = isinstance(other_val, list) and same_elements(one_val, other_val) - elif isinstance(one_val, dict): - equal = isinstance(other_val, dict) and sorted(one_val.keys()) == sorted( - other_val.keys() - ) - equal = equal and same_elements( - list(one_val.values()), list(other_val.values()) - ) - elif isinstance(one_val, np.ndarray): - equal = np.array_equal(one_val, other_val, equal_nan=True) - elif isinstance(one_val, datetime): - equal = datetime_equals(one_val, other_val) - elif isinstance(one_val, float): - equal = np.isclose(one_val, other_val) - elif isinstance(one_val, pd.DataFrame): - equal = dataframe_equals(one_val, other_val) else: - equal = one_val == other_val + equal = is_ax_equal(one_val, other_val) if not equal: unequal_value[field] = (one_val, other_val) diff --git a/ax/utils/common/tests/test_equality.py b/ax/utils/common/tests/test_equality.py index 4d2df2fc509..e21e49dd966 100644 --- a/ax/utils/common/tests/test_equality.py +++ b/ax/utils/common/tests/test_equality.py @@ -34,8 +34,15 @@ def eq(x, y): def test_ListsEquals(self) -> None: self.assertFalse(same_elements([0], [0, 1])) self.assertFalse(same_elements([1, 0], [0, 2])) + self.assertFalse(same_elements([1, 1], [1, 2])) + self.assertFalse(same_elements([1, 2], [1, 1])) + self.assertFalse(same_elements([1, 1, 2], [1, 2, 2])) self.assertTrue(same_elements([1, 0], [0, 1])) + def test_ListsEquals_floats(self) -> None: + self.assertTrue(same_elements([0.0], [0.000000000000001])) + self.assertTrue(same_elements([float("nan")], [float("nan")])) + def test_DatetimeEquals(self) -> None: now = datetime.now() self.assertTrue(datetime_equals(None, None))