Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change how we stub out HTTP requests in the tests #85

Merged
merged 1 commit into from
Mar 24, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/85.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Change how we stub out HTTP requests in the tests.
32 changes: 8 additions & 24 deletions tests/test_gcm.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,7 @@ def test_expected(self):
200, {"results": [{"message_id": "msg42", "registration_id": "spqr"}]}
)

req = self._make_request(self._make_dummy_notification([DEVICE_EXAMPLE]))

resp = self._collect_request(req)
resp = self._request(self._make_dummy_notification([DEVICE_EXAMPLE]))

self.assertEquals(resp, {"rejected": []})
self.assertEquals(gcm.num_requests, 1)
Expand All @@ -85,9 +83,7 @@ def test_rejected(self):
200, {"results": [{"registration_id": "spqr", "error": "NotRegistered"}]}
)

req = self._make_request(self._make_dummy_notification([DEVICE_EXAMPLE]))

resp = self._collect_request(req)
resp = self._request(self._make_dummy_notification([DEVICE_EXAMPLE]))

self.assertEquals(resp, {"rejected": ["spqr"]})
self.assertEquals(gcm.num_requests, 1)
Expand All @@ -103,19 +99,15 @@ def test_regenerated_id(self):
200, {"results": [{"registration_id": "spqr_new", "message_id": "msg42"}]}
)

req = self._make_request(self._make_dummy_notification([DEVICE_EXAMPLE]))

resp = self._collect_request(req)
resp = self._request(self._make_dummy_notification([DEVICE_EXAMPLE]))

self.assertEquals(resp, {"rejected": []})

gcm.preload_with_response(
200, {"results": [{"registration_id": "spqr_new", "message_id": "msg43"}]}
)

req = self._make_request(self._make_dummy_notification([DEVICE_EXAMPLE]))

resp = self._collect_request(req)
resp = self._request(self._make_dummy_notification([DEVICE_EXAMPLE]))

self.assertEquals(gcm.last_request_body["to"], "spqr_new")

Expand All @@ -138,12 +130,10 @@ def test_batching(self):
},
)

req = self._make_request(
resp = self._request(
self._make_dummy_notification([DEVICE_EXAMPLE, DEVICE_EXAMPLE2])
)

resp = self._collect_request(req)

self.assertEquals(resp, {"rejected": []})
self.assertEquals(gcm.last_request_body["registration_ids"], ["spqr", "spqr2"])
self.assertEquals(gcm.num_requests, 1)
Expand All @@ -166,12 +156,10 @@ def test_batching_individual_failure(self):
},
)

req = self._make_request(
resp = self._request(
self._make_dummy_notification([DEVICE_EXAMPLE, DEVICE_EXAMPLE2])
)

resp = self._collect_request(req)

self.assertEquals(resp, {"rejected": ["spqr2"]})
self.assertEquals(gcm.last_request_body["registration_ids"], ["spqr", "spqr2"])
self.assertEquals(gcm.num_requests, 1)
Expand All @@ -188,9 +176,7 @@ def test_regenerated_failure(self):
200, {"results": [{"registration_id": "spqr_new", "message_id": "msg42"}]}
)

req = self._make_request(self._make_dummy_notification([DEVICE_EXAMPLE]))

resp = self._collect_request(req)
resp = self._request(self._make_dummy_notification([DEVICE_EXAMPLE]))

self.assertEquals(resp, {"rejected": []})

Expand All @@ -202,9 +188,7 @@ def test_regenerated_failure(self):
{"results": [{"registration_id": "spqr_new", "error": "NotRegistered"}]},
)

req = self._make_request(self._make_dummy_notification([DEVICE_EXAMPLE]))

resp = self._collect_request(req)
resp = self._request(self._make_dummy_notification([DEVICE_EXAMPLE]))

self.assertEquals(gcm.last_request_body["to"], "spqr_new")

Expand Down
154 changes: 80 additions & 74 deletions tests/testutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,17 @@
import json
from io import BytesIO
from threading import Condition
from typing import BinaryIO, Optional, Union

import attr
from twisted.internet.defer import ensureDeferred
from twisted.test.proto_helpers import MemoryReactorClock
from twisted.trial import unittest
from twisted.web.http_headers import Headers
from twisted.web.server import NOT_DONE_YET
from twisted.web.test.requesthelper import DummyRequest as UnaugmentedDummyRequest
from twisted.web.server import Request

from sygnal.http import PushGatewayApiServer
from sygnal.sygnal import Sygnal, merge_left_with_defaults, CONFIG_DEFAULTS
from sygnal.sygnal import CONFIG_DEFAULTS, Sygnal, merge_left_with_defaults

REQ_PATH = b"/_matrix/push/v1/notify"

