Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[4.4] Warnings overhaul #789

Merged
merged 10 commits into from
Sep 5, 2022
4 changes: 0 additions & 4 deletions neo4j/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,13 +383,11 @@ def session(self, **config):
"""
from neo4j.work.simple import Session
session_config = SessionConfig(self._default_workspace_config, config)
SessionConfig.consume(config) # Consume the config
return Session(self._pool, session_config)

def pipeline(self, **config):
from neo4j.work.pipelining import Pipeline, PipelineConfig
pipeline_config = PipelineConfig(self._default_workspace_config, config)
PipelineConfig.consume(config) # Consume the config
return Pipeline(self._pool, pipeline_config)

@experimental("The configuration may change in the future.")
Expand Down Expand Up @@ -427,13 +425,11 @@ def __init__(self, pool, default_workspace_config):

def session(self, **config):
session_config = SessionConfig(self._default_workspace_config, config)
SessionConfig.consume(config) # Consume the config
return Session(self._pool, session_config)

def pipeline(self, **config):
from neo4j.work.pipelining import Pipeline, PipelineConfig
pipeline_config = PipelineConfig(self._default_workspace_config, config)
PipelineConfig.consume(config) # Consume the config
return Pipeline(self._pool, pipeline_config)

@experimental("The configuration may change in the future.")
Expand Down
12 changes: 11 additions & 1 deletion neo4j/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,9 +170,19 @@ def set_attr(k, v):
else:
raise AttributeError(k)

rejected_keys = []
for key, value in data_dict.items():
if value is not None:
set_attr(key, value)
try:
set_attr(key, value)
except AttributeError as exc:
if not exc.args == (key,):
raise
rejected_keys.append(key)

if rejected_keys:
raise ConfigurationError("Unexpected config keys: "
+ ", ".join(rejected_keys))

def __init__(self, *args, **kwargs):
for arg in args:
Expand Down
22 changes: 13 additions & 9 deletions neo4j/io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,18 +277,19 @@ def get_handshake(cls):
return b"".join(version.to_bytes() for version in offered_versions).ljust(16, b"\x00")

@classmethod
def ping(cls, address, *, timeout=None, **config):
def ping(cls, address, *, timeout=None, pool_config=None):
""" Attempt to establish a Bolt connection, returning the
agreed Bolt protocol version if successful.
"""
config = PoolConfig.consume(config)
if pool_config is None:
pool_config = PoolConfig()
try:
s, protocol_version, handshake, data = BoltSocket.connect(
address,
timeout=timeout,
custom_resolver=config.resolver,
ssl_context=config.get_ssl_context(),
keep_alive=config.keep_alive,
custom_resolver=pool_config.resolver,
ssl_context=pool_config.get_ssl_context(),
keep_alive=pool_config.keep_alive,
)
except (ServiceUnavailable, SessionExpired, BoltHandshakeError):
return None
Expand All @@ -297,7 +298,8 @@ def ping(cls, address, *, timeout=None, **config):
return protocol_version

@classmethod
def open(cls, address, *, auth=None, timeout=None, routing_context=None, **pool_config):
def open(cls, address, *, auth=None, timeout=None, routing_context=None,
pool_config=None):
""" Open a new Bolt connection to a given server address.

:param address:
Expand All @@ -316,7 +318,8 @@ def time_remaining():
return t if t > 0 else 0

t0 = perf_counter()
pool_config = PoolConfig.consume(pool_config)
if pool_config is None:
pool_config = PoolConfig()

socket_connection_timeout = pool_config.connection_timeout
if socket_connection_timeout is None:
Expand Down Expand Up @@ -906,7 +909,7 @@ def open(cls, address, *, auth, pool_config, workspace_config):
def opener(addr, timeout):
return Bolt.open(
addr, auth=auth, timeout=timeout, routing_context=None,
**pool_config
pool_config=pool_config
)

pool = cls(opener, pool_config, workspace_config, address)
Expand Down Expand Up @@ -951,7 +954,8 @@ def open(cls, *addresses, auth, pool_config, workspace_config, routing_context=N

def opener(addr, timeout):
return Bolt.open(addr, auth=auth, timeout=timeout,
routing_context=routing_context, **pool_config)
routing_context=routing_context,
pool_config=pool_config)

