Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Add type hints to tests/rest. (#12208)
Browse files Browse the repository at this point in the history
Co-authored-by: Patrick Cloke <[email protected]>
  • Loading branch information
dklimpel and clokep authored Mar 11, 2022
1 parent e10a2fe commit 32c828d
Show file tree
Hide file tree
Showing 5 changed files with 129 additions and 85 deletions.
1 change: 1 addition & 0 deletions changelog.d/12208.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add type hints to tests files.
1 change: 0 additions & 1 deletion mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,6 @@ exclude = (?x)
|tests/push/test_push_rule_evaluator.py
|tests/rest/client/test_transactions.py
|tests/rest/media/v1/test_media_storage.py
|tests/rest/media/v1/test_url_preview.py
|tests/scripts/test_new_matrix_user.py
|tests/server.py
|tests/server_notices/test_resource_limits_server_notices.py
Expand Down
19 changes: 17 additions & 2 deletions tests/rest/client/test_transactions.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,18 @@
# Copyright 2018-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from http import HTTPStatus
from unittest.mock import Mock, call

from twisted.internet import defer, reactor
Expand All @@ -11,14 +26,14 @@


class HttpTransactionCacheTestCase(unittest.TestCase):
def setUp(self):
def setUp(self) -> None:
self.clock = MockClock()
self.hs = Mock()
self.hs.get_clock = Mock(return_value=self.clock)
self.hs.get_auth = Mock()
self.cache = HttpTransactionCache(self.hs)

self.mock_http_response = (200, "GOOD JOB!")
self.mock_http_response = (HTTPStatus.OK, "GOOD JOB!")
self.mock_key = "foo"

@defer.inlineCallbacks
Expand Down
110 changes: 67 additions & 43 deletions tests/rest/media/v1/test_media_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import tempfile
from binascii import unhexlify
from io import BytesIO
from typing import Optional
from typing import Any, BinaryIO, Dict, List, Optional, Union
from unittest.mock import Mock
from urllib import parse

Expand All @@ -26,18 +26,24 @@

from twisted.internet import defer
from twisted.internet.defer import Deferred
from twisted.test.proto_helpers import MemoryReactor

from synapse.events import EventBase
from synapse.events.spamcheck import load_legacy_spam_checkers
from synapse.logging.context import make_deferred_yieldable
from synapse.module_api import ModuleApi
from synapse.rest import admin
from synapse.rest.client import login
from synapse.rest.media.v1._base import FileInfo
from synapse.rest.media.v1.filepath import MediaFilePaths
from synapse.rest.media.v1.media_storage import MediaStorage
from synapse.rest.media.v1.media_storage import MediaStorage, ReadableFileWrapper
from synapse.rest.media.v1.storage_provider import FileStorageProviderBackend
from synapse.server import HomeServer
from synapse.types import RoomAlias
from synapse.util import Clock

from tests import unittest
from tests.server import FakeSite, make_request
from tests.server import FakeChannel, FakeSite, make_request
from tests.test_utils import SMALL_PNG
from tests.utils import default_config

Expand All @@ -46,7 +52,7 @@ class MediaStorageTests(unittest.HomeserverTestCase):

needs_threadpool = True

def prepare(self, reactor, clock, hs):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.test_dir = tempfile.mkdtemp(prefix="synapse-tests-")
self.addCleanup(shutil.rmtree, self.test_dir)

Expand All @@ -62,7 +68,7 @@ def prepare(self, reactor, clock, hs):
hs, self.primary_base_path, self.filepaths, storage_providers
)

def test_ensure_media_is_in_local_cache(self):
def test_ensure_media_is_in_local_cache(self) -> None:
media_id = "some_media_id"
test_body = "Test\n"

Expand Down Expand Up @@ -105,7 +111,7 @@ def test_ensure_media_is_in_local_cache(self):
self.assertEqual(test_body, body)


@attr.s(slots=True, frozen=True)
@attr.s(auto_attribs=True, slots=True, frozen=True)
class _TestImage:
"""An image for testing thumbnailing with the expected results
Expand All @@ -121,18 +127,18 @@ class _TestImage:
a 404 is expected.
"""

