diff --git a/neo4j/io/_bolt3.py b/neo4j/io/_bolt3.py index 24a65368c..0a9cf512d 100644 --- a/neo4j/io/_bolt3.py +++ b/neo4j/io/_bolt3.py @@ -61,28 +61,43 @@ class ServerStates(Enum): FAILED = "FAILED" -STATE_TRANSITIONS = { - ServerStates.CONNECTED: { - "hello": ServerStates.READY, - }, - ServerStates.READY: { - "run": ServerStates.STREAMING, - "begin": ServerStates.TX_READY_OR_TX_STREAMING, - }, - ServerStates.STREAMING: { - "pull": ServerStates.READY, - "discard": ServerStates.READY, - "reset": ServerStates.READY, - }, - ServerStates.TX_READY_OR_TX_STREAMING: { - "commit": ServerStates.READY, - "rollback": ServerStates.READY, - "reset": ServerStates.READY, - }, - ServerStates.FAILED: { - "reset": ServerStates.READY, +class ServerStateManager: + _STATE_TRANSITIONS = { + ServerStates.CONNECTED: { + "hello": ServerStates.READY, + }, + ServerStates.READY: { + "run": ServerStates.STREAMING, + "begin": ServerStates.TX_READY_OR_TX_STREAMING, + }, + ServerStates.STREAMING: { + "pull": ServerStates.READY, + "discard": ServerStates.READY, + "reset": ServerStates.READY, + }, + ServerStates.TX_READY_OR_TX_STREAMING: { + "commit": ServerStates.READY, + "rollback": ServerStates.READY, + "reset": ServerStates.READY, + }, + ServerStates.FAILED: { + "reset": ServerStates.READY, + } } -} + + def __init__(self, init_state, on_change=None): + self.state = init_state + self._on_change = on_change + + def transition(self, message, metadata): + if metadata.get("has_more"): + return + state_before = self.state + self.state = self._STATE_TRANSITIONS\ + .get(self.state, {})\ + .get(message, self.state) + if state_before != self.state and callable(self._on_change): + self._on_change(state_before, self.state) class Bolt3(Bolt): @@ -97,7 +112,15 @@ class Bolt3(Bolt): supports_multiple_databases = False - _server_state = ServerStates.CONNECTED + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._server_state_manager = ServerStateManager( + ServerStates.CONNECTED, on_change=self._on_server_state_change + ) + + def _on_server_state_change(self, old_state, new_state): + log.debug("[#%04X] State: %s > %s", self.local_port, + old_state.name, new_state.name) @property def is_reset(self): @@ -105,7 +128,7 @@ def is_reset(self): # we can't be sure of the server's state as there are still pending # responses. return False - return self._server_state == ServerStates.READY + return self._server_state_manager.state == ServerStates.READY @property def encrypted(self): @@ -213,7 +236,6 @@ def pull(self, n=-1, qid=-1, **handlers): # Just ignore n and qid, it is not supported in the Bolt 3 Protocol. log.debug("[#%04X] C: PULL_ALL", self.local_port) self._append(b"\x3F", (), Response(self, "pull", **handlers)) - self._is_reset = False def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None, db=None, **handlers): if db is not None: @@ -238,7 +260,6 @@ def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None, db=None, raise TypeError("Timeout must be specified as a number of seconds") log.debug("[#%04X] C: BEGIN %r", self.local_port, extra) self._append(b"\x11", (extra,), Response(self, "begin", **handlers)) - self._is_reset = False def commit(self, **handlers): log.debug("[#%04X] C: COMMIT", self.local_port) @@ -260,18 +281,6 @@ def fail(metadata): self._append(b"\x0F", response=Response(self, "reset", on_failure=fail)) self.send_all() self.fetch_all() - self._is_reset = True - - def _update_server_state_on_success(self, metadata, message): - if metadata.get("has_more"): - return - state_before = self._server_state - self._server_state = STATE_TRANSITIONS\ - .get(self._server_state, {})\ - .get(message, self._server_state) - if state_before != self._server_state: - log.debug("[#%04X] State: %s", self.local_port, - self._server_state.name) def fetch_message(self): """ Receive at most one message from the server, if available. @@ -304,15 +313,15 @@ def fetch_message(self): response.complete = True if summary_signature == b"\x70": log.debug("[#%04X] S: SUCCESS %r", self.local_port, summary_metadata) - self._update_server_state_on_success(summary_metadata, - response.message) + self._server_state_manager.transition(response.message, + summary_metadata) response.on_success(summary_metadata or {}) elif summary_signature == b"\x7E": log.debug("[#%04X] S: IGNORED", self.local_port) response.on_ignored(summary_metadata or {}) elif summary_signature == b"\x7F": log.debug("[#%04X] S: FAILURE %r", self.local_port, summary_metadata) - self._server_state = ServerStates.FAILED + self._server_state_manager.state = ServerStates.FAILED try: response.on_failure(summary_metadata or {}) except (ServiceUnavailable, DatabaseUnavailable): diff --git a/neo4j/io/_bolt4.py b/neo4j/io/_bolt4.py index a9d476da2..d966a6972 100644 --- a/neo4j/io/_bolt4.py +++ b/neo4j/io/_bolt4.py @@ -49,8 +49,8 @@ Response, ) from neo4j.io._bolt3 import ( + ServerStateManager, ServerStates, - STATE_TRANSITIONS, ) @@ -69,7 +69,15 @@ class Bolt4x0(Bolt): supports_multiple_databases = True - _server_state = ServerStates.CONNECTED + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._server_state_manager = ServerStateManager( + ServerStates.CONNECTED, on_change=self._on_server_state_change + ) + + def _on_server_state_change(self, old_state, new_state): + log.debug("[#%04X] State: %s > %s", self.local_port, + old_state.name, new_state.name) @property def is_reset(self): @@ -77,7 +85,7 @@ def is_reset(self): # we can't be sure of the server's state as there are still pending # responses. return False - return self._server_state == ServerStates.READY + return self._server_state_manager.state == ServerStates.READY @property def encrypted(self): @@ -181,7 +189,6 @@ def run(self, query, parameters=None, mode=None, bookmarks=None, metadata=None, **handlers)) else: self._append(b"\x10", fields, Response(self, "run", **handlers)) - self._is_reset = False def discard(self, n=-1, qid=-1, **handlers): extra = {"n": n} @@ -196,7 +203,6 @@ def pull(self, n=-1, qid=-1, **handlers): extra["qid"] = qid log.debug("[#%04X] C: PULL %r", self.local_port, extra) self._append(b"\x3F", (extra,), Response(self, "pull", **handlers)) - self._is_reset = False def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None, db=None, **handlers): @@ -222,7 +228,6 @@ def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None, raise TypeError("Timeout must be specified as a number of seconds") log.debug("[#%04X] C: BEGIN %r", self.local_port, extra) self._append(b"\x11", (extra,), Response(self, "begin", **handlers)) - self._is_reset = False def commit(self, **handlers): log.debug("[#%04X] C: COMMIT", self.local_port) @@ -244,18 +249,6 @@ def fail(metadata): self._append(b"\x0F", response=Response(self, "reset", on_failure=fail)) self.send_all() self.fetch_all() - self._is_reset = True - - def _update_server_state_on_success(self, metadata, message): - if metadata.get("has_more"): - return - state_before = self._server_state - self._server_state = STATE_TRANSITIONS\ - .get(self._server_state, {})\ - .get(message, self._server_state) - if state_before != self._server_state: - log.debug("[#%04X] [%s]", self.local_port, - self._server_state.name) def fetch_message(self): """ Receive at most one message from the server, if available. @@ -288,15 +281,15 @@ def fetch_message(self): response.complete = True if summary_signature == b"\x70": log.debug("[#%04X] S: SUCCESS %r", self.local_port, summary_metadata) - self._update_server_state_on_success(summary_metadata, - response.message) + self._server_state_manager.transition(response.message, + summary_metadata) response.on_success(summary_metadata or {}) elif summary_signature == b"\x7E": log.debug("[#%04X] S: IGNORED", self.local_port) response.on_ignored(summary_metadata or {}) elif summary_signature == b"\x7F": log.debug("[#%04X] S: FAILURE %r", self.local_port, summary_metadata) - self._server_state = ServerStates.FAILED + self._server_state_manager.state = ServerStates.FAILED try: response.on_failure(summary_metadata or {}) except (ServiceUnavailable, DatabaseUnavailable):