Skip to content

Commit

Permalink
Don't send RESET on READY (clean) connections (#572)
Browse files Browse the repository at this point in the history
  • Loading branch information
robsdedude authored Aug 5, 2021
1 parent 6679891 commit cad3718
Show file tree
Hide file tree
Showing 5 changed files with 129 additions and 39 deletions.
6 changes: 1 addition & 5 deletions neo4j/io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ class Bolt(abc.ABC):
PROTOCOL_VERSION = None

# flag if connection needs RESET to go back to READY state
_is_reset = True
is_reset = False

# The socket
in_use = False
Expand Down Expand Up @@ -460,10 +460,6 @@ def rollback(self, **handlers):
""" Appends a ROLLBACK message to the output queue."""
pass

@property
def is_reset(self):
return self._is_reset

@abc.abstractmethod
def reset(self):
""" Appends a RESET message to the outgoing queue, sends it and consumes
Expand Down
94 changes: 81 additions & 13 deletions neo4j/io/_bolt3.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from enum import Enum
from logging import getLogger
from ssl import SSLSocket

Expand Down Expand Up @@ -52,6 +53,53 @@
log = getLogger("neo4j")


class ServerStates(Enum):
CONNECTED = "CONNECTED"
READY = "READY"
STREAMING = "STREAMING"
TX_READY_OR_TX_STREAMING = "TX_READY||TX_STREAMING"
FAILED = "FAILED"


class ServerStateManager:
_STATE_TRANSITIONS = {
ServerStates.CONNECTED: {
"hello": ServerStates.READY,
},
ServerStates.READY: {
"run": ServerStates.STREAMING,
"begin": ServerStates.TX_READY_OR_TX_STREAMING,
},
ServerStates.STREAMING: {
"pull": ServerStates.READY,
"discard": ServerStates.READY,
"reset": ServerStates.READY,
},
ServerStates.TX_READY_OR_TX_STREAMING: {
"commit": ServerStates.READY,
"rollback": ServerStates.READY,
"reset": ServerStates.READY,
},
ServerStates.FAILED: {
"reset": ServerStates.READY,
}
}

def __init__(self, init_state, on_change=None):
self.state = init_state
self._on_change = on_change

def transition(self, message, metadata):
if metadata.get("has_more"):
return
state_before = self.state
self.state = self._STATE_TRANSITIONS\
.get(self.state, {})\
.get(message, self.state)
if state_before != self.state and callable(self._on_change):
self._on_change(state_before, self.state)


class Bolt3(Bolt):
""" Protocol handler for Bolt 3.
Expand All @@ -64,6 +112,25 @@ class Bolt3(Bolt):

supports_multiple_databases = False

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._server_state_manager = ServerStateManager(
ServerStates.CONNECTED, on_change=self._on_server_state_change
)

def _on_server_state_change(self, old_state, new_state):
log.debug("[#%04X] State: %s > %s", self.local_port,
old_state.name, new_state.name)

@property
def is_reset(self):
if self.responses:
# We can't be sure of the server's state as there are still pending
# responses. Unless the last message we sent was RESET. In that case
# the server state will always be READY when we're done.
return self.responses[-1].message == "reset"
return self._server_state_manager.state == ServerStates.READY

@property
def encrypted(self):
return isinstance(self.socket, SSLSocket)
Expand Down Expand Up @@ -92,7 +159,8 @@ def hello(self):
logged_headers["credentials"] = "*******"
log.debug("[#%04X] C: HELLO %r", self.local_port, logged_headers)
self._append(b"\x01", (headers,),
response=InitResponse(self, on_success=self.server_info.update))
response=InitResponse(self, "hello",
on_success=self.server_info.update))
self.send_all()
self.fetch_all()
check_supported_server_product(self.server_info.agent)
Expand Down Expand Up @@ -155,21 +223,20 @@ def run(self, query, parameters=None, mode=None, bookmarks=None, metadata=None,
fields = (query, parameters, extra)
log.debug("[#%04X] C: RUN %s", self.local_port, " ".join(map(repr, fields)))
if query.upper() == u"COMMIT":
self._append(b"\x10", fields, CommitResponse(self, **handlers))
self._append(b"\x10", fields, CommitResponse(self, "run",
**handlers))
else:
self._append(b"\x10", fields, Response(self, **handlers))
self._is_reset = False
self._append(b"\x10", fields, Response(self, "run", **handlers))

def discard(self, n=-1, qid=-1, **handlers):
# Just ignore n and qid, it is not supported in the Bolt 3 Protocol.
log.debug("[#%04X] C: DISCARD_ALL", self.local_port)
self._append(b"\x2F", (), Response(self, **handlers))
self._append(b"\x2F", (), Response(self, "discard", **handlers))

def pull(self, n=-1, qid=-1, **handlers):
# Just ignore n and qid, it is not supported in the Bolt 3 Protocol.
log.debug("[#%04X] C: PULL_ALL", self.local_port)
self._append(b"\x3F", (), Response(self, **handlers))
self._is_reset = False
self._append(b"\x3F", (), Response(self, "pull", **handlers))

def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None, db=None, **handlers):
if db is not None:
Expand All @@ -193,16 +260,15 @@ def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None, db=None,
except TypeError:
raise TypeError("Timeout must be specified as a number of seconds")
log.debug("[#%04X] C: BEGIN %r", self.local_port, extra)
self._append(b"\x11", (extra,), Response(self, **handlers))
self._is_reset = False
self._append(b"\x11", (extra,), Response(self, "begin", **handlers))

