diff --git a/aiida/orm/nodes/data/base.py b/aiida/orm/nodes/data/base.py index 070296ad0d..176b1445d0 100644 --- a/aiida/orm/nodes/data/base.py +++ b/aiida/orm/nodes/data/base.py @@ -50,10 +50,5 @@ def __eq__(self, other): return self.value == other.value return self.value == other - def __ne__(self, other): - if isinstance(other, BaseType): - return self.value != other.value - return self.value != other - def new(self, value=None): return self.__class__(value) diff --git a/aiida/orm/nodes/data/dict.py b/aiida/orm/nodes/data/dict.py index 6cd542ca65..2cdffba3d9 100644 --- a/aiida/orm/nodes/data/dict.py +++ b/aiida/orm/nodes/data/dict.py @@ -71,10 +71,9 @@ def __setitem__(self, key, value): self.set_attribute(key, value) def __eq__(self, other): - if isinstance(other, dict): - return self.get_dict() == other - - return super().__eq__(other) + if isinstance(other, Dict): + return self.get_dict() == other.get_dict() + return self.get_dict() == other def set_dict(self, dictionary): """ Replace the current dictionary with another one. diff --git a/aiida/orm/nodes/data/list.py b/aiida/orm/nodes/data/list.py index cb05920a48..36bb57ae39 100644 --- a/aiida/orm/nodes/data/list.py +++ b/aiida/orm/nodes/data/list.py @@ -52,13 +52,9 @@ def __str__(self): return f'{super().__str__()} value: {self.get_list()}' def __eq__(self, other): - try: + if isinstance(other, List): return self.get_list() == other.get_list() - except AttributeError: - return self.get_list() == other - - def __ne__(self, other): - return not self == other + return self.get_list() == other def append(self, value): data = self.get_list() diff --git a/tests/orm/nodes/data/test_base.py b/tests/orm/nodes/data/test_base.py index adb564f42e..f70ff8b39f 100644 --- a/tests/orm/nodes/data/test_base.py +++ b/tests/orm/nodes/data/test_base.py @@ -205,3 +205,25 @@ def test_operator(opera): c_val = opera(node_x.value, node_y.value) assert res._type == type(c_val) # pylint: disable=protected-access assert res == opera(node_x.value, node_y.value) + + +@pytest.mark.usefixtures('clear_database_before_test') +@pytest.mark.parametrize('node_type, a, b', [ + (Bool, False, True), + (Int, 2, 5), + (Float, 2.5, 5.5), + (Str, 'a', 'b'), +]) +def test_equality(node_type, a, b): + """Test equality comparison for the base types.""" + node_a = node_type(a) + node_a_clone = node_type(a) + node_b = node_type(b) + + # Test equality comparison with Python base types + assert node_a == a + assert node_a != b + + # Test equality comparison with other `BaseType` nodes + assert node_a == node_a_clone + assert node_a != node_b diff --git a/tests/orm/nodes/data/test_dict.py b/tests/orm/nodes/data/test_dict.py index 7a27b91fe6..0aaac9a2c9 100644 --- a/tests/orm/nodes/data/test_dict.py +++ b/tests/orm/nodes/data/test_dict.py @@ -82,30 +82,32 @@ def test_correct_raises(dictionary): @pytest.mark.usefixtures('clear_database_before_test') -def test_eq(dictionary): - """Test the ``__eq__`` method. +def test_equality(dictionary): + """Test the equality comparison for the ``Dict`` type. - A node should compare equal to itself and to the plain dictionary that represents its value. However, it should not - compare equal to another node that has the same content. This is a hot issue and is being discussed in the following - ticket: https://github.com/aiidateam/aiida-core/issues/1917 + A node should compare equal to a the plain dictionary that has the same value, as well as any other ``Dict`` node + that has the same content. For context, the discussion on whether to compare nodes by content was started in the + following issue: + + https://github.com/aiidateam/aiida-core/issues/1917 + + A summary and the final conclusion can be found in this discussion: + + https://github.com/aiidateam/aiida-core/discussions/5187 """ + different_dict = {'I': {'am': 'different'}} node = Dict(dictionary) + different_node = Dict(different_dict) clone = Dict(dictionary) - assert node is node # pylint: disable=comparison-with-itself + # Test equality comparison with Python base type assert node == dictionary - assert node != clone - - # To test the fallback, where two ``Dict`` nodes are equal if their UUIDs are even if the content is different, we - # create a different node with other content, but artificially give it the same UUID as ``node``. In practice this - # wouldn't happen unless, by accident, two different nodes get the same UUID, the probability of which is minimal. - # Note that we have to set the UUID directly through the database model instance of the backend entity, since it is - # forbidden to change it through the front-end or backend entity instance, for good reasons. - other = Dict({}) - other.backend_entity._dbmodel.uuid = node.uuid # pylint: disable=protected-access - assert other.uuid == node.uuid - assert other.dict != node.dict - assert node == other + assert node != different_dict + + # Test equality comparison between `Dict` nodes + assert node is node # pylint: disable=comparison-with-itself + assert node == clone + assert node != different_node @pytest.mark.usefixtures('clear_database_before_test') diff --git a/tests/orm/nodes/data/test_list.py b/tests/orm/nodes/data/test_list.py index dd7f2309ce..41ff099d1a 100644 --- a/tests/orm/nodes/data/test_list.py +++ b/tests/orm/nodes/data/test_list.py @@ -71,7 +71,7 @@ def test_store_load(listing): @pytest.mark.usefixtures('clear_database_before_test') def test_special_methods(listing): """Test the special methods of the ``List`` class.""" - node = List(list=listing) + node = List(listing) # __getitem__ for i, value in enumerate(listing): @@ -91,11 +91,19 @@ def test_special_methods(listing): @pytest.mark.usefixtures('clear_database_before_test') def test_equality(listing): - """Test that two ``List`` nodes with equal content compare equal.""" - node1 = List(list=listing) - node2 = List(list=listing) + """Test equality comparison for ``List`` nodes.""" + different_list = ['I', 'am', 'different'] + node = List(listing) + different_node = List(different_list) + clone = List(listing) + + # Test equality comparison with Python base type + assert node == listing + assert node != different_list - assert node1 == node2 + # Test equality comparison with other `BaseType` nodes + assert node == clone + assert node != different_node @pytest.mark.usefixtures('clear_database_before_test') @@ -114,7 +122,7 @@ def do_checks(node): node.store() do_checks(node) - node = List(list=listing) + node = List(listing) node.append('more') assert node[-1] == 'more' @@ -145,7 +153,7 @@ def do_checks(node, lst): @pytest.mark.usefixtures('clear_database_before_test') def test_insert(listing): """Test the ``List.insert()`` method.""" - node = List(list=listing) + node = List(listing) node.insert(1, 'new') assert node[1] == 'new' assert len(node) == 4 @@ -154,7 +162,7 @@ def test_insert(listing): @pytest.mark.usefixtures('clear_database_before_test') def test_remove(listing): """Test the ``List.remove()`` method.""" - node = List(list=listing) + node = List(listing) node.remove(1) listing.remove(1) assert node.get_list() == listing @@ -166,7 +174,7 @@ def test_remove(listing): @pytest.mark.usefixtures('clear_database_before_test') def test_pop(listing): """Test the ``List.pop()`` method.""" - node = List(list=listing) + node = List(listing) node.pop() assert node.get_list() == listing[:-1] @@ -174,7 +182,7 @@ def test_pop(listing): @pytest.mark.usefixtures('clear_database_before_test') def test_index(listing): """Test the ``List.index()`` method.""" - node = List(list=listing) + node = List(listing) assert node.index(True) == listing.index(True) @@ -182,7 +190,7 @@ def test_index(listing): @pytest.mark.usefixtures('clear_database_before_test') def test_count(listing): """Test the ``List.count()`` method.""" - node = List(list=listing) + node = List(listing) for value in listing: assert node.count(value) == listing.count(value) @@ -190,12 +198,12 @@ def test_count(listing): @pytest.mark.usefixtures('clear_database_before_test') def test_sort(listing, int_listing): """Test the ``List.sort()`` method.""" - node = List(list=int_listing) + node = List(int_listing) node.sort() int_listing.sort() assert node.get_list() == int_listing - node = List(list=listing) + node = List(listing) with pytest.raises(TypeError, match=r"'<' not supported between instances of 'int' and 'str'"): node.sort() @@ -203,7 +211,7 @@ def test_sort(listing, int_listing): @pytest.mark.usefixtures('clear_database_before_test') def test_reverse(listing): """Test the ``List.reverse()`` method.""" - node = List(list=listing) + node = List(listing) node.reverse() listing.reverse() assert node.get_list() == listing @@ -212,5 +220,5 @@ def test_reverse(listing): @pytest.mark.usefixtures('clear_database_before_test') def test_initialise_with_list_kwarg(listing): """Test that the ``List`` node can be initialized with the ``list`` keyword argument for backwards compatibility.""" - node = List(list=listing) + node = List(listing) assert node.get_list() == listing