From 97b74ee0481baf40a4dc48d1888f73bd8f9add4d Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Thu, 11 Apr 2024 10:25:48 +0200 Subject: [PATCH 1/3] Fix pool closing connections too aggressively Whenever a new routing table was fetched, the pool would close all connections to servers that were not part of the routing table. However, it might well be, that a missing server is present still in the routing table for another database. Hence, the pool now checks the routing tables for all databases before deciding which connections are no longer needed.g --- neo4j/io/__init__.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/neo4j/io/__init__.py b/neo4j/io/__init__.py index 7b8c63ee1..1e44f5484 100644 --- a/neo4j/io/__init__.py +++ b/neo4j/io/__init__.py @@ -1213,7 +1213,13 @@ def update_routing_table(self, *, database, imp_user, bookmarks, raise ServiceUnavailable("Unable to retrieve routing information") def update_connection_pool(self, *, database): - servers = self.get_or_create_routing_table(database).servers() + with self.refresh_lock: + routing_tables = [self.get_or_create_routing_table(database)] + for db in self.routing_tables.keys(): + if db == database: + continue + routing_tables.append(self.routing_tables[db]) + servers = set.union(*(rt.servers() for rt in routing_tables)) for address in list(self.connections): if address.unresolved not in servers: super(Neo4jPool, self).deactivate(address) From 839b7ca4b4464fa7dc11031b0168b8a7c5f22f46 Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Fri, 12 Apr 2024 12:57:06 +0200 Subject: [PATCH 2/3] Backport unit tests --- tests/unit/io/test_neo4j_pool.py | 293 +++++++++++++++++++--------- tests/unit/work/__init__.py | 2 +- tests/unit/work/_fake_connection.py | 169 ++++++++-------- tests/unit/work/test_session.py | 6 +- 4 files changed, 297 insertions(+), 173 deletions(-) diff --git a/tests/unit/io/test_neo4j_pool.py b/tests/unit/io/test_neo4j_pool.py index 730875349..ffbca0702 100644 --- a/tests/unit/io/test_neo4j_pool.py +++ b/tests/unit/io/test_neo4j_pool.py @@ -2,15 +2,13 @@ # -*- encoding: utf-8 -*- # Copyright (c) "Neo4j" -# Neo4j Sweden AB [http://neo4j.com] -# -# This file is part of Neo4j. +# Neo4j Sweden AB [https://neo4j.com] # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -19,62 +17,101 @@ # limitations under the License. -from unittest.mock import Mock +import inspect import pytest -from ..work import FakeConnection - from neo4j import ( READ_ACCESS, WRITE_ACCESS, ) +from neo4j._deadline import Deadline from neo4j.addressing import ResolvedAddress from neo4j.conf import ( PoolConfig, RoutingConfig, - WorkspaceConfig + WorkspaceConfig, ) -from neo4j._deadline import Deadline from neo4j.exceptions import ( ServiceUnavailable, - SessionExpired + SessionExpired, ) from neo4j.io import Neo4jPool - -ROUTER_ADDRESS = ResolvedAddress(("1.2.3.1", 9001), host_name="host") -READER_ADDRESS = ResolvedAddress(("1.2.3.1", 9002), host_name="host") -WRITER_ADDRESS = ResolvedAddress(("1.2.3.1", 9003), host_name="host") - - -@pytest.fixture() -def opener(): - def open_(addr, timeout): - connection = FakeConnection() - connection.addr = addr - connection.timeout = timeout - route_mock = Mock() - route_mock.return_value = [{ - "ttl": 1000, - "servers": [ - {"addresses": [str(ROUTER_ADDRESS)], "role": "ROUTE"}, - {"addresses": [str(READER_ADDRESS)], "role": "READ"}, - {"addresses": [str(WRITER_ADDRESS)], "role": "WRITE"}, - ], - }] - connection.attach_mock(route_mock, "route") - opener_.connections.append(connection) - return connection - - opener_ = Mock() - opener_.connections = [] - opener_.side_effect = open_ - return opener_ +from ..work import fake_connection_generator + + +ROUTER1_ADDRESS = ResolvedAddress(("1.2.3.1", 9000), host_name="host") +ROUTER2_ADDRESS = ResolvedAddress(("1.2.3.1", 9001), host_name="host") +ROUTER3_ADDRESS = ResolvedAddress(("1.2.3.1", 9002), host_name="host") +READER1_ADDRESS = ResolvedAddress(("1.2.3.1", 9010), host_name="host") +READER2_ADDRESS = ResolvedAddress(("1.2.3.1", 9011), host_name="host") +READER3_ADDRESS = ResolvedAddress(("1.2.3.1", 9012), host_name="host") +WRITER1_ADDRESS = ResolvedAddress(("1.2.3.1", 9020), host_name="host") + + +@pytest.fixture +def custom_routing_opener(fake_connection_generator, mocker): + def make_opener(failures=None, get_readers=None): + def routing_side_effect(*args, **kwargs): + nonlocal failures + res = next(failures, None) + if res is None: + if get_readers is not None: + readers = get_readers(kwargs.get("database", args[0])) + else: + readers = [str(READER1_ADDRESS)] + return [{ + "ttl": 1000, + "servers": [ + {"addresses": [str(ROUTER1_ADDRESS), + str(ROUTER2_ADDRESS), + str(ROUTER3_ADDRESS)], + "role": "ROUTE"}, + {"addresses": readers, "role": "READ"}, + {"addresses": [str(WRITER1_ADDRESS)], "role": "WRITE"}, + ], + }] + raise res + + def open_(addr, deadline): + connection = fake_connection_generator() + connection.unresolved_address = addr + connection.deadline = deadline + route_mock = mocker.MagicMock() + + route_mock.side_effect = routing_side_effect + connection.attach_mock(route_mock, "route") + opener_.connections.append(connection) + return connection + + failures = iter(failures or []) + opener_ = mocker.MagicMock() + opener_.connections = [] + opener_.side_effect = open_ + return opener_ + + return make_opener + + +@pytest.fixture +def opener(custom_routing_opener): + return custom_routing_opener() + + +def _pool_config(): + pool_config = PoolConfig() + return pool_config + + +def _simple_pool(opener) -> Neo4jPool: + return Neo4jPool( + opener, _pool_config(), WorkspaceConfig(), ROUTER1_ADDRESS + ) def test_acquires_new_routing_table_if_deleted(opener): - pool = Neo4jPool(opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS) + pool = _simple_pool(opener) cx = pool.acquire(READ_ACCESS, 30, 60, "test_db", None) pool.release(cx) assert pool.routing_tables.get("test_db") @@ -87,7 +124,7 @@ def test_acquires_new_routing_table_if_deleted(opener): def test_acquires_new_routing_table_if_stale(opener): - pool = Neo4jPool(opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS) + pool = _simple_pool(opener) cx = pool.acquire(READ_ACCESS, 30, 60, "test_db", None) pool.release(cx) assert pool.routing_tables.get("test_db") @@ -101,7 +138,7 @@ def test_acquires_new_routing_table_if_stale(opener): def test_removes_old_routing_table(opener): - pool = Neo4jPool(opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS) + pool = _simple_pool(opener) cx = pool.acquire(READ_ACCESS, 30, 60, "test_db1", None) pool.release(cx) assert pool.routing_tables.get("test_db1") @@ -122,18 +159,18 @@ def test_removes_old_routing_table(opener): @pytest.mark.parametrize("type_", ("r", "w")) def test_chooses_right_connection_type(opener, type_): - pool = Neo4jPool(opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS) + pool = _simple_pool(opener) cx1 = pool.acquire(READ_ACCESS if type_ == "r" else WRITE_ACCESS, 30, 60, "test_db", None) pool.release(cx1) if type_ == "r": - assert cx1.addr == READER_ADDRESS + assert cx1.unresolved_address == READER1_ADDRESS else: - assert cx1.addr == WRITER_ADDRESS + assert cx1.unresolved_address == WRITER1_ADDRESS def test_reuses_connection(opener): - pool = Neo4jPool(opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS) + pool = _simple_pool(opener) cx1 = pool.acquire(READ_ACCESS, 30, 60, "test_db", None) pool.release(cx1) cx2 = pool.acquire(READ_ACCESS, 30, 60, "test_db", None) @@ -143,17 +180,19 @@ def test_reuses_connection(opener): @pytest.mark.parametrize("break_on_close", (True, False)) def test_closes_stale_connections(opener, break_on_close): def break_connection(): - pool.deactivate(cx1.addr) + pool.deactivate(cx1.unresolved_address) if cx_close_mock_side_effect: - cx_close_mock_side_effect() + res = cx_close_mock_side_effect() + if inspect.isawaitable(res): + return res - pool = Neo4jPool(opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS) + pool = _simple_pool(opener) cx1 = pool.acquire(READ_ACCESS, 30, 60, "test_db", None) pool.release(cx1) - assert cx1 in pool.connections[cx1.addr] - # simulate connection going stale (e.g. exceeding) and then breaking when - # the pool tries to close the connection + assert cx1 in pool.connections[cx1.unresolved_address] + # simulate connection going stale (e.g. exceeding idle timeout) and then + # breaking when the pool tries to close the connection cx1.stale.return_value = True cx_close_mock = cx1.close if break_on_close: @@ -166,24 +205,25 @@ def break_connection(): else: cx1.close.assert_called_once() assert cx2 is not cx1 - assert cx2.addr == cx1.addr - assert cx1 not in pool.connections[cx1.addr] - assert cx2 in pool.connections[cx2.addr] + assert cx2.unresolved_address == cx1.unresolved_address + assert cx1 not in pool.connections[cx1.unresolved_address] + assert cx2 in pool.connections[cx2.unresolved_address] def test_does_not_close_stale_connections_in_use(opener): - pool = Neo4jPool(opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS) + pool = _simple_pool(opener) cx1 = pool.acquire(READ_ACCESS, 30, 60, "test_db", None) - assert cx1 in pool.connections[cx1.addr] - # simulate connection going stale (e.g. exceeding) while being in use + assert cx1 in pool.connections[cx1.unresolved_address] + # simulate connection going stale (e.g. exceeding idle timeout) while being + # in use cx1.stale.return_value = True cx2 = pool.acquire(READ_ACCESS, 30, 60, "test_db", None) pool.release(cx2) cx1.close.assert_not_called() assert cx2 is not cx1 - assert cx2.addr == cx1.addr - assert cx1 in pool.connections[cx1.addr] - assert cx2 in pool.connections[cx2.addr] + assert cx2.unresolved_address == cx1.unresolved_address + assert cx1 in pool.connections[cx1.unresolved_address] + assert cx2 in pool.connections[cx2.unresolved_address] pool.release(cx1) # now that cx1 is back in the pool and still stale, @@ -194,13 +234,13 @@ def test_does_not_close_stale_connections_in_use(opener): pool.release(cx3) cx1.close.assert_called_once() assert cx2 is cx3 - assert cx3.addr == cx1.addr - assert cx1 not in pool.connections[cx1.addr] - assert cx3 in pool.connections[cx2.addr] + assert cx3.unresolved_address == cx1.unresolved_address + assert cx1 not in pool.connections[cx1.unresolved_address] + assert cx3 in pool.connections[cx2.unresolved_address] def test_release_resets_connections(opener): - pool = Neo4jPool(opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS) + pool = _simple_pool(opener) cx1 = pool.acquire(READ_ACCESS, 30, 60, "test_db", None) cx1.is_reset_mock.return_value = False cx1.is_reset_mock.reset_mock() @@ -210,40 +250,41 @@ def test_release_resets_connections(opener): def test_release_does_not_resets_closed_connections(opener): - pool = Neo4jPool(opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS) + pool = _simple_pool(opener) cx1 = pool.acquire(READ_ACCESS, 30, 60, "test_db", None) cx1.closed.return_value = True cx1.closed.reset_mock() cx1.is_reset_mock.reset_mock() pool.release(cx1) cx1.closed.assert_called_once() - cx1.is_reset_mock.asset_not_called() - cx1.reset.asset_not_called() + cx1.is_reset_mock.assert_not_called() + cx1.reset.assert_not_called() def test_release_does_not_resets_defunct_connections(opener): - pool = Neo4jPool(opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS) + pool = _simple_pool(opener) cx1 = pool.acquire(READ_ACCESS, 30, 60, "test_db", None) cx1.defunct.return_value = True cx1.defunct.reset_mock() cx1.is_reset_mock.reset_mock() pool.release(cx1) cx1.defunct.assert_called_once() - cx1.is_reset_mock.asset_not_called() - cx1.reset.asset_not_called() + cx1.is_reset_mock.assert_not_called() + cx1.reset.assert_not_called() -def test_multiple_broken_connections_on_close(opener): +def test_multiple_broken_connections_on_close(opener, mocker): def mock_connection_breaks_on_close(cx): def close_side_effect(): cx.closed.return_value = True cx.defunct.return_value = True - pool.deactivate(READER_ADDRESS) + pool.deactivate(READER1_ADDRESS) - cx.attach_mock(Mock(side_effect=close_side_effect), "close") + cx.attach_mock(mocker.MagicMock(side_effect=close_side_effect), + "close") # create pool with 2 idle connections - pool = Neo4jPool(opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS) + pool = _simple_pool(opener) cx1 = pool.acquire(READ_ACCESS, 30, 60, "test_db", None) cx2 = pool.acquire(READ_ACCESS, 30, 60, "test_db", None) pool.release(cx1) @@ -264,37 +305,115 @@ def close_side_effect(): def test_failing_opener_leaves_connections_in_use_alone(opener): - pool = Neo4jPool(opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS) + pool = _simple_pool(opener) cx1 = pool.acquire(READ_ACCESS, 30, 60, "test_db", None) opener.side_effect = ServiceUnavailable("Server overloaded") with pytest.raises((ServiceUnavailable, SessionExpired)): pool.acquire(READ_ACCESS, 30, 60, "test_db", None) - assert not cx1.closed() def test__acquire_new_later_with_room(opener): - config = PoolConfig() + config = _pool_config() config.max_connection_pool_size = 1 pool = Neo4jPool( - opener, config, WorkspaceConfig(), ROUTER_ADDRESS + opener, config, WorkspaceConfig(), ROUTER1_ADDRESS ) - assert pool.connections_reservations[READER_ADDRESS] == 0 - creator = pool._acquire_new_later(READER_ADDRESS, Deadline(1)) - assert pool.connections_reservations[READER_ADDRESS] == 1 + assert pool.connections_reservations[READER1_ADDRESS] == 0 + creator = pool._acquire_new_later(READER1_ADDRESS, Deadline(1)) + assert pool.connections_reservations[READER1_ADDRESS] == 1 assert callable(creator) def test__acquire_new_later_without_room(opener): - config = PoolConfig() + config = _pool_config() config.max_connection_pool_size = 1 pool = Neo4jPool( - opener, config, WorkspaceConfig(), ROUTER_ADDRESS + opener, config, WorkspaceConfig(), ROUTER1_ADDRESS ) _ = pool.acquire(READ_ACCESS, 30, 60, "test_db", None) # pool is full now - assert pool.connections_reservations[READER_ADDRESS] == 0 - creator = pool._acquire_new_later(READER_ADDRESS, Deadline(1)) - assert pool.connections_reservations[READER_ADDRESS] == 0 + assert pool.connections_reservations[READER1_ADDRESS] == 0 + creator = pool._acquire_new_later(READER1_ADDRESS, Deadline(1)) + assert pool.connections_reservations[READER1_ADDRESS] == 0 assert creator is None + + +def test_pool_closes_connections_dropped_from_rt(custom_routing_opener): + readers = {"db1": [str(READER1_ADDRESS)]} + + def get_readers(database): + return readers[database] + + opener = custom_routing_opener(get_readers=get_readers) + + pool = Neo4jPool( + opener, _pool_config(), WorkspaceConfig(), ROUTER1_ADDRESS + ) + cx1 = pool.acquire(READ_ACCESS, 30, 60, "db1", None) + assert cx1.unresolved_address == READER1_ADDRESS + pool.release(cx1) + + cx1.close.assert_not_called() + assert len(pool.connections[READER1_ADDRESS]) == 1 + + # force RT refresh, returning a different reader + del pool.routing_tables["db1"] + readers["db1"] = [str(READER2_ADDRESS)] + + cx2 = pool.acquire(READ_ACCESS, 30, 60, "db1", None) + assert cx2.unresolved_address == READER2_ADDRESS + + cx1.close.assert_called_once() + assert len(pool.connections[READER1_ADDRESS]) == 0 + + pool.release(cx2) + assert len(pool.connections[READER2_ADDRESS]) == 1 + + +def test_pool_does_not_close_connections_dropped_from_rt_for_other_server( + custom_routing_opener +): + readers = { + "db1": [str(READER1_ADDRESS), str(READER2_ADDRESS)], + "db2": [str(READER1_ADDRESS)] + } + + def get_readers(database): + return readers[database] + + opener = custom_routing_opener(get_readers=get_readers) + + pool = Neo4jPool( + opener, _pool_config(), WorkspaceConfig(), ROUTER1_ADDRESS + ) + cx1 = pool.acquire(READ_ACCESS, 30, 60, "db1", None) + pool.release(cx1) + assert cx1.unresolved_address in (READER1_ADDRESS, READER2_ADDRESS) + reader1_connection_count = len(pool.connections[READER1_ADDRESS]) + reader2_connection_count = len(pool.connections[READER2_ADDRESS]) + assert reader1_connection_count + reader2_connection_count == 1 + + cx2 = pool.acquire(READ_ACCESS, 30, 60, "db2", None) + pool.release(cx2) + assert cx2.unresolved_address == READER1_ADDRESS + cx1.close.assert_not_called() + cx2.close.assert_not_called() + assert len(pool.connections[READER1_ADDRESS]) == 1 + assert len(pool.connections[READER2_ADDRESS]) == reader2_connection_count + + # force RT refresh, returning a different reader + del pool.routing_tables["db2"] + readers["db2"] = [str(READER3_ADDRESS)] + + cx3 = pool.acquire(READ_ACCESS, 30, 60, "db2", None) + pool.release(cx3) + assert cx3.unresolved_address == READER3_ADDRESS + + cx1.close.assert_not_called() + cx2.close.assert_not_called() + cx3.close.assert_not_called() + assert len(pool.connections[READER1_ADDRESS]) == 1 + assert len(pool.connections[READER2_ADDRESS]) == reader2_connection_count + assert len(pool.connections[READER3_ADDRESS]) == 1 diff --git a/tests/unit/work/__init__.py b/tests/unit/work/__init__.py index 238e61d3f..46d89dabe 100644 --- a/tests/unit/work/__init__.py +++ b/tests/unit/work/__init__.py @@ -19,6 +19,6 @@ # limitations under the License. from ._fake_connection import ( - FakeConnection, + fake_connection_generator, fake_connection, ) diff --git a/tests/unit/work/_fake_connection.py b/tests/unit/work/_fake_connection.py index 9bc6815ad..1360139f6 100644 --- a/tests/unit/work/_fake_connection.py +++ b/tests/unit/work/_fake_connection.py @@ -20,7 +20,6 @@ import inspect -from unittest import mock import pytest @@ -28,87 +27,93 @@ from neo4j._deadline import Deadline -class FakeConnection(mock.NonCallableMagicMock): - callbacks = [] - server_info = ServerInfo("127.0.0.1", (4, 3)) - local_port = 1234 - bolt_patches = set() - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.attach_mock(mock.Mock(return_value=True), "is_reset_mock") - self.attach_mock(mock.Mock(return_value=False), "defunct") - self.attach_mock(mock.Mock(return_value=False), "stale") - self.attach_mock(mock.Mock(return_value=False), "closed") - self.attach_mock(mock.Mock(return_value=False), "socket") - self.socket.attach_mock( - mock.Mock(return_value=None), "get_deadline" - ) - - def set_deadline_side_effect(deadline): - deadline = Deadline.from_timeout_or_deadline(deadline) - self.socket.get_deadline.return_value = deadline - - self.socket.attach_mock( - mock.Mock(side_effect=set_deadline_side_effect), "set_deadline" - ) - - def close_side_effect(): - self.closed.return_value = True - - self.attach_mock(mock.Mock(side_effect=close_side_effect), "close") - - @property - def is_reset(self): - if self.closed.return_value or self.defunct.return_value: - raise AssertionError("is_reset should not be called on a closed or " - "defunct connection.") - return self.is_reset_mock() - - def fetch_message(self, *args, **kwargs): - if self.callbacks: - cb = self.callbacks.pop(0) - cb() - return super().__getattr__("fetch_message")(*args, **kwargs) - - def fetch_all(self, *args, **kwargs): - while self.callbacks: - cb = self.callbacks.pop(0) - cb() - return super().__getattr__("fetch_all")(*args, **kwargs) - - def __getattr__(self, name): - parent = super() - - def build_message_handler(name): - def func(*args, **kwargs): - def callback(): - for cb_name, param_count in ( - ("on_success", 1), - ("on_summary", 0) - ): - cb = kwargs.get(cb_name, None) - if callable(cb): - try: - param_count = \ - len(inspect.signature(cb).parameters) - except ValueError: - # e.g. built-in method as cb - pass - if param_count == 1: - cb({}) - else: - cb() - self.callbacks.append(callback) - - return func - - method_mock = parent.__getattr__(name) - if name in ("run", "commit", "pull", "rollback", "discard"): - method_mock.side_effect = build_message_handler(name) - return method_mock +@pytest.fixture +def fake_connection_generator(session_mocker): + mock = session_mocker.mock_module + + class FakeConnection(mock.NonCallableMagicMock): + callbacks = [] + server_info = ServerInfo("127.0.0.1", (4, 3)) + local_port = 1234 + bolt_patches = set() + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.attach_mock(mock.Mock(return_value=True), "is_reset_mock") + self.attach_mock(mock.Mock(return_value=False), "defunct") + self.attach_mock(mock.Mock(return_value=False), "stale") + self.attach_mock(mock.Mock(return_value=False), "closed") + self.attach_mock(mock.Mock(return_value=False), "socket") + self.socket.attach_mock( + mock.Mock(return_value=None), "get_deadline" + ) + + def set_deadline_side_effect(deadline): + deadline = Deadline.from_timeout_or_deadline(deadline) + self.socket.get_deadline.return_value = deadline + + self.socket.attach_mock( + mock.Mock(side_effect=set_deadline_side_effect), "set_deadline" + ) + + def close_side_effect(): + self.closed.return_value = True + + self.attach_mock(mock.Mock(side_effect=close_side_effect), "close") + + @property + def is_reset(self): + if self.closed.return_value or self.defunct.return_value: + raise AssertionError("is_reset should not be called on a closed or " + "defunct connection.") + return self.is_reset_mock() + + def fetch_message(self, *args, **kwargs): + if self.callbacks: + cb = self.callbacks.pop(0) + cb() + return super().__getattr__("fetch_message")(*args, **kwargs) + + def fetch_all(self, *args, **kwargs): + while self.callbacks: + cb = self.callbacks.pop(0) + cb() + return super().__getattr__("fetch_all")(*args, **kwargs) + + def __getattr__(self, name): + parent = super() + + def build_message_handler(name): + def func(*args, **kwargs): + def callback(): + for cb_name, param_count in ( + ("on_success", 1), + ("on_summary", 0) + ): + cb = kwargs.get(cb_name, None) + if callable(cb): + try: + param_count = \ + len(inspect.signature(cb).parameters) + except ValueError: + # e.g. built-in method as cb + pass + if param_count == 1: + cb({}) + else: + cb() + self.callbacks.append(callback) + + return func + + method_mock = parent.__getattr__(name) + if name in ("run", "commit", "pull", "rollback", "discard"): + method_mock.side_effect = build_message_handler(name) + return method_mock + + return FakeConnection @pytest.fixture -def fake_connection(): - return FakeConnection() +def fake_connection(fake_connection_generator): + return fake_connection_generator() diff --git a/tests/unit/work/test_session.py b/tests/unit/work/test_session.py index d43048699..b3004bf6b 100644 --- a/tests/unit/work/test_session.py +++ b/tests/unit/work/test_session.py @@ -30,17 +30,17 @@ ) from neo4j.io import IOPool -from ._fake_connection import FakeConnection +from ._fake_connection import fake_connection_generator @pytest.fixture -def pool(mocker): +def pool(mocker, fake_connection_generator): pool = mocker.Mock(spec=IOPool) assert not hasattr(pool, "acquired_connection_mocks") pool.acquired_connection_mocks = [] def acquire_side_effect(*_, **__): - connection = FakeConnection() + connection = fake_connection_generator() pool.acquired_connection_mocks.append(connection) return connection From 27690b059c949ff93a77185d2f6bbe4d720948f5 Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Fri, 12 Apr 2024 15:22:56 +0200 Subject: [PATCH 3/3] Tests: fix fixture discovery --- tests/unit/conftest.py | 19 +++++++++++++++ tests/unit/fixtures/__init__.py | 19 +++++++++++++++ .../{work => fixtures}/_fake_connection.py | 9 ++++--- tests/unit/io/test__common.py | 2 -- tests/unit/io/test_neo4j_pool.py | 2 -- tests/unit/work/__init__.py | 24 ------------------- tests/unit/work/test_session.py | 2 -- tests/unit/work/test_transaction.py | 2 -- 8 files changed, 44 insertions(+), 35 deletions(-) create mode 100644 tests/unit/conftest.py create mode 100644 tests/unit/fixtures/__init__.py rename tests/unit/{work => fixtures}/_fake_connection.py (98%) diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py new file mode 100644 index 000000000..c03179758 --- /dev/null +++ b/tests/unit/conftest.py @@ -0,0 +1,19 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from .fixtures import * # necessary for pytest to discover the fixtures diff --git a/tests/unit/fixtures/__init__.py b/tests/unit/fixtures/__init__.py new file mode 100644 index 000000000..f17c825e2 --- /dev/null +++ b/tests/unit/fixtures/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from ._fake_connection import * diff --git a/tests/unit/work/_fake_connection.py b/tests/unit/fixtures/_fake_connection.py similarity index 98% rename from tests/unit/work/_fake_connection.py rename to tests/unit/fixtures/_fake_connection.py index 1360139f6..36d610c0c 100644 --- a/tests/unit/work/_fake_connection.py +++ b/tests/unit/fixtures/_fake_connection.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - # Copyright (c) "Neo4j" # Neo4j Sweden AB [http://neo4j.com] # @@ -27,6 +24,12 @@ from neo4j._deadline import Deadline +__all__ = [ + "fake_connection", + "fake_connection_generator", +] + + @pytest.fixture def fake_connection_generator(session_mocker): mock = session_mocker.mock_module diff --git a/tests/unit/io/test__common.py b/tests/unit/io/test__common.py index 72f9ee921..b70357c17 100644 --- a/tests/unit/io/test__common.py +++ b/tests/unit/io/test__common.py @@ -25,8 +25,6 @@ ResetResponse, ) -from ..work import fake_connection - @pytest.mark.parametrize(("chunk_size", "data", "result"), ( ( diff --git a/tests/unit/io/test_neo4j_pool.py b/tests/unit/io/test_neo4j_pool.py index ffbca0702..4d2eff54b 100644 --- a/tests/unit/io/test_neo4j_pool.py +++ b/tests/unit/io/test_neo4j_pool.py @@ -38,8 +38,6 @@ ) from neo4j.io import Neo4jPool -from ..work import fake_connection_generator - ROUTER1_ADDRESS = ResolvedAddress(("1.2.3.1", 9000), host_name="host") ROUTER2_ADDRESS = ResolvedAddress(("1.2.3.1", 9001), host_name="host") diff --git a/tests/unit/work/__init__.py b/tests/unit/work/__init__.py index 46d89dabe..e69de29bb 100644 --- a/tests/unit/work/__init__.py +++ b/tests/unit/work/__init__.py @@ -1,24 +0,0 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - -# Copyright (c) "Neo4j" -# Neo4j Sweden AB [http://neo4j.com] -# -# This file is part of Neo4j. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from ._fake_connection import ( - fake_connection_generator, - fake_connection, -) diff --git a/tests/unit/work/test_session.py b/tests/unit/work/test_session.py index b3004bf6b..64a9f494b 100644 --- a/tests/unit/work/test_session.py +++ b/tests/unit/work/test_session.py @@ -30,8 +30,6 @@ ) from neo4j.io import IOPool -from ._fake_connection import fake_connection_generator - @pytest.fixture def pool(mocker, fake_connection_generator): diff --git a/tests/unit/work/test_transaction.py b/tests/unit/work/test_transaction.py index 06e755662..3ac206ce2 100644 --- a/tests/unit/work/test_transaction.py +++ b/tests/unit/work/test_transaction.py @@ -31,8 +31,6 @@ Transaction, ) -from ._fake_connection import fake_connection - @pytest.mark.parametrize(("explicit_commit", "close"), ( (False, False),