data = attr.ib(type=bytes)
content_type = attr.ib(type=bytes)
extension = attr.ib(type=bytes)
expected_cropped = attr.ib(type=Optional[bytes], default=None)
expected_scaled = attr.ib(type=Optional[bytes], default=None)
expected_found = attr.ib(default=True, type=bool)
data: bytes
content_type: bytes
extension: bytes
expected_cropped: Optional[bytes] = None
expected_scaled: Optional[bytes] = None
expected_found: bool = True


@parameterized_class(
("test_image",),
[
# smoll png
# small png
(
_TestImage(
SMALL_PNG,
Expand Down Expand Up @@ -193,11 +199,17 @@ class MediaRepoTests(unittest.HomeserverTestCase):
hijack_auth = True
user_id = "@test:user"

def make_homeserver(self, reactor, clock):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:

self.fetches = []

def get_file(destination, path, output_stream, args=None, max_size=None):
def get_file(
destination: str,
path: str,
output_stream: BinaryIO,
args: Optional[Dict[str, Union[str, List[str]]]] = None,
max_size: Optional[int] = None,
) -> Deferred:
"""
Returns tuple[int,dict,str,int] of file length, response headers,
absolute URI, and response code.
Expand Down Expand Up @@ -238,7 +250,7 @@ def write_to(r):

return hs

def prepare(self, reactor, clock, hs):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:

media_resource = hs.get_media_repository_resource()
self.download_resource = media_resource.children[b"download"]
Expand All @@ -248,8 +260,9 @@ def prepare(self, reactor, clock, hs):

self.media_id = "example.com/12345"

def _req(self, content_disposition, include_content_type=True):

def _req(
self, content_disposition: Optional[bytes], include_content_type: bool = True
) -> FakeChannel:
channel = make_request(
self.reactor,
FakeSite(self.download_resource, self.reactor),
Expand Down Expand Up @@ -288,7 +301,7 @@ def _req(self, content_disposition, include_content_type=True):

return channel

def test_handle_missing_content_type(self):
def test_handle_missing_content_type(self) -> None:
channel = self._req(
b"inline; filename=out" + self.test_image.extension,
include_content_type=False,
Expand All @@ -299,7 +312,7 @@ def test_handle_missing_content_type(self):
headers.getRawHeaders(b"Content-Type"), [b"application/octet-stream"]
)

def test_disposition_filename_ascii(self):
def test_disposition_filename_ascii(self) -> None:
"""
If the filename is filename=<ascii> then Synapse will decode it as an
ASCII string, and use filename= in the response.
Expand All @@ -315,7 +328,7 @@ def test_disposition_filename_ascii(self):
[b"inline; filename=out" + self.test_image.extension],
)

def test_disposition_filenamestar_utf8escaped(self):
def test_disposition_filenamestar_utf8escaped(self) -> None:
"""
If the filename is filename=*utf8''<utf8 escaped> then Synapse will
correctly decode it as the UTF-8 string, and use filename* in the
Expand All @@ -335,7 +348,7 @@ def test_disposition_filenamestar_utf8escaped(self):
[b"inline; filename*=utf-8''" + filename + self.test_image.extension],
)

def test_disposition_none(self):
def test_disposition_none(self) -> None:
"""
If there is no filename, one isn't passed on in the Content-Disposition
of the request.
Expand All @@ -348,26 +361,26 @@ def test_disposition_none(self):
)
self.assertEqual(headers.getRawHeaders(b"Content-Disposition"), None)

def test_thumbnail_crop(self):
def test_thumbnail_crop(self) -> None:
"""Test that a cropped remote thumbnail is available."""
self._test_thumbnail(
"crop", self.test_image.expected_cropped, self.test_image.expected_found
)

def test_thumbnail_scale(self):
def test_thumbnail_scale(self) -> None:
"""Test that a scaled remote thumbnail is available."""
self._test_thumbnail(
"scale", self.test_image.expected_scaled, self.test_image.expected_found
)

def test_invalid_type(self):
def test_invalid_type(self) -> None:
"""An invalid thumbnail type is never available."""
self._test_thumbnail("invalid", None, False)

@unittest.override_config(
{"thumbnail_sizes": [{"width": 32, "height": 32, "method": "scale"}]}
)
def test_no_thumbnail_crop(self):
def test_no_thumbnail_crop(self) -> None:
"""
Override the config to generate only scaled thumbnails, but request a cropped one.
"""
Expand All @@ -376,13 +389,13 @@ def test_no_thumbnail_crop(self):
@unittest.override_config(
{"thumbnail_sizes": [{"width": 32, "height": 32, "method": "crop"}]}
)
def test_no_thumbnail_scale(self):
def test_no_thumbnail_scale(self) -> None:
"""
Override the config to generate only cropped thumbnails, but request a scaled one.
"""
self._test_thumbnail("scale", None, False)

def test_thumbnail_repeated_thumbnail(self):
def test_thumbnail_repeated_thumbnail(self) -> None:
"""Test that fetching the same thumbnail works, and deleting the on disk
thumbnail regenerates it.
"""
Expand Down Expand Up @@ -443,7 +456,9 @@ def test_thumbnail_repeated_thumbnail(self):
channel.result["body"],
)

def _test_thumbnail(self, method, expected_body, expected_found):
def _test_thumbnail(
self, method: str, expected_body: Optional[bytes], expected_found: bool
) -> None:
params = "?width=32&height=32&method=" + method
channel = make_request(
self.reactor,
Expand Down Expand Up @@ -485,7 +500,7 @@ def _test_thumbnail(self, method, expected_body, expected_found):
)

@parameterized.expand([("crop", 16), ("crop", 64), ("scale", 16), ("scale", 64)])
def test_same_quality(self, method, desired_size):
def test_same_quality(self, method: str, desired_size: int) -> None:
"""Test that choosing between thumbnails with the same quality rating succeeds.
We are not particular about which thumbnail is chosen."""
Expand Down Expand Up @@ -521,7 +536,7 @@ def test_same_quality(self, method, desired_size):
)
)

