Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix equality check for floats #2507

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 33 additions & 33 deletions ax/utils/common/equality.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,30 +42,47 @@ def same_elements(list1: List[Any], list2: List[Any]) -> bool:
-- The lists do not contain duplicates

Checking equality is then the same as checking that the lists are the same
length, and that one is a subset of the other.
length, and that both are subsets of the other.
"""

if len(list1) != len(list2):
return False

matched = [False for _ in list2]
for item1 in list1:
found = False
for item2 in list2:
if isinstance(item1, np.ndarray) or isinstance(item2, np.ndarray):
if (
isinstance(item1, np.ndarray)
and isinstance(item2, np.ndarray)
and np.array_equal(item1, item2)
):
found = True
break
elif item1 == item2:
found = True
matched_this_item = False
for i, item2 in enumerate(list2):
if not matched[i] and is_ax_equal(item1, item2):
matched[i] = True
matched_this_item = True
break
if not found:
if not matched_this_item:
return False
return all(matched)

return True

# pyre-fixme[2]: Parameter annotation cannot contain `Any`.
def is_ax_equal(one_val: Any, other_val: Any) -> bool:
"""Check for equality of two values, handling lists, dicts, dfs, floats,
dates, and numpy arrays. This method and ``same_elements`` function
as a recursive unit.
"""
if isinstance(one_val, list) and isinstance(other_val, list):
return same_elements(one_val, other_val)
elif isinstance(one_val, dict) and isinstance(other_val, dict):
return sorted(one_val.keys()) == sorted(other_val.keys()) and same_elements(
list(one_val.values()), list(other_val.values())
)
elif isinstance(one_val, np.ndarray) and isinstance(other_val, np.ndarray):
return np.array_equal(one_val, other_val, equal_nan=True)
elif isinstance(one_val, datetime):
return datetime_equals(one_val, other_val)
elif isinstance(one_val, float) and isinstance(other_val, float):
return np.isclose(one_val, other_val, equal_nan=True)
elif isinstance(one_val, pd.DataFrame) and isinstance(other_val, pd.DataFrame):
return dataframe_equals(one_val, other_val)
else:
return one_val == other_val


def datetime_equals(dt1: Optional[datetime], dt2: Optional[datetime]) -> bool:
Expand Down Expand Up @@ -198,25 +215,8 @@ def object_attribute_dicts_find_unequal_fields(
and isinstance(one_val.model, type(other_val.model))
)

elif isinstance(one_val, list):
equal = isinstance(other_val, list) and same_elements(one_val, other_val)
elif isinstance(one_val, dict):
equal = isinstance(other_val, dict) and sorted(one_val.keys()) == sorted(
other_val.keys()
)
equal = equal and same_elements(
list(one_val.values()), list(other_val.values())
)
elif isinstance(one_val, np.ndarray):
equal = np.array_equal(one_val, other_val, equal_nan=True)
elif isinstance(one_val, datetime):
equal = datetime_equals(one_val, other_val)
elif isinstance(one_val, float):
equal = np.isclose(one_val, other_val)
elif isinstance(one_val, pd.DataFrame):
equal = dataframe_equals(one_val, other_val)
else:
equal = one_val == other_val
equal = is_ax_equal(one_val, other_val)

if not equal:
unequal_value[field] = (one_val, other_val)
Expand Down
7 changes: 7 additions & 0 deletions ax/utils/common/tests/test_equality.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,15 @@ def eq(x, y):
def test_ListsEquals(self) -> None:
self.assertFalse(same_elements([0], [0, 1]))
self.assertFalse(same_elements([1, 0], [0, 2]))
self.assertFalse(same_elements([1, 1], [1, 2]))
self.assertFalse(same_elements([1, 2], [1, 1]))
self.assertFalse(same_elements([1, 1, 2], [1, 2, 2]))
self.assertTrue(same_elements([1, 0], [0, 1]))

def test_ListsEquals_floats(self) -> None:
self.assertTrue(same_elements([0.0], [0.000000000000001]))
self.assertTrue(same_elements([float("nan")], [float("nan")]))

def test_DatetimeEquals(self) -> None:
now = datetime.now()
self.assertTrue(datetime_equals(None, None))
Expand Down