Skip to content

Commit

Permalink
Code clean-up
Browse files Browse the repository at this point in the history
  • Loading branch information
robsdedude committed Aug 3, 2021
1 parent 315077f commit 74d346e
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 61 deletions.
89 changes: 49 additions & 40 deletions neo4j/io/_bolt3.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,28 +61,43 @@ class ServerStates(Enum):
FAILED = "FAILED"


STATE_TRANSITIONS = {
ServerStates.CONNECTED: {
"hello": ServerStates.READY,
},
ServerStates.READY: {
"run": ServerStates.STREAMING,
"begin": ServerStates.TX_READY_OR_TX_STREAMING,
},
ServerStates.STREAMING: {
"pull": ServerStates.READY,
"discard": ServerStates.READY,
"reset": ServerStates.READY,
},
ServerStates.TX_READY_OR_TX_STREAMING: {
"commit": ServerStates.READY,
"rollback": ServerStates.READY,
"reset": ServerStates.READY,
},
ServerStates.FAILED: {
"reset": ServerStates.READY,
class ServerStateManager:
_STATE_TRANSITIONS = {
ServerStates.CONNECTED: {
"hello": ServerStates.READY,
},
ServerStates.READY: {
"run": ServerStates.STREAMING,
"begin": ServerStates.TX_READY_OR_TX_STREAMING,
},
ServerStates.STREAMING: {
"pull": ServerStates.READY,
"discard": ServerStates.READY,
"reset": ServerStates.READY,
},
ServerStates.TX_READY_OR_TX_STREAMING: {
"commit": ServerStates.READY,
"rollback": ServerStates.READY,
"reset": ServerStates.READY,
},
ServerStates.FAILED: {
"reset": ServerStates.READY,
}
}
}

def __init__(self, init_state, on_change=None):
self.state = init_state
self._on_change = on_change

def transition(self, message, metadata):
if metadata.get("has_more"):
return
state_before = self.state
self.state = self._STATE_TRANSITIONS\
.get(self.state, {})\
.get(message, self.state)
if state_before != self.state and callable(self._on_change):
self._on_change(state_before, self.state)


class Bolt3(Bolt):
Expand All @@ -97,15 +112,23 @@ class Bolt3(Bolt):

supports_multiple_databases = False

_server_state = ServerStates.CONNECTED
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._server_state_manager = ServerStateManager(
ServerStates.CONNECTED, on_change=self._on_server_state_change
)

def _on_server_state_change(self, old_state, new_state):
log.debug("[#%04X] State: %s > %s", self.local_port,
old_state.name, new_state.name)

@property
def is_reset(self):
if self.responses:
# we can't be sure of the server's state as there are still pending
# responses.
return False
return self._server_state == ServerStates.READY
return self._server_state_manager.state == ServerStates.READY

@property
def encrypted(self):
Expand Down Expand Up @@ -213,7 +236,6 @@ def pull(self, n=-1, qid=-1, **handlers):
# Just ignore n and qid, it is not supported in the Bolt 3 Protocol.
log.debug("[#%04X] C: PULL_ALL", self.local_port)
self._append(b"\x3F", (), Response(self, "pull", **handlers))
self._is_reset = False

def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None, db=None, **handlers):
if db is not None:
Expand All @@ -238,7 +260,6 @@ def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None, db=None,
raise TypeError("Timeout must be specified as a number of seconds")
log.debug("[#%04X] C: BEGIN %r", self.local_port, extra)
self._append(b"\x11", (extra,), Response(self, "begin", **handlers))
self._is_reset = False

def commit(self, **handlers):
log.debug("[#%04X] C: COMMIT", self.local_port)
Expand All @@ -260,18 +281,6 @@ def fail(metadata):
self._append(b"\x0F", response=Response(self, "reset", on_failure=fail))
self.send_all()
self.fetch_all()
self._is_reset = True

def _update_server_state_on_success(self, metadata, message):
if metadata.get("has_more"):
return
state_before = self._server_state
self._server_state = STATE_TRANSITIONS\
.get(self._server_state, {})\
.get(message, self._server_state)
if state_before != self._server_state:
log.debug("[#%04X] State: %s", self.local_port,
self._server_state.name)

