Skip to content

Commit

Permalink
[4.4] Invalidate writers per database (#1039)
Browse files Browse the repository at this point in the history
* Invalidate writers per database

This should improve the performance of the driver in multi database use-cases.
The driver now only removes a server as a writer for a single database (before
for all databases) if that server returns an error that notifies the driver that
the server is no longer a writer (`Neo.ClientError.Cluster.NotALeader` or
`Neo.ClientError.General.ForbiddenOnReadOnlyDatabase`).

* Minor code clean-up

Co-authored-by: Antonio Barcélos <[email protected]>
  • Loading branch information
robsdedude and bigmontz authored Apr 12, 2024
1 parent 9f5c495 commit 0823655
Show file tree
Hide file tree
Showing 10 changed files with 539 additions and 37 deletions.
36 changes: 30 additions & 6 deletions neo4j/io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,16 @@
log = getLogger("neo4j")


class ClientStateManagerBase(abc.ABC):
@abc.abstractmethod
def __init__(self, init_state, on_change=None):
...

@abc.abstractmethod
def transition(self, message):
...


class Bolt(abc.ABC):
""" Server connection for Bolt protocol.
Expand All @@ -125,6 +135,10 @@ class Bolt(abc.ABC):
# The socket
in_use = False

# The database name the connection was last used with
# (BEGIN for explicit transactions, RUN for auto-commit transactions)
last_database = None

# The socket
_closing = False
_closed = False
Expand Down Expand Up @@ -399,6 +413,10 @@ def __del__(self):
except OSError:
pass

@abc.abstractmethod
def _get_client_state_manager(self):
...

@abc.abstractmethod
def route(self, database=None, imp_user=None, bookmarks=None):
""" Fetch a routing table from the server for the given
Expand Down Expand Up @@ -504,6 +522,8 @@ def _append(self, signature, fields=(), response=None):
self.packer.pack_struct(signature, fields)
self.outbox.wrap_message()
self.responses.append(response)
if response:
self._get_client_state_manager().transition(response.message)

def _send_all(self):
with self.outbox.view() as data:
Expand Down Expand Up @@ -867,8 +887,10 @@ def deactivate(self, address):
if not self.connections[address]:
del self.connections[address]

def on_write_failure(self, address):
raise WriteServiceUnavailable("No write service available for pool {}".format(self))
def on_write_failure(self, address, database):
raise WriteServiceUnavailable(
"No write service available for pool {}".format(self)
)

def close(self):
""" Close all connections and empty the pool.
Expand Down Expand Up @@ -1342,13 +1364,15 @@ def deactivate(self, address):
log.debug("[#0000] C: <ROUTING> table=%r", self.routing_tables)
super(Neo4jPool, self).deactivate(address)

def on_write_failure(self, address):
def on_write_failure(self, address, database):
""" Remove a writer address from the routing table, if present.
"""
log.debug("[#0000] C: <ROUTING> Removing writer %r", address)
log.debug("[#0000] C: <ROUTING> Removing writer %r for database %r",
address, database)
with self.refresh_lock:
for database in self.routing_tables.keys():
self.routing_tables[database].writers.discard(address)
table = self.routing_tables.get(database)
if table is not None:
table.writers.discard(address)
log.debug("[#0000] C: <ROUTING> table=%r", self.routing_tables)


Expand Down
29 changes: 29 additions & 0 deletions neo4j/io/_bolt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [http://neo4j.com]
#
# This file is part of Neo4j.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import abc


class ClientStateManagerBase(abc.ABC):
@abc.abstractmethod
def __init__(self, init_state, on_change=None):
...

@abc.abstractmethod
def transition(self, message):
...
89 changes: 68 additions & 21 deletions neo4j/io/_bolt3.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
Bolt,
check_supported_server_product,
)
from neo4j.io._bolt import ClientStateManagerBase
from neo4j.io._common import (
CommitResponse,
InitResponse,
Expand All @@ -55,7 +56,7 @@
log = getLogger("neo4j")


class ServerStates(Enum):
class BoltStates(Enum):
CONNECTED = "CONNECTED"
READY = "READY"
STREAMING = "STREAMING"
Expand All @@ -65,25 +66,25 @@ class ServerStates(Enum):

class ServerStateManager:
_STATE_TRANSITIONS = {
ServerStates.CONNECTED: {
"hello": ServerStates.READY,
BoltStates.CONNECTED: {
"hello": BoltStates.READY,
},
ServerStates.READY: {
"run": ServerStates.STREAMING,
"begin": ServerStates.TX_READY_OR_TX_STREAMING,
BoltStates.READY: {
"run": BoltStates.STREAMING,
"begin": BoltStates.TX_READY_OR_TX_STREAMING,
},
ServerStates.STREAMING: {
"pull": ServerStates.READY,
"discard": ServerStates.READY,
"reset": ServerStates.READY,
BoltStates.STREAMING: {
"pull": BoltStates.READY,
"discard": BoltStates.READY,
"reset": BoltStates.READY,
},
ServerStates.TX_READY_OR_TX_STREAMING: {
"commit": ServerStates.READY,
"rollback": ServerStates.READY,
"reset": ServerStates.READY,
BoltStates.TX_READY_OR_TX_STREAMING: {
"commit": BoltStates.READY,
"rollback": BoltStates.READY,
"reset": BoltStates.READY,
},
ServerStates.FAILED: {
"reset": ServerStates.READY,
BoltStates.FAILED: {
"reset": BoltStates.READY,
}
}

Expand All @@ -102,6 +103,39 @@ def transition(self, message, metadata):
self._on_change(state_before, self.state)


class ClientStateManager(ClientStateManagerBase):
_STATE_TRANSITIONS = {
BoltStates.CONNECTED: {
"hello": BoltStates.READY,
},
BoltStates.READY: {
"run": BoltStates.STREAMING,
"begin": BoltStates.TX_READY_OR_TX_STREAMING,
},
BoltStates.STREAMING: {
"begin": BoltStates.TX_READY_OR_TX_STREAMING,
"reset": BoltStates.READY,
},
BoltStates.TX_READY_OR_TX_STREAMING: {
"commit": BoltStates.READY,
"rollback": BoltStates.READY,
"reset": BoltStates.READY,
},
}

def __init__(self, init_state, on_change=None):
self.state = init_state
self._on_change = on_change

def transition(self, message):
state_before = self.state
self.state = self._STATE_TRANSITIONS \
.get(self.state, {}) \
.get(message, self.state)
if state_before != self.state and callable(self._on_change):
self._on_change(state_before, self.state)


class Bolt3(Bolt):
""" Protocol handler for Bolt 3.
Expand All @@ -117,13 +151,23 @@ class Bolt3(Bolt):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._server_state_manager = ServerStateManager(
ServerStates.CONNECTED, on_change=self._on_server_state_change
BoltStates.CONNECTED, on_change=self._on_server_state_change
)
self._client_state_manager = ClientStateManager(
BoltStates.CONNECTED, on_change=self._on_client_state_change
)

def _on_server_state_change(self, old_state, new_state):
log.debug("[#%04X] State: %s > %s", self.local_port,
log.debug("[#%04X] Server State: %s > %s", self.local_port,
old_state.name, new_state.name)

def _on_client_state_change(self, old_state, new_state):
log.debug("[#%04X] Client state: %s > %s",
self.local_port, old_state.name, new_state.name)

def _get_client_state_manager(self):
return self._client_state_manager

@property
def is_reset(self):
# We can't be sure of the server's state if there are still pending
Expand All @@ -132,7 +176,7 @@ def is_reset(self):
if (self.responses and self.responses[-1]
and self.responses[-1].message == "reset"):
return True
return self._server_state_manager.state == ServerStates.READY
return self._server_state_manager.state == BoltStates.READY

@property
def encrypted(self):
Expand Down Expand Up @@ -342,7 +386,7 @@ def fetch_message(self):
response.on_ignored(summary_metadata or {})
elif summary_signature == b"\x7F":
log.debug("[#%04X] S: FAILURE %r", self.local_port, summary_metadata)
self._server_state_manager.state = ServerStates.FAILED
self._server_state_manager.state = BoltStates.FAILED
try:
response.on_failure(summary_metadata or {})
except (ServiceUnavailable, DatabaseUnavailable):
Expand All @@ -351,7 +395,10 @@ def fetch_message(self):
raise
except (NotALeader, ForbiddenOnReadOnlyDatabase):
if self.pool:
self.pool.on_write_failure(address=self.unresolved_address),
self.pool.on_write_failure(
address=self.unresolved_address,
database=self.last_database,
),
raise
except Neo4jError as e:
if self.pool and e.invalidates_all_connections():
Expand Down
34 changes: 28 additions & 6 deletions neo4j/io/_bolt4.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,9 @@
tx_timeout_as_ms,
)
from neo4j.io._bolt3 import (
BoltStates,
ClientStateManager,
ServerStateManager,
ServerStates,
)


Expand All @@ -74,13 +75,23 @@ class Bolt4x0(Bolt):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._server_state_manager = ServerStateManager(
ServerStates.CONNECTED, on_change=self._on_server_state_change
BoltStates.CONNECTED, on_change=self._on_server_state_change
)
self._client_state_manager = ClientStateManager(
BoltStates.CONNECTED, on_change=self._on_client_state_change
)

def _on_server_state_change(self, old_state, new_state):
log.debug("[#%04X] State: %s > %s", self.local_port,
log.debug("[#%04X] Server state: %s > %s", self.local_port,
old_state.name, new_state.name)

def _on_client_state_change(self, old_state, new_state):
log.debug("[#%04X] Client state: %s > %s",
self.local_port, old_state.name, new_state.name)

def _get_client_state_manager(self):
return self._client_state_manager

@property
def is_reset(self):
# We can't be sure of the server's state if there are still pending
Expand All @@ -89,7 +100,7 @@ def is_reset(self):
if (self.responses and self.responses[-1]
and self.responses[-1].message == "reset"):
return True
return self._server_state_manager.state == ServerStates.READY
return self._server_state_manager.state == BoltStates.READY

@property
def encrypted(self):
Expand Down Expand Up @@ -169,6 +180,9 @@ def run(self, query, parameters=None, mode=None, bookmarks=None,
extra["mode"] = "r" # It will default to mode "w" if nothing is specified
if db:
extra["db"] = db
client_state = self._client_state_manager.state
if client_state != BoltStates.TX_READY_OR_TX_STREAMING:
self.last_database = db
if bookmarks:
try:
extra["bookmarks"] = list(bookmarks)
Expand Down Expand Up @@ -217,6 +231,7 @@ def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None,
extra["mode"] = "r" # It will default to mode "w" if nothing is specified
if db:
extra["db"] = db
self.last_database = db
if bookmarks:
try:
extra["bookmarks"] = list(bookmarks)
Expand Down Expand Up @@ -294,7 +309,7 @@ def fetch_message(self):
response.on_ignored(summary_metadata or {})
elif summary_signature == b"\x7F":
log.debug("[#%04X] S: FAILURE %r", self.local_port, summary_metadata)
self._server_state_manager.state = ServerStates.FAILED
self._server_state_manager.state = BoltStates.FAILED
try:
response.on_failure(summary_metadata or {})
except (ServiceUnavailable, DatabaseUnavailable):
Expand All @@ -303,7 +318,10 @@ def fetch_message(self):
raise
except (NotALeader, ForbiddenOnReadOnlyDatabase):
if self.pool:
self.pool.on_write_failure(address=self.unresolved_address),
self.pool.on_write_failure(
address=self.unresolved_address,
database=self.last_database,
),
raise
except Neo4jError as e:
if self.pool and e.invalidates_all_connections():
Expand Down Expand Up @@ -471,6 +489,9 @@ def run(self, query, parameters=None, mode=None, bookmarks=None,
extra["mode"] = "r"
if db:
extra["db"] = db
client_state = self._client_state_manager.state
if client_state != BoltStates.TX_READY_OR_TX_STREAMING:
self.last_database = db
if imp_user:
extra["imp_user"] = imp_user
if bookmarks:
Expand Down Expand Up @@ -502,6 +523,7 @@ def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None,
extra["mode"] = "r"
if db:
extra["db"] = db
self.last_database = db
if imp_user:
extra["imp_user"] = imp_user
if bookmarks:
Expand Down
Loading

0 comments on commit 0823655

Please sign in to comment.