pool = cls(opener, pool_config, workspace_config, address)
return pool
Expand Down
2 changes: 1 addition & 1 deletion neo4j/io/_socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ def _handshake(cls, s, resolved_address):
def close_socket(cls, socket_):
try:
if isinstance(socket_, BoltSocket):
socket.close()
socket_.close()
else:
socket_.shutdown(SHUT_RDWR)
socket_.close()
Expand Down
10 changes: 5 additions & 5 deletions neo4j/time/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,7 +482,7 @@ def __mul__(self, other):
:rtype: Duration
"""
if isinstance(other, float):
deprecation_warn("Multiplication with float will be deprecated in "
deprecation_warn("Multiplication with float will be removed in "
"5.0.")
if isinstance(other, (int, float)):
return Duration(
Expand Down Expand Up @@ -1627,7 +1627,7 @@ def from_clock_time(cls, clock_time, epoch):
ts = clock_time.seconds % 86400
nanoseconds = int(NANO_SECONDS * ts + clock_time.nanoseconds)
ticks = (epoch.time().ticks_ns + nanoseconds) % (86400 * NANO_SECONDS)
return Time.from_ticks_ns(ticks)
return cls.from_ticks_ns(ticks)

@classmethod
def __normalize_hour(cls, hour):
Expand Down Expand Up @@ -1657,8 +1657,8 @@ def __normalize_nanosecond(cls, hour, minute, second, nanosecond):
# TODO 5.0: remove -----------------------------------------------------
seconds, extra_ns = divmod(second, 1)
if extra_ns:
deprecation_warn("Float support second will be removed in 5.0. "
"Use `nanosecond` instead.")
deprecation_warn("Float support for `second` will be removed in "
"5.0. Use `nanosecond` instead.")
# ----------------------------------------------------------------------
hour, minute, second = cls.__normalize_second(hour, minute, second)
nanosecond = int(nanosecond
Expand Down Expand Up @@ -1753,7 +1753,7 @@ def nanosecond(self):
return self.__nanosecond

@property
@deprecated("hour_minute_second will be removed in 5.0. "
@deprecated("`hour_minute_second` will be removed in 5.0. "
"Use `hour_minute_second_nanosecond` instead.")
def hour_minute_second(self):
"""The time as a tuple of (hour, minute, second).
Expand Down
2 changes: 1 addition & 1 deletion testkit/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ RUN apt-get update && \
apt-get install -y --no-install-recommends \
make build-essential libssl-dev zlib1g-dev libbz2-dev libreadline-dev \
libsqlite3-dev wget curl llvm libncurses5-dev xz-utils tk-dev \
libxml2-dev libxmlsec1-dev libffi-dev \
libxml2-dev libxmlsec1-dev libffi-dev liblzma-dev \
ca-certificates && \
apt-get clean && rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/*

Expand Down
2 changes: 1 addition & 1 deletion testkit/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@

if __name__ == "__main__":
subprocess.check_call(
["python", "-m", "testkitbackend"],
["python", "-W", "error", "-m", "testkitbackend"],
stdout=sys.stdout, stderr=sys.stderr
)
4 changes: 1 addition & 3 deletions testkit/stress.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import subprocess
import os
import sys


if __name__ == "__main__":
# Until below works
sys.exit(0)
Expand All @@ -16,4 +14,4 @@
"NEO4J_PASSWORD": os.environ["TEST_NEO4J_PASS"],
"NEO4J_URI": uri}
subprocess.check_call(cmd, universal_newlines=True,
stderr=subprocess.STDOUT, env=env)
stdout=sys.stdout, stderr=sys.stderr, env=env)
5 changes: 4 additions & 1 deletion testkit/unittests.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import subprocess
import sys


def run(args):
subprocess.run(
args, universal_newlines=True, stderr=subprocess.STDOUT, check=True)
args, universal_newlines=True, stdout=sys.stdout, stderr=sys.stderr,
check=True
)


if __name__ == "__main__":
Expand Down
4 changes: 4 additions & 0 deletions testkitbackend/__main__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import warnings

from .server import Server


if __name__ == "__main__":
warnings.simplefilter("error")
server = Server(("0.0.0.0", 9876))
while True:
server.handle_request()
64 changes: 64 additions & 0 deletions testkitbackend/_warning_check.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [https://neo4j.com]
#
# This file is part of Neo4j.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import re
import warnings
from contextlib import contextmanager


@contextmanager
def warning_check(category, message):
with warnings.catch_warnings(record=True) as warn_log:
warnings.filterwarnings("always", category=category, message=message)
yield
if len(warn_log) != 1:
raise AssertionError("Expected 1 warning, found %d: %s"
% (len(warn_log), warn_log))


@contextmanager
def warnings_check(category_message_pairs):
with warnings.catch_warnings(record=True) as warn_log:
for category, message in category_message_pairs:
warnings.filterwarnings("always", category=category,
message=message)
yield
if len(warn_log) != len(category_message_pairs):
raise AssertionError(
"Expected %d warnings, found %d: %s"
% (len(category_message_pairs), len(warn_log), warn_log)
)
category_message_pairs = [
(category, re.compile(message, re.I))
for category, message in category_message_pairs
]
for category, matcher in category_message_pairs:
match = None
for i, warning in enumerate(warn_log):
if (
warning.category == category
and matcher.match(warning.message.args[0])
):
match = i
break
if match is None:
raise AssertionError(
"Expected warning not found: %r %r"
% (category, matcher.pattern)
)
warn_log.pop(match)
17 changes: 17 additions & 0 deletions testkitbackend/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,23 @@ def __init__(self, rd, wr):
self._requestHandlers = dict(
[m for m in getmembers(requests, isfunction)])

def close(self):
for dict_of_closables in (
self.transactions,
{key: tracker.session for key, tracker in self.sessions.items()},
self.drivers,
):
for key, closable in dict_of_closables.items():
try:
closable.close()
except (Neo4jError, DriverError, OSError):
log.error(
"Error during TestKit backend garbage collection. "
"While collecting: (key: %s) %s\n%s",
key, closable, traceback.format_exc()
)
dict_of_closables.clear()

def next_key(self):
self.key = self.key + 1
return self.key
Expand Down
44 changes: 36 additions & 8 deletions testkitbackend/requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import json
from os import path

Expand All @@ -22,6 +24,11 @@
import testkitbackend.totestkit as totestkit
from testkitbackend.fromtestkit import to_meta_and_timeout

from ._warning_check import (
warning_check,
warnings_check,
)


class FrontendError(Exception):
pass
Expand Down Expand Up @@ -97,9 +104,21 @@ def NewDriver(backend, data):
data.mark_item_as_read_if_equals("livenessCheckTimeoutMs", None)

data.mark_item_as_read("domainNameResolverRegistered")
driver = neo4j.GraphDatabase.driver(
data["uri"], auth=auth, user_agent=data["userAgent"], **kwargs
)
expected_warnings = []
if "update_routing_table_timeout" in kwargs:
expected_warnings.append((
DeprecationWarning,
"The 'update_routing_table_timeout' config key is deprecated"
))
if "session_connection_timeout" in kwargs:
expected_warnings.append((
DeprecationWarning,
"The 'session_connection_timeout' config key is deprecated"
))
with warnings_check(expected_warnings):
driver = neo4j.GraphDatabase.driver(
data["uri"], auth=auth, user_agent=data["userAgent"], **kwargs
)
key = backend.next_key()
backend.drivers[key] = driver
backend.send_response("Driver", {"id": key})
Expand All @@ -108,17 +127,26 @@ def NewDriver(backend, data):
def VerifyConnectivity(backend, data):
driver_id = data["driverId"]
driver = backend.drivers[driver_id]
driver.verify_connectivity()
with warning_check(
neo4j.ExperimentalWarning,
"The configuration may change in the future."
):
driver.verify_connectivity()
backend.send_response("Driver", {"id": driver_id})


def CheckMultiDBSupport(backend, data):
driver_id = data["driverId"]
driver = backend.drivers[driver_id]
backend.send_response(
"MultiDBSupport",
{"id": backend.next_key(), "available": driver.supports_multi_db()}
)
with warning_check(
neo4j.ExperimentalWarning,
"Feature support query, based on Bolt protocol version and Neo4j "
"server version will change in the future."
):
available = driver.supports_multi_db()
backend.send_response("MultiDBSupport", {
"id": backend.next_key(), "available": available
})


def resolution_func(backend, custom_resolver=False, custom_dns_resolver=False):
Expand Down
7 changes: 5 additions & 2 deletions testkitbackend/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,10 @@ def __init__(self, address):
class Handler(StreamRequestHandler):
def handle(self):
backend = Backend(self.rfile, self.wfile)
while backend.process_request():
pass
try:
while backend.process_request():
pass
finally:
backend.close()
print("Disconnected")
super(Server, self).__init__(address, Handler)
Loading