From cfaf6e7ae0915df13b748d58d497f834cf3b48ef Mon Sep 17 00:00:00 2001 From: Marnik Bercx Date: Tue, 7 Dec 2021 11:21:18 +0100 Subject: [PATCH 1/2] =?UTF-8?q?=E2=80=BC=EF=B8=8F=20BREAKING:=20Compare=20?= =?UTF-8?q?`Dict`=20nodes=20by=20content?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Currently there is an inconsistency in how the base data type node instances compare equality. All base types compare based on the content of the node, whereas `Dict` instances rely on the UUID fallback introduced in #4753. After a long discussion started by #1917, it was finally decided that the best way forward is to make the equality comparison consitent among the base types (see #5187). Here we adapt the `__eq__` method of the `Dict` class to compare equality by content instead of relying on the fallback comparison of the UUIDs. --- aiida/orm/nodes/data/dict.py | 8 +++++--- aiida/orm/nodes/data/list.py | 5 ++--- tests/orm/nodes/data/test_dict.py | 25 ++++++++++--------------- 3 files changed, 17 insertions(+), 21 deletions(-) diff --git a/aiida/orm/nodes/data/dict.py b/aiida/orm/nodes/data/dict.py index 6cd542ca65..d9fc232750 100644 --- a/aiida/orm/nodes/data/dict.py +++ b/aiida/orm/nodes/data/dict.py @@ -71,10 +71,12 @@ def __setitem__(self, key, value): self.set_attribute(key, value) def __eq__(self, other): - if isinstance(other, dict): - return self.get_dict() == other + if isinstance(other, Dict): + return self.get_dict() == other.get_dict() + return self.get_dict() == other - return super().__eq__(other) + def __ne__(self, other): + return not self == other def set_dict(self, dictionary): """ Replace the current dictionary with another one. diff --git a/aiida/orm/nodes/data/list.py b/aiida/orm/nodes/data/list.py index cb05920a48..b4c033abed 100644 --- a/aiida/orm/nodes/data/list.py +++ b/aiida/orm/nodes/data/list.py @@ -52,10 +52,9 @@ def __str__(self): return f'{super().__str__()} value: {self.get_list()}' def __eq__(self, other): - try: + if isinstance(other, List): return self.get_list() == other.get_list() - except AttributeError: - return self.get_list() == other + return self.get_list() == other def __ne__(self, other): return not self == other diff --git a/tests/orm/nodes/data/test_dict.py b/tests/orm/nodes/data/test_dict.py index 7a27b91fe6..bf5eb1505d 100644 --- a/tests/orm/nodes/data/test_dict.py +++ b/tests/orm/nodes/data/test_dict.py @@ -85,27 +85,22 @@ def test_correct_raises(dictionary): def test_eq(dictionary): """Test the ``__eq__`` method. - A node should compare equal to itself and to the plain dictionary that represents its value. However, it should not - compare equal to another node that has the same content. This is a hot issue and is being discussed in the following - ticket: https://github.com/aiidateam/aiida-core/issues/1917 + A node should compare equal to a the plain dictionary that has the same value, as well as any other ``Dict`` node + that has the same content. For context, the discussion on whether to compare nodes by content was started in the + following issue: + + https://github.com/aiidateam/aiida-core/issues/1917 + + A summary and the final conclusion can be found in this discussion: + + https://github.com/aiidateam/aiida-core/discussions/5187 """ node = Dict(dictionary) clone = Dict(dictionary) assert node is node # pylint: disable=comparison-with-itself assert node == dictionary - assert node != clone - - # To test the fallback, where two ``Dict`` nodes are equal if their UUIDs are even if the content is different, we - # create a different node with other content, but artificially give it the same UUID as ``node``. In practice this - # wouldn't happen unless, by accident, two different nodes get the same UUID, the probability of which is minimal. - # Note that we have to set the UUID directly through the database model instance of the backend entity, since it is - # forbidden to change it through the front-end or backend entity instance, for good reasons. - other = Dict({}) - other.backend_entity._dbmodel.uuid = node.uuid # pylint: disable=protected-access - assert other.uuid == node.uuid - assert other.dict != node.dict - assert node == other + assert node == clone @pytest.mark.usefixtures('clear_database_before_test') From 4ff2782863225849f33194341b2c7229d05395ab Mon Sep 17 00:00:00 2001 From: Marnik Bercx Date: Tue, 7 Dec 2021 15:36:51 +0100 Subject: [PATCH 2/2] Remove `__ne__` and update tests --- aiida/orm/nodes/data/base.py | 5 ---- aiida/orm/nodes/data/dict.py | 3 --- aiida/orm/nodes/data/list.py | 3 --- tests/orm/nodes/data/test_base.py | 22 ++++++++++++++++++ tests/orm/nodes/data/test_dict.py | 13 ++++++++--- tests/orm/nodes/data/test_list.py | 38 +++++++++++++++++++------------ 6 files changed, 55 insertions(+), 29 deletions(-) diff --git a/aiida/orm/nodes/data/base.py b/aiida/orm/nodes/data/base.py index 070296ad0d..176b1445d0 100644 --- a/aiida/orm/nodes/data/base.py +++ b/aiida/orm/nodes/data/base.py @@ -50,10 +50,5 @@ def __eq__(self, other): return self.value == other.value return self.value == other - def __ne__(self, other): - if isinstance(other, BaseType): - return self.value != other.value - return self.value != other - def new(self, value=None): return self.__class__(value) diff --git a/aiida/orm/nodes/data/dict.py b/aiida/orm/nodes/data/dict.py index d9fc232750..2cdffba3d9 100644 --- a/aiida/orm/nodes/data/dict.py +++ b/aiida/orm/nodes/data/dict.py @@ -75,9 +75,6 @@ def __eq__(self, other): return self.get_dict() == other.get_dict() return self.get_dict() == other - def __ne__(self, other): - return not self == other - def set_dict(self, dictionary): """ Replace the current dictionary with another one. diff --git a/aiida/orm/nodes/data/list.py b/aiida/orm/nodes/data/list.py index b4c033abed..36bb57ae39 100644 --- a/aiida/orm/nodes/data/list.py +++ b/aiida/orm/nodes/data/list.py @@ -56,9 +56,6 @@ def __eq__(self, other): return self.get_list() == other.get_list() return self.get_list() == other - def __ne__(self, other): - return not self == other - def append(self, value): data = self.get_list() data.append(value) diff --git a/tests/orm/nodes/data/test_base.py b/tests/orm/nodes/data/test_base.py index adb564f42e..f70ff8b39f 100644 --- a/tests/orm/nodes/data/test_base.py +++ b/tests/orm/nodes/data/test_base.py @@ -205,3 +205,25 @@ def test_operator(opera): c_val = opera(node_x.value, node_y.value) assert res._type == type(c_val) # pylint: disable=protected-access assert res == opera(node_x.value, node_y.value) + + +@pytest.mark.usefixtures('clear_database_before_test') +@pytest.mark.parametrize('node_type, a, b', [ + (Bool, False, True), + (Int, 2, 5), + (Float, 2.5, 5.5), + (Str, 'a', 'b'), +]) +def test_equality(node_type, a, b): + """Test equality comparison for the base types.""" + node_a = node_type(a) + node_a_clone = node_type(a) + node_b = node_type(b) + + # Test equality comparison with Python base types + assert node_a == a + assert node_a != b + + # Test equality comparison with other `BaseType` nodes + assert node_a == node_a_clone + assert node_a != node_b diff --git a/tests/orm/nodes/data/test_dict.py b/tests/orm/nodes/data/test_dict.py index bf5eb1505d..0aaac9a2c9 100644 --- a/tests/orm/nodes/data/test_dict.py +++ b/tests/orm/nodes/data/test_dict.py @@ -82,8 +82,8 @@ def test_correct_raises(dictionary): @pytest.mark.usefixtures('clear_database_before_test') -def test_eq(dictionary): - """Test the ``__eq__`` method. +def test_equality(dictionary): + """Test the equality comparison for the ``Dict`` type. A node should compare equal to a the plain dictionary that has the same value, as well as any other ``Dict`` node that has the same content. For context, the discussion on whether to compare nodes by content was started in the @@ -95,12 +95,19 @@ def test_eq(dictionary): https://github.com/aiidateam/aiida-core/discussions/5187 """ + different_dict = {'I': {'am': 'different'}} node = Dict(dictionary) + different_node = Dict(different_dict) clone = Dict(dictionary) - assert node is node # pylint: disable=comparison-with-itself + # Test equality comparison with Python base type assert node == dictionary + assert node != different_dict + + # Test equality comparison between `Dict` nodes + assert node is node # pylint: disable=comparison-with-itself assert node == clone + assert node != different_node @pytest.mark.usefixtures('clear_database_before_test') diff --git a/tests/orm/nodes/data/test_list.py b/tests/orm/nodes/data/test_list.py index dd7f2309ce..41ff099d1a 100644 --- a/tests/orm/nodes/data/test_list.py +++ b/tests/orm/nodes/data/test_list.py @@ -71,7 +71,7 @@ def test_store_load(listing): @pytest.mark.usefixtures('clear_database_before_test') def test_special_methods(listing): """Test the special methods of the ``List`` class.""" - node = List(list=listing) + node = List(listing) # __getitem__ for i, value in enumerate(listing): @@ -91,11 +91,19 @@ def test_special_methods(listing): @pytest.mark.usefixtures('clear_database_before_test') def test_equality(listing): - """Test that two ``List`` nodes with equal content compare equal.""" - node1 = List(list=listing) - node2 = List(list=listing) + """Test equality comparison for ``List`` nodes.""" + different_list = ['I', 'am', 'different'] + node = List(listing) + different_node = List(different_list) + clone = List(listing) + + # Test equality comparison with Python base type + assert node == listing + assert node != different_list - assert node1 == node2 + # Test equality comparison with other `BaseType` nodes + assert node == clone + assert node != different_node @pytest.mark.usefixtures('clear_database_before_test') @@ -114,7 +122,7 @@ def do_checks(node): node.store() do_checks(node) - node = List(list=listing) + node = List(listing) node.append('more') assert node[-1] == 'more' @@ -145,7 +153,7 @@ def do_checks(node, lst): @pytest.mark.usefixtures('clear_database_before_test') def test_insert(listing): """Test the ``List.insert()`` method.""" - node = List(list=listing) + node = List(listing) node.insert(1, 'new') assert node[1] == 'new' assert len(node) == 4 @@ -154,7 +162,7 @@ def test_insert(listing): @pytest.mark.usefixtures('clear_database_before_test') def test_remove(listing): """Test the ``List.remove()`` method.""" - node = List(list=listing) + node = List(listing) node.remove(1) listing.remove(1) assert node.get_list() == listing @@ -166,7 +174,7 @@ def test_remove(listing): @pytest.mark.usefixtures('clear_database_before_test') def test_pop(listing): """Test the ``List.pop()`` method.""" - node = List(list=listing) + node = List(listing) node.pop() assert node.get_list() == listing[:-1] @@ -174,7 +182,7 @@ def test_pop(listing): @pytest.mark.usefixtures('clear_database_before_test') def test_index(listing): """Test the ``List.index()`` method.""" - node = List(list=listing) + node = List(listing) assert node.index(True) == listing.index(True) @@ -182,7 +190,7 @@ def test_index(listing): @pytest.mark.usefixtures('clear_database_before_test') def test_count(listing): """Test the ``List.count()`` method.""" - node = List(list=listing) + node = List(listing) for value in listing: assert node.count(value) == listing.count(value) @@ -190,12 +198,12 @@ def test_count(listing): @pytest.mark.usefixtures('clear_database_before_test') def test_sort(listing, int_listing): """Test the ``List.sort()`` method.""" - node = List(list=int_listing) + node = List(int_listing) node.sort() int_listing.sort() assert node.get_list() == int_listing - node = List(list=listing) + node = List(listing) with pytest.raises(TypeError, match=r"'<' not supported between instances of 'int' and 'str'"): node.sort() @@ -203,7 +211,7 @@ def test_sort(listing, int_listing): @pytest.mark.usefixtures('clear_database_before_test') def test_reverse(listing): """Test the ``List.reverse()`` method.""" - node = List(list=listing) + node = List(listing) node.reverse() listing.reverse() assert node.get_list() == listing @@ -212,5 +220,5 @@ def test_reverse(listing): @pytest.mark.usefixtures('clear_database_before_test') def test_initialise_with_list_kwarg(listing): """Test that the ``List`` node can be initialized with the ``list`` keyword argument for backwards compatibility.""" - node = List(list=listing) + node = List(listing) assert node.get_list() == listing