From b09dc68c65ed071d3a52ff710736d640cfba0e01 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Mon, 23 Mar 2020 22:22:50 +0000 Subject: [PATCH 1/2] Change how we stub out HTTP requests in the tests By having a FakeChannel instead of a DummyRequest, we can make sure that functionality which relies on our own Request impl works correctly. --- changelog.d/85.misc | 1 + tests/test_gcm.py | 32 +++------ tests/testutils.py | 154 +++++++++++++++++++++++--------------------- 3 files changed, 89 insertions(+), 98 deletions(-) create mode 100644 changelog.d/85.misc diff --git a/changelog.d/85.misc b/changelog.d/85.misc new file mode 100644 index 00000000..dc585ab4 --- /dev/null +++ b/changelog.d/85.misc @@ -0,0 +1 @@ +Change how we stub out HTTP requests in the tests. diff --git a/tests/test_gcm.py b/tests/test_gcm.py index 87496ae9..906e2af9 100644 --- a/tests/test_gcm.py +++ b/tests/test_gcm.py @@ -68,9 +68,7 @@ def test_expected(self): 200, {"results": [{"message_id": "msg42", "registration_id": "spqr"}]} ) - req = self._make_request(self._make_dummy_notification([DEVICE_EXAMPLE])) - - resp = self._collect_request(req) + resp = self._request(self._make_dummy_notification([DEVICE_EXAMPLE])) self.assertEquals(resp, {"rejected": []}) self.assertEquals(gcm.num_requests, 1) @@ -85,9 +83,7 @@ def test_rejected(self): 200, {"results": [{"registration_id": "spqr", "error": "NotRegistered"}]} ) - req = self._make_request(self._make_dummy_notification([DEVICE_EXAMPLE])) - - resp = self._collect_request(req) + resp = self._request(self._make_dummy_notification([DEVICE_EXAMPLE])) self.assertEquals(resp, {"rejected": ["spqr"]}) self.assertEquals(gcm.num_requests, 1) @@ -103,9 +99,7 @@ def test_regenerated_id(self): 200, {"results": [{"registration_id": "spqr_new", "message_id": "msg42"}]} ) - req = self._make_request(self._make_dummy_notification([DEVICE_EXAMPLE])) - - resp = self._collect_request(req) + resp = self._request(self._make_dummy_notification([DEVICE_EXAMPLE])) self.assertEquals(resp, {"rejected": []}) @@ -113,9 +107,7 @@ def test_regenerated_id(self): 200, {"results": [{"registration_id": "spqr_new", "message_id": "msg43"}]} ) - req = self._make_request(self._make_dummy_notification([DEVICE_EXAMPLE])) - - resp = self._collect_request(req) + resp = self._request(self._make_dummy_notification([DEVICE_EXAMPLE])) self.assertEquals(gcm.last_request_body["to"], "spqr_new") @@ -138,12 +130,10 @@ def test_batching(self): }, ) - req = self._make_request( + resp = self._request( self._make_dummy_notification([DEVICE_EXAMPLE, DEVICE_EXAMPLE2]) ) - resp = self._collect_request(req) - self.assertEquals(resp, {"rejected": []}) self.assertEquals(gcm.last_request_body["registration_ids"], ["spqr", "spqr2"]) self.assertEquals(gcm.num_requests, 1) @@ -166,12 +156,10 @@ def test_batching_individual_failure(self): }, ) - req = self._make_request( + resp = self._request( self._make_dummy_notification([DEVICE_EXAMPLE, DEVICE_EXAMPLE2]) ) - resp = self._collect_request(req) - self.assertEquals(resp, {"rejected": ["spqr2"]}) self.assertEquals(gcm.last_request_body["registration_ids"], ["spqr", "spqr2"]) self.assertEquals(gcm.num_requests, 1) @@ -188,9 +176,7 @@ def test_regenerated_failure(self): 200, {"results": [{"registration_id": "spqr_new", "message_id": "msg42"}]} ) - req = self._make_request(self._make_dummy_notification([DEVICE_EXAMPLE])) - - resp = self._collect_request(req) + resp = self._request(self._make_dummy_notification([DEVICE_EXAMPLE])) self.assertEquals(resp, {"rejected": []}) @@ -202,9 +188,7 @@ def test_regenerated_failure(self): {"results": [{"registration_id": "spqr_new", "error": "NotRegistered"}]}, ) - req = self._make_request(self._make_dummy_notification([DEVICE_EXAMPLE])) - - resp = self._collect_request(req) + resp = self._request(self._make_dummy_notification([DEVICE_EXAMPLE])) self.assertEquals(gcm.last_request_body["to"], "spqr_new") diff --git a/tests/testutils.py b/tests/testutils.py index f1155b67..a652f5c2 100644 --- a/tests/testutils.py +++ b/tests/testutils.py @@ -15,16 +15,17 @@ import json from io import BytesIO from threading import Condition +from typing import BinaryIO, Optional, Union +import attr from twisted.internet.defer import ensureDeferred from twisted.test.proto_helpers import MemoryReactorClock from twisted.trial import unittest from twisted.web.http_headers import Headers -from twisted.web.server import NOT_DONE_YET -from twisted.web.test.requesthelper import DummyRequest as UnaugmentedDummyRequest +from twisted.web.server import Request from sygnal.http import PushGatewayApiServer -from sygnal.sygnal import Sygnal, merge_left_with_defaults, CONFIG_DEFAULTS +from sygnal.sygnal import CONFIG_DEFAULTS, Sygnal, merge_left_with_defaults REQ_PATH = b"/_matrix/push/v1/notify" @@ -78,77 +79,36 @@ def _make_dummy_notification(self, devices): } } - def _make_request(self, payload, headers=None): + def _request(self, payload) -> Union[dict, int]: """ - Make a dummy request to the notify endpoint with the specified - Args: - payload: payload to be JSON encoded - headers (dict, optional): A L{dict} mapping header names as L{bytes} - to L{list}s of header values as L{bytes} - - Returns (DummyRequest): - A dummy request corresponding to the request arguments supplied. - - """ - pathparts = REQ_PATH.split(b"/") - if pathparts[0] == b"": - pathparts = pathparts[1:] - dreq = DummyRequest(pathparts) - dreq.requestHeaders = Headers(headers or {}) - dreq.responseCode = 200 # default to 200 - - if isinstance(payload, dict): - payload = json.dumps(payload) - - dreq.content = BytesIO(payload.encode()) - dreq.method = "POST" + Make a dummy request to the notify endpoint with the specified payload - return dreq - - def _collect_request(self, request): - """ - Collects (waits until done and then returns the result of) the request. Args: - request (Request): a request to collect + payload: payload to be JSON encoded Returns (dict or int): If successful (200 response received), the response is JSON decoded and the resultant dict is returned. If the response code is not 200, returns the response code. """ - resource = self.v1api.site.getResourceFor(request) - rendered = resource.render(request) - - if request.responseCode != 200: - return request.responseCode + if isinstance(payload, dict): + payload = json.dumps(payload) + content = BytesIO(payload.encode()) - if isinstance(rendered, str): - return json.loads(rendered) - elif rendered == NOT_DONE_YET: + channel = FakeChannel(self.v1api.site, self.sygnal.reactor) + channel.process_request(b"POST", REQ_PATH, content) - while not request.finished: - # we need to advance until the request has been finished - self.sygnal.reactor.advance(1) - self.sygnal.reactor.wait_for_work(lambda: request.finished) + while not channel.done: + # we need to advance until the request has been finished + self.sygnal.reactor.advance(1) + self.sygnal.reactor.wait_for_work(lambda: channel.done) - assert request.finished > 0 + assert channel.done - if request.responseCode != 200: - return request.responseCode + if channel.result.code != 200: + return channel.result.code - written_bytes = b"".join(request.written) - return json.loads(written_bytes) - else: - raise RuntimeError(f"Can't collect: {rendered}") - - def _request(self, *args, **kwargs): - """ - Makes and collects a request. - See L{_make_request} and L{_collect_request}. - """ - request = self._make_request(*args, **kwargs) - - return self._collect_request(request) + return json.loads(channel.response_body) class ExtendedMemoryReactorClock(MemoryReactorClock): @@ -192,20 +152,6 @@ def wait_for_work(self, early_stop=lambda: False): self.work_notifier.release() -class DummyRequest(UnaugmentedDummyRequest): - """ - Tracks the response code in the 'code' field, like a normal Request. - """ - - def __init__(self, postpath, session=None, client=None): - super().__init__(postpath, session, client) - self.code = 200 - - def setResponseCode(self, code, message=None): - super().setResponseCode(code, message) - self.code = code - - class DummyResponse(object): def __init__(self, code): self.code = code @@ -216,3 +162,63 @@ async def dummy(*_args, **_kwargs): return ret_val return dummy + + +@attr.s +class HTTPResult: + """Holds the result data for FakeChannel""" + + version = attr.ib(type=str) + code = attr.ib(type=int) + reason = attr.ib(type=str) + headers = attr.ib(type=Headers) + + +@attr.s +class FakeChannel(object): + """ + A fake Twisted Web Channel (the part that interfaces with the + wire). + """ + + site = attr.ib() + _reactor = attr.ib() + _producer = None + + result = attr.ib(type=Optional[HTTPResult], default=None) + response_body = b"" + done = attr.ib(type=bool, default=False) + + @property + def code(self): + if not self.result: + raise Exception("No result yet.") + return int(self.result.code) + + def writeHeaders(self, version, code, reason, headers): + self.result = HTTPResult(version, int(code), reason, headers) + + def write(self, content): + assert isinstance(content, bytes), "Should be bytes! " + repr(content) + self.response_body += content + + def requestDone(self, _self): + self.done = True + + def getPeer(self): + return None + + def getHost(self): + return None + + @property + def transport(self): + return None + + def process_request(self, method: bytes, request_path: bytes, content: BinaryIO): + """pretend that a request has arrived, and process it""" + + # this is normally done by HTTPChannel, in its various lineReceived etc methods + req = self.site.requestFactory(self) # type: Request + req.content = content + req.requestReceived(method, request_path, b"1.1") From 6f3662428057fa02c7d27159577dceb94039d105 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Mon, 23 Mar 2020 19:53:24 +0000 Subject: [PATCH 2/2] Fix warnings about finish() after disconnect --- changelog.d/84.bugfix | 1 + sygnal/http.py | 4 +++- 2 files changed, 4 insertions(+), 1 deletion(-) create mode 100644 changelog.d/84.bugfix diff --git a/changelog.d/84.bugfix b/changelog.d/84.bugfix new file mode 100644 index 00000000..27e86772 --- /dev/null +++ b/changelog.d/84.bugfix @@ -0,0 +1 @@ +Fix warnings about finish() after disconnect. diff --git a/sygnal/http.py b/sygnal/http.py index 7efecc1b..8aa2c52d 100644 --- a/sygnal/http.py +++ b/sygnal/http.py @@ -240,7 +240,9 @@ async def _handle_dispatch(self, root_span, request, log, notif, context): request.setResponseCode(500) log.error("Exception whilst dispatching notification.", exc_info=True) finally: - request.finish() + if not request._disconnected: + request.finish() + PUSHGATEWAY_HTTP_RESPONSES_COUNTER.labels(code=request.code).inc() root_span.set_tag(tags.HTTP_STATUS_CODE, request.code)