def commit(self, **handlers):
log.debug("[#%04X] C: COMMIT", self.local_port)
self._append(b"\x12", (), CommitResponse(self, **handlers))
self._append(b"\x12", (), CommitResponse(self, "commit", **handlers))

def rollback(self, **handlers):
log.debug("[#%04X] C: ROLLBACK", self.local_port)
self._append(b"\x13", (), Response(self, **handlers))
self._append(b"\x13", (), Response(self, "rollback", **handlers))

def reset(self):
""" Add a RESET message to the outgoing queue, send
Expand All @@ -213,10 +279,9 @@ def fail(metadata):
raise BoltProtocolError("RESET failed %r" % metadata, address=self.unresolved_address)

log.debug("[#%04X] C: RESET", self.local_port)
self._append(b"\x0F", response=Response(self, on_failure=fail))
self._append(b"\x0F", response=Response(self, "reset", on_failure=fail))
self.send_all()
self.fetch_all()
self._is_reset = True

def fetch_message(self):
""" Receive at most one message from the server, if available.
Expand Down Expand Up @@ -249,12 +314,15 @@ def fetch_message(self):
response.complete = True
if summary_signature == b"\x70":
log.debug("[#%04X] S: SUCCESS %r", self.local_port, summary_metadata)
self._server_state_manager.transition(response.message,
summary_metadata)
response.on_success(summary_metadata or {})
elif summary_signature == b"\x7E":
log.debug("[#%04X] S: IGNORED", self.local_port)
response.on_ignored(summary_metadata or {})
elif summary_signature == b"\x7F":
log.debug("[#%04X] S: FAILURE %r", self.local_port, summary_metadata)
self._server_state_manager.state = ServerStates.FAILED
try:
response.on_failure(summary_metadata or {})
except (ServiceUnavailable, DatabaseUnavailable):
Expand Down
59 changes: 43 additions & 16 deletions neo4j/io/_bolt4.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from enum import Enum
from logging import getLogger
from ssl import SSLSocket

Expand All @@ -37,7 +38,6 @@
Neo4jError,
NotALeader,
ServiceUnavailable,
SessionExpired,
)
from neo4j.io import (
Bolt,
Expand All @@ -48,6 +48,10 @@
InitResponse,
Response,
)
from neo4j.io._bolt3 import (
ServerStateManager,
ServerStates,
)


log = getLogger("neo4j")
Expand All @@ -65,6 +69,25 @@ class Bolt4x0(Bolt):

supports_multiple_databases = True

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._server_state_manager = ServerStateManager(
ServerStates.CONNECTED, on_change=self._on_server_state_change
)

def _on_server_state_change(self, old_state, new_state):
log.debug("[#%04X] State: %s > %s", self.local_port,
old_state.name, new_state.name)

@property
def is_reset(self):
if self.responses:
# We can't be sure of the server's state as there are still pending
# responses. Unless the last message we sent was RESET. In that case
# the server state will always be READY when we're done.
return self.responses[-1].message == "reset"
return self._server_state_manager.state == ServerStates.READY

@property
def encrypted(self):
return isinstance(self.socket, SSLSocket)
Expand Down Expand Up @@ -93,7 +116,8 @@ def hello(self):
logged_headers["credentials"] = "*******"
log.debug("[#%04X] C: HELLO %r", self.local_port, logged_headers)
self._append(b"\x01", (headers,),
response=InitResponse(self, on_success=self.server_info.update))
response=InitResponse(self, "hello",
on_success=self.server_info.update))
self.send_all()
self.fetch_all()
check_supported_server_product(self.server_info.agent)
Expand Down Expand Up @@ -162,25 +186,24 @@ def run(self, query, parameters=None, mode=None, bookmarks=None, metadata=None,
fields = (query, parameters, extra)
log.debug("[#%04X] C: RUN %s", self.local_port, " ".join(map(repr, fields)))
if query.upper() == u"COMMIT":
self._append(b"\x10", fields, CommitResponse(self, **handlers))
self._append(b"\x10", fields, CommitResponse(self, "run",
**handlers))
else:
self._append(b"\x10", fields, Response(self, **handlers))
self._is_reset = False
self._append(b"\x10", fields, Response(self, "run", **handlers))

def discard(self, n=-1, qid=-1, **handlers):
extra = {"n": n}
if qid != -1:
extra["qid"] = qid
log.debug("[#%04X] C: DISCARD %r", self.local_port, extra)
self._append(b"\x2F", (extra,), Response(self, **handlers))
self._append(b"\x2F", (extra,), Response(self, "discard", **handlers))

def pull(self, n=-1, qid=-1, **handlers):
extra = {"n": n}
if qid != -1:
extra["qid"] = qid
log.debug("[#%04X] C: PULL %r", self.local_port, extra)
self._append(b"\x3F", (extra,), Response(self, **handlers))
self._is_reset = False
self._append(b"\x3F", (extra,), Response(self, "pull", **handlers))

def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None,
db=None, **handlers):
Expand All @@ -205,16 +228,15 @@ def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None,
except TypeError:
raise TypeError("Timeout must be specified as a number of seconds")
log.debug("[#%04X] C: BEGIN %r", self.local_port, extra)
self._append(b"\x11", (extra,), Response(self, **handlers))
self._is_reset = False
self._append(b"\x11", (extra,), Response(self, "begin", **handlers))