def fetch_message(self):
""" Receive at most one message from the server, if available.
Expand Down Expand Up @@ -304,15 +313,15 @@ def fetch_message(self):
response.complete = True
if summary_signature == b"\x70":
log.debug("[#%04X] S: SUCCESS %r", self.local_port, summary_metadata)
self._update_server_state_on_success(summary_metadata,
response.message)
self._server_state_manager.transition(response.message,
summary_metadata)
response.on_success(summary_metadata or {})
elif summary_signature == b"\x7E":
log.debug("[#%04X] S: IGNORED", self.local_port)
response.on_ignored(summary_metadata or {})
elif summary_signature == b"\x7F":
log.debug("[#%04X] S: FAILURE %r", self.local_port, summary_metadata)
self._server_state = ServerStates.FAILED
self._server_state_manager.state = ServerStates.FAILED
try:
response.on_failure(summary_metadata or {})
except (ServiceUnavailable, DatabaseUnavailable):
Expand Down
35 changes: 14 additions & 21 deletions neo4j/io/_bolt4.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@
Response,
)
from neo4j.io._bolt3 import (
ServerStateManager,
ServerStates,
STATE_TRANSITIONS,
)


Expand All @@ -69,15 +69,23 @@ class Bolt4x0(Bolt):

supports_multiple_databases = True

_server_state = ServerStates.CONNECTED
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._server_state_manager = ServerStateManager(
ServerStates.CONNECTED, on_change=self._on_server_state_change
)

def _on_server_state_change(self, old_state, new_state):
log.debug("[#%04X] State: %s > %s", self.local_port,
old_state.name, new_state.name)

@property
def is_reset(self):
if self.responses:
# we can't be sure of the server's state as there are still pending
# responses.
return False
return self._server_state == ServerStates.READY
return self._server_state_manager.state == ServerStates.READY

@property
def encrypted(self):
Expand Down Expand Up @@ -181,7 +189,6 @@ def run(self, query, parameters=None, mode=None, bookmarks=None, metadata=None,
**handlers))
else:
self._append(b"\x10", fields, Response(self, "run", **handlers))
self._is_reset = False

def discard(self, n=-1, qid=-1, **handlers):
extra = {"n": n}
Expand All @@ -196,7 +203,6 @@ def pull(self, n=-1, qid=-1, **handlers):
extra["qid"] = qid
log.debug("[#%04X] C: PULL %r", self.local_port, extra)
self._append(b"\x3F", (extra,), Response(self, "pull", **handlers))
self._is_reset = False

def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None,
db=None, **handlers):
Expand All @@ -222,7 +228,6 @@ def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None,
raise TypeError("Timeout must be specified as a number of seconds")
log.debug("[#%04X] C: BEGIN %r", self.local_port, extra)
self._append(b"\x11", (extra,), Response(self, "begin", **handlers))
self._is_reset = False

def commit(self, **handlers):
log.debug("[#%04X] C: COMMIT", self.local_port)
Expand All @@ -244,18 +249,6 @@ def fail(metadata):
self._append(b"\x0F", response=Response(self, "reset", on_failure=fail))
self.send_all()
self.fetch_all()
self._is_reset = True

def _update_server_state_on_success(self, metadata, message):
if metadata.get("has_more"):
return
state_before = self._server_state
self._server_state = STATE_TRANSITIONS\
.get(self._server_state, {})\
.get(message, self._server_state)
if state_before != self._server_state:
log.debug("[#%04X] [%s]", self.local_port,
self._server_state.name)

def fetch_message(self):
""" Receive at most one message from the server, if available.
Expand Down Expand Up @@ -288,15 +281,15 @@ def fetch_message(self):
response.complete = True
if summary_signature == b"\x70":
log.debug("[#%04X] S: SUCCESS %r", self.local_port, summary_metadata)
self._update_server_state_on_success(summary_metadata,
response.message)
self._server_state_manager.transition(response.message,
summary_metadata)
response.on_success(summary_metadata or {})
elif summary_signature == b"\x7E":
log.debug("[#%04X] S: IGNORED", self.local_port)
response.on_ignored(summary_metadata or {})
elif summary_signature == b"\x7F":
log.debug("[#%04X] S: FAILURE %r", self.local_port, summary_metadata)
self._server_state = ServerStates.FAILED
self._server_state_manager.state = ServerStates.FAILED
try:
response.on_failure(summary_metadata or {})
except (ServiceUnavailable, DatabaseUnavailable):
Expand Down

0 comments on commit 74d346e

Please sign in to comment.