Skip to content

Commit

Permalink
Fix equality check for floats
Browse files Browse the repository at this point in the history
Summary:
`same_elements()` wasn't working for `float('nan')`, and it wasn't treating floats with `np.is_close()` like in `object_attribute_dicts_find_unequal_fields()`.

Also, `same_elements` was generally broken.  Example:
{F1676730235}

Differential Revision: D58289519
  • Loading branch information
Daniel Cohen authored and facebook-github-bot committed Jun 7, 2024
1 parent 40ae984 commit 4fa4ff8
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 23 deletions.
49 changes: 26 additions & 23 deletions ax/utils/common/equality.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,25 @@ def _type_safe_equals(self, other):
return _type_safe_equals


# pyre-fixme[2]: Parameter annotation cannot contain `Any`.
def contains_element(container: List[Any], element: Any) -> bool:
"""Check if a list contains an element in a type safe way."""
for item in container:
if isinstance(element, np.ndarray) or isinstance(item, np.ndarray):
if (
isinstance(element, np.ndarray)
and isinstance(item, np.ndarray)
and np.array_equal(element, item, equal_nan=True)
):
return True
elif element == item:
return True
elif isinstance(element, float) or isinstance(item, float):
if np.isclose(element, item, equal_nan=True):
return True
return False


# pyre-fixme[2]: Parameter annotation cannot contain `Any`.
def same_elements(list1: List[Any], list2: List[Any]) -> bool:
"""Compare equality of two lists of core Ax objects.
Expand All @@ -42,30 +61,14 @@ def same_elements(list1: List[Any], list2: List[Any]) -> bool:
-- The lists do not contain duplicates
Checking equality is then the same as checking that the lists are the same
length, and that one is a subset of the other.
length, and that both are subsets of the other.
"""

if len(list1) != len(list2):
return False

for item1 in list1:
found = False
for item2 in list2:
if isinstance(item1, np.ndarray) or isinstance(item2, np.ndarray):
if (
isinstance(item1, np.ndarray)
and isinstance(item2, np.ndarray)
and np.array_equal(item1, item2)
):
found = True
break
elif item1 == item2:
found = True
break
if not found:
return False

return True
return (
len(list1) == len(list2)
and all(contains_element(container=list2, element=item1) for item1 in list1)
and all(contains_element(container=list1, element=item2) for item2 in list2)
)


def datetime_equals(dt1: Optional[datetime], dt2: Optional[datetime]) -> bool:
Expand Down Expand Up @@ -212,7 +215,7 @@ def object_attribute_dicts_find_unequal_fields(
elif isinstance(one_val, datetime):
equal = datetime_equals(one_val, other_val)
elif isinstance(one_val, float):
equal = np.isclose(one_val, other_val)
equal = np.isclose(one_val, other_val, equal_nan=True)
elif isinstance(one_val, pd.DataFrame):
equal = dataframe_equals(one_val, other_val)
else:
Expand Down
6 changes: 6 additions & 0 deletions ax/utils/common/tests/test_equality.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,14 @@ def eq(x, y):
def test_ListsEquals(self) -> None:
self.assertFalse(same_elements([0], [0, 1]))
self.assertFalse(same_elements([1, 0], [0, 2]))
self.assertFalse(same_elements([1, 1], [1, 2]))
self.assertFalse(same_elements([1, 2], [1, 1]))
self.assertTrue(same_elements([1, 0], [0, 1]))

def test_ListsEquals_floats(self) -> None:
self.assertTrue(same_elements([0.0], [0.000000000000001]))
self.assertTrue(same_elements([float("nan")], [float("nan")]))

def test_DatetimeEquals(self) -> None:
now = datetime.now()
self.assertTrue(datetime_equals(None, None))
Expand Down

0 comments on commit 4fa4ff8

Please sign in to comment.