def commit(self, **handlers):
log.debug("[#%04X] C: COMMIT", self.local_port)
self._append(b"\x12", (), CommitResponse(self, **handlers))
self._append(b"\x12", (), CommitResponse(self, "commit", **handlers))

def rollback(self, **handlers):
log.debug("[#%04X] C: ROLLBACK", self.local_port)
self._append(b"\x13", (), Response(self, **handlers))
self._append(b"\x13", (), Response(self, "rollback", **handlers))

def reset(self):
""" Add a RESET message to the outgoing queue, send
Expand All @@ -225,10 +247,9 @@ def fail(metadata):
raise BoltProtocolError("RESET failed %r" % metadata, self.unresolved_address)

log.debug("[#%04X] C: RESET", self.local_port)
self._append(b"\x0F", response=Response(self, on_failure=fail))
self._append(b"\x0F", response=Response(self, "reset", on_failure=fail))
self.send_all()
self.fetch_all()
self._is_reset = True

def fetch_message(self):
""" Receive at most one message from the server, if available.
Expand Down Expand Up @@ -261,12 +282,15 @@ def fetch_message(self):
response.complete = True
if summary_signature == b"\x70":
log.debug("[#%04X] S: SUCCESS %r", self.local_port, summary_metadata)
self._server_state_manager.transition(response.message,
summary_metadata)
response.on_success(summary_metadata or {})
elif summary_signature == b"\x7E":
log.debug("[#%04X] S: IGNORED", self.local_port)
response.on_ignored(summary_metadata or {})
elif summary_signature == b"\x7F":
log.debug("[#%04X] S: FAILURE %r", self.local_port, summary_metadata)
self._server_state_manager.state = ServerStates.FAILED
try:
response.on_failure(summary_metadata or {})
except (ServiceUnavailable, DatabaseUnavailable):
Expand Down Expand Up @@ -372,7 +396,9 @@ def fail(md):
else:
bookmarks = list(bookmarks)
self._append(b"\x66", (routing_context, bookmarks, database),
response=Response(self, on_success=metadata.update, on_failure=fail))
response=Response(self, "route",
on_success=metadata.update,
on_failure=fail))
self.send_all()
self.fetch_all()
return [metadata.get("rt")]
Expand Down Expand Up @@ -400,7 +426,8 @@ def on_success(metadata):
logged_headers["credentials"] = "*******"
log.debug("[#%04X] C: HELLO %r", self.local_port, logged_headers)
self._append(b"\x01", (headers,),
response=InitResponse(self, on_success=on_success))
response=InitResponse(self, "hello",
on_success=on_success))
self.send_all()
self.fetch_all()
check_supported_server_product(self.server_info.agent)
Expand Down
3 changes: 2 additions & 1 deletion neo4j/io/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,9 +144,10 @@ class Response:
more detail messages followed by one summary message).
"""

def __init__(self, connection, **handlers):
def __init__(self, connection, message, **handlers):
self.connection = connection
self.handlers = handlers
self.message = message
self.complete = False

def on_records(self, records):
Expand Down
6 changes: 2 additions & 4 deletions testkitbackend/test_config.json
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,12 @@
"stub.session_run_parameters.test_session_run_parameters.TestSessionRunParameters.test_empty_query":
"Driver rejects empty queries before sending it to the server",
"tls.tlsversions.TestTlsVersions.test_1_1":
"TLSv1.1 and below are disabled in the driver",
"stub.disconnects.test_disconnects.TestDisconnects.test_fail_on_reset":
"Driver silently ignores all errors on releasing connections back into the pool."
"TLSv1.1 and below are disabled in the driver"
},
"features": {
"AuthorizationExpiredTreatment": true,
"Optimization:ImplicitDefaultArguments": true,
"Optimization:MinimalResets": "Driver resets some clean connections when put back into pool",
"Optimization:MinimalResets": true,
"Optimization:ConnectionReuse": true,
"Optimization:PullPipelining": true,
"ConfHint:connection.recv_timeout_seconds": true,
Expand Down

0 comments on commit cad3718

Please sign in to comment.