Expand Down Expand Up @@ -78,77 +79,36 @@ def _make_dummy_notification(self, devices):
}
}

def _make_request(self, payload, headers=None):
def _request(self, payload) -> Union[dict, int]:
"""
Make a dummy request to the notify endpoint with the specified
Args:
payload: payload to be JSON encoded
headers (dict, optional): A L{dict} mapping header names as L{bytes}
to L{list}s of header values as L{bytes}

Returns (DummyRequest):
A dummy request corresponding to the request arguments supplied.

"""
pathparts = REQ_PATH.split(b"/")
if pathparts[0] == b"":
pathparts = pathparts[1:]
dreq = DummyRequest(pathparts)
dreq.requestHeaders = Headers(headers or {})
dreq.responseCode = 200 # default to 200

if isinstance(payload, dict):
payload = json.dumps(payload)

dreq.content = BytesIO(payload.encode())
dreq.method = "POST"
Make a dummy request to the notify endpoint with the specified payload

return dreq

def _collect_request(self, request):
"""
Collects (waits until done and then returns the result of) the request.
Args:
request (Request): a request to collect
payload: payload to be JSON encoded

Returns (dict or int):
If successful (200 response received), the response is JSON decoded
and the resultant dict is returned.
If the response code is not 200, returns the response code.
"""
resource = self.v1api.site.getResourceFor(request)
rendered = resource.render(request)

if request.responseCode != 200:
return request.responseCode
if isinstance(payload, dict):
payload = json.dumps(payload)
content = BytesIO(payload.encode())

if isinstance(rendered, str):
return json.loads(rendered)
elif rendered == NOT_DONE_YET:
channel = FakeChannel(self.v1api.site, self.sygnal.reactor)
channel.process_request(b"POST", REQ_PATH, content)

while not request.finished:
# we need to advance until the request has been finished
self.sygnal.reactor.advance(1)
self.sygnal.reactor.wait_for_work(lambda: request.finished)
while not channel.done:
# we need to advance until the request has been finished
self.sygnal.reactor.advance(1)
self.sygnal.reactor.wait_for_work(lambda: channel.done)

assert request.finished > 0
assert channel.done

if request.responseCode != 200:
return request.responseCode
if channel.result.code != 200:
return channel.result.code

written_bytes = b"".join(request.written)
return json.loads(written_bytes)
else:
raise RuntimeError(f"Can't collect: {rendered}")

def _request(self, *args, **kwargs):
"""
Makes and collects a request.
See L{_make_request} and L{_collect_request}.
"""
request = self._make_request(*args, **kwargs)

return self._collect_request(request)
return json.loads(channel.response_body)


class ExtendedMemoryReactorClock(MemoryReactorClock):
Expand Down Expand Up @@ -192,20 +152,6 @@ def wait_for_work(self, early_stop=lambda: False):
self.work_notifier.release()


class DummyRequest(UnaugmentedDummyRequest):
"""
Tracks the response code in the 'code' field, like a normal Request.
"""

def __init__(self, postpath, session=None, client=None):
super().__init__(postpath, session, client)
self.code = 200

def setResponseCode(self, code, message=None):
super().setResponseCode(code, message)
self.code = code


class DummyResponse(object):
def __init__(self, code):
self.code = code
Expand All @@ -216,3 +162,63 @@ async def dummy(*_args, **_kwargs):
return ret_val

return dummy


@attr.s
class HTTPResult:
"""Holds the result data for FakeChannel"""

version = attr.ib(type=str)
code = attr.ib(type=int)
reason = attr.ib(type=str)
headers = attr.ib(type=Headers)


@attr.s
class FakeChannel(object):
"""
A fake Twisted Web Channel (the part that interfaces with the
wire).
"""

site = attr.ib()
_reactor = attr.ib()
_producer = None

result = attr.ib(type=Optional[HTTPResult], default=None)
response_body = b""
done = attr.ib(type=bool, default=False)

@property
def code(self):
if not self.result:
raise Exception("No result yet.")
return int(self.result.code)

def writeHeaders(self, version, code, reason, headers):
self.result = HTTPResult(version, int(code), reason, headers)

def write(self, content):
assert isinstance(content, bytes), "Should be bytes! " + repr(content)
self.response_body += content

def requestDone(self, _self):
self.done = True

def getPeer(self):
return None

def getHost(self):
return None

@property
def transport(self):
return None

def process_request(self, method: bytes, request_path: bytes, content: BinaryIO):
"""pretend that a request has arrived, and process it"""

# this is normally done by HTTPChannel, in its various lineReceived etc methods
req = self.site.requestFactory(self) # type: Request
req.content = content
req.requestReceived(method, request_path, b"1.1")