def test_x_robots_tag_header(self):
def test_x_robots_tag_header(self) -> None:
"""
Tests that the `X-Robots-Tag` header is present, which informs web crawlers
to not index, archive, or follow links in media.
Expand All @@ -540,29 +555,38 @@ class TestSpamChecker:
`evil`.
"""

def __init__(self, config, api):
def __init__(self, config: Dict[str, Any], api: ModuleApi) -> None:
self.config = config
self.api = api

def parse_config(config):
def parse_config(config: Dict[str, Any]) -> Dict[str, Any]:
return config

async def check_event_for_spam(self, foo):
async def check_event_for_spam(self, event: EventBase) -> Union[bool, str]:
return False # allow all events

async def user_may_invite(self, inviter_userid, invitee_userid, room_id):
async def user_may_invite(
self,
inviter_userid: str,
invitee_userid: str,
room_id: str,
) -> bool:
return True # allow all invites

async def user_may_create_room(self, userid):
async def user_may_create_room(self, userid: str) -> bool:
return True # allow all room creations

async def user_may_create_room_alias(self, userid, room_alias):
async def user_may_create_room_alias(
self, userid: str, room_alias: RoomAlias
) -> bool:
return True # allow all room aliases

async def user_may_publish_room(self, userid, room_id):
async def user_may_publish_room(self, userid: str, room_id: str) -> bool:
return True # allow publishing of all rooms

async def check_media_file_for_spam(self, file_wrapper, file_info) -> bool:
async def check_media_file_for_spam(
self, file_wrapper: ReadableFileWrapper, file_info: FileInfo
) -> bool:
buf = BytesIO()
await file_wrapper.write_chunks_to(buf.write)

Expand All @@ -575,7 +599,7 @@ class SpamCheckerTestCase(unittest.HomeserverTestCase):
admin.register_servlets,
]

def prepare(self, reactor, clock, hs):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.user = self.register_user("user", "pass")
self.tok = self.login("user", "pass")

Expand All @@ -586,7 +610,7 @@ def prepare(self, reactor, clock, hs):

load_legacy_spam_checkers(hs)

def default_config(self):
def default_config(self) -> Dict[str, Any]:
config = default_config("test")

config.update(
Expand All @@ -602,13 +626,13 @@ def default_config(self):

return config

def test_upload_innocent(self):
def test_upload_innocent(self) -> None:
"""Attempt to upload some innocent data that should be allowed."""
self.helper.upload_media(
self.upload_resource, SMALL_PNG, tok=self.tok, expect_code=200
)

def test_upload_ban(self):
def test_upload_ban(self) -> None:
"""Attempt to upload some data that includes bytes "evil", which should
get rejected by the spam checker.
"""
Expand Down
Loading

0 comments on commit 32c828d

Please sign in to comment.