Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Pass the Requester down to the HttpTransactionCache #15200

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/15200.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Make the `HttpTransactionCache` use the `Requester` in addition of the just the `Request` to build the transaction key.
34 changes: 23 additions & 11 deletions synapse/rest/admin/server_notice_servlet.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from http import HTTPStatus
from typing import TYPE_CHECKING, Awaitable, Optional, Tuple
from typing import TYPE_CHECKING, Optional, Tuple

from synapse.api.constants import EventTypes
from synapse.api.errors import NotFoundError, SynapseError
Expand All @@ -23,10 +23,10 @@
parse_json_object_from_request,
)
from synapse.http.site import SynapseRequest
from synapse.rest.admin import assert_requester_is_admin
from synapse.rest.admin._base import admin_patterns
from synapse.logging.opentracing import set_tag
from synapse.rest.admin._base import admin_patterns, assert_user_is_admin
from synapse.rest.client.transactions import HttpTransactionCache
from synapse.types import JsonDict, UserID
from synapse.types import JsonDict, Requester, UserID

if TYPE_CHECKING:
from synapse.server import HomeServer
Expand Down Expand Up @@ -70,10 +70,13 @@ def register(self, json_resource: HttpServer) -> None:
self.__class__.__name__,
)

async def on_POST(
self, request: SynapseRequest, txn_id: Optional[str] = None
async def _do(
self,
request: SynapseRequest,
requester: Requester,
txn_id: Optional[str],
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
await assert_user_is_admin(self.auth, requester)
body = parse_json_object_from_request(request)
assert_params_in_dict(body, ("user_id", "content"))
event_type = body.get("type", EventTypes.Message)
Expand Down Expand Up @@ -106,9 +109,18 @@ async def on_POST(

return HTTPStatus.OK, {"event_id": event.event_id}

def on_PUT(
async def on_POST(
self,
request: SynapseRequest,
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
return await self._do(request, requester, None)

async def on_PUT(
self, request: SynapseRequest, txn_id: str
) -> Awaitable[Tuple[int, JsonDict]]:
return self.txns.fetch_or_execute_request(
request, self.on_POST, request, txn_id
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
set_tag("txn_id", txn_id)
clokep marked this conversation as resolved.
Show resolved Hide resolved
return await self.txns.fetch_or_execute_request(
request, requester, self._do, request, requester, txn_id
)
174 changes: 108 additions & 66 deletions synapse/rest/client/room.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
from synapse.rest.client._base import client_patterns
from synapse.rest.client.transactions import HttpTransactionCache
from synapse.streams.config import PaginationConfig
from synapse.types import JsonDict, StreamToken, ThirdPartyInstanceID, UserID
from synapse.types import JsonDict, Requester, StreamToken, ThirdPartyInstanceID, UserID
from synapse.types.state import StateFilter
from synapse.util import json_decoder
from synapse.util.cancellation import cancellable
Expand Down Expand Up @@ -151,15 +151,22 @@ def register(self, http_server: HttpServer) -> None:
PATTERNS = "/createRoom"
register_txn_path(self, PATTERNS, http_server)

def on_PUT(
async def on_PUT(
self, request: SynapseRequest, txn_id: str
) -> Awaitable[Tuple[int, JsonDict]]:
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
set_tag("txn_id", txn_id)
return self.txns.fetch_or_execute_request(request, self.on_POST, request)
return await self.txns.fetch_or_execute_request(
request, requester, self._do, request, requester
)
clokep marked this conversation as resolved.
Show resolved Hide resolved

async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
return await self._do(request, requester)

async def _do(
self, request: SynapseRequest, requester: Requester
) -> Tuple[int, JsonDict]:
room_id, _, _ = await self._room_creation_handler.create_room(
requester, self.get_room_config(request)
)
Expand All @@ -172,9 +179,9 @@ def get_room_config(self, request: Request) -> JsonDict:


# TODO: Needs unit testing for generic events
class RoomStateEventRestServlet(TransactionRestServlet):
class RoomStateEventRestServlet(RestServlet):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note to self: This doesn't actually use self.txns since create_and_send_nonmember_event handles the transaction ID.

def __init__(self, hs: "HomeServer"):
super().__init__(hs)
super().__init__()
self.event_creation_handler = hs.get_event_creation_handler()
self.room_member_handler = hs.get_room_member_handler()
self.message_handler = hs.get_message_handler()
Expand Down Expand Up @@ -324,16 +331,16 @@ def __init__(self, hs: "HomeServer"):
def register(self, http_server: HttpServer) -> None:
# /rooms/$roomid/send/$event_type[/$txn_id]
PATTERNS = "/rooms/(?P<room_id>[^/]*)/send/(?P<event_type>[^/]*)"
register_txn_path(self, PATTERNS, http_server, with_get=True)
register_txn_path(self, PATTERNS, http_server)

async def on_POST(
async def _do(
self,
request: SynapseRequest,
requester: Requester,
room_id: str,
event_type: str,
txn_id: Optional[str] = None,
txn_id: Optional[str],
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
content = parse_json_object_from_request(request)

event_dict: JsonDict = {
Expand Down Expand Up @@ -362,18 +369,30 @@ async def on_POST(
set_tag("event_id", event_id)
return 200, {"event_id": event_id}

def on_GET(
self, request: SynapseRequest, room_id: str, event_type: str, txn_id: str
) -> Tuple[int, str]:
return 200, "Not implemented"
async def on_POST(
self,
request: SynapseRequest,
room_id: str,
event_type: str,
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
return await self._do(request, requester, room_id, event_type, None)

def on_PUT(
async def on_PUT(
self, request: SynapseRequest, room_id: str, event_type: str, txn_id: str
) -> Awaitable[Tuple[int, JsonDict]]:
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
set_tag("txn_id", txn_id)

return self.txns.fetch_or_execute_request(
request, self.on_POST, request, room_id, event_type, txn_id
return await self.txns.fetch_or_execute_request(
request,
requester,
self._do,
request,
requester,
room_id,
event_type,
txn_id,
)


Expand All @@ -389,14 +408,13 @@ def register(self, http_server: HttpServer) -> None:
PATTERNS = "/join/(?P<room_identifier>[^/]*)"
register_txn_path(self, PATTERNS, http_server)

async def on_POST(
async def _do(
clokep marked this conversation as resolved.
Show resolved Hide resolved
self,
request: SynapseRequest,
requester: Requester,
room_identifier: str,
txn_id: Optional[str] = None,
txn_id: Optional[str],
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)

content = parse_json_object_from_request(request, allow_empty_body=True)

# twisted.web.server.Request.args is incorrectly defined as Optional[Any]
Expand All @@ -420,22 +438,31 @@ async def on_POST(

return 200, {"room_id": room_id}

def on_PUT(
async def on_POST(
self,
request: SynapseRequest,
room_identifier: str,
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
return await self._do(request, requester, room_identifier, None)

async def on_PUT(
self, request: SynapseRequest, room_identifier: str, txn_id: str
) -> Awaitable[Tuple[int, JsonDict]]:
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
set_tag("txn_id", txn_id)

return self.txns.fetch_or_execute_request(
request, self.on_POST, request, room_identifier, txn_id
return await self.txns.fetch_or_execute_request(
request, requester, self._do, request, requester, room_identifier, txn_id
)


# TODO: Needs unit testing
class PublicRoomListRestServlet(TransactionRestServlet):
class PublicRoomListRestServlet(RestServlet):
PATTERNS = client_patterns("/publicRooms$", v1=True)

def __init__(self, hs: "HomeServer"):
super().__init__(hs)
super().__init__()
self.hs = hs
self.auth = hs.get_auth()

Expand Down Expand Up @@ -907,22 +934,25 @@ def register(self, http_server: HttpServer) -> None:
PATTERNS = "/rooms/(?P<room_id>[^/]*)/forget"
register_txn_path(self, PATTERNS, http_server)

async def on_POST(
self, request: SynapseRequest, room_id: str, txn_id: Optional[str] = None
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=False)

async def _do(self, requester: Requester, room_id: str) -> Tuple[int, JsonDict]:
await self.room_member_handler.forget(user=requester.user, room_id=room_id)

return 200, {}

def on_PUT(
async def on_POST(
self, request: SynapseRequest, room_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=False)
return await self._do(requester, room_id)

async def on_PUT(
self, request: SynapseRequest, room_id: str, txn_id: str
) -> Awaitable[Tuple[int, JsonDict]]:
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=False)
set_tag("txn_id", txn_id)

return self.txns.fetch_or_execute_request(
request, self.on_POST, request, room_id, txn_id
return await self.txns.fetch_or_execute_request(
request, requester, self._do, requester, room_id
)


Expand All @@ -941,15 +971,14 @@ def register(self, http_server: HttpServer) -> None:
)
register_txn_path(self, PATTERNS, http_server)

async def on_POST(
async def _do(
self,
request: SynapseRequest,
requester: Requester,
room_id: str,
membership_action: str,
txn_id: Optional[str] = None,
txn_id: Optional[str],
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)

if requester.is_guest and membership_action not in {
Membership.JOIN,
Membership.LEAVE,
Expand Down Expand Up @@ -1014,13 +1043,30 @@ async def on_POST(

return 200, return_value

def on_PUT(
async def on_POST(
self,
request: SynapseRequest,
room_id: str,
membership_action: str,
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
return await self._do(request, requester, room_id, membership_action, None)

async def on_PUT(
self, request: SynapseRequest, room_id: str, membership_action: str, txn_id: str
) -> Awaitable[Tuple[int, JsonDict]]:
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
set_tag("txn_id", txn_id)

return self.txns.fetch_or_execute_request(
request, self.on_POST, request, room_id, membership_action, txn_id
return await self.txns.fetch_or_execute_request(
request,
requester,
self._do,
request,
requester,
room_id,
membership_action,
txn_id,
)


Expand All @@ -1036,14 +1082,14 @@ def register(self, http_server: HttpServer) -> None:
PATTERNS = "/rooms/(?P<room_id>[^/]*)/redact/(?P<event_id>[^/]*)"
register_txn_path(self, PATTERNS, http_server)

async def on_POST(
async def _do(
self,
request: SynapseRequest,
requester: Requester,
room_id: str,
event_id: str,
txn_id: Optional[str] = None,
txn_id: Optional[str],
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
content = parse_json_object_from_request(request)

try:
Expand Down Expand Up @@ -1094,13 +1140,23 @@ async def on_POST(
set_tag("event_id", event_id)
return 200, {"event_id": event_id}

def on_PUT(
async def on_POST(
self,
request: SynapseRequest,
room_id: str,
event_id: str,
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
return await self._do(request, requester, room_id, event_id, None)

async def on_PUT(
self, request: SynapseRequest, room_id: str, event_id: str, txn_id: str
) -> Awaitable[Tuple[int, JsonDict]]:
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
set_tag("txn_id", txn_id)

return self.txns.fetch_or_execute_request(
request, self.on_POST, request, room_id, event_id, txn_id
return await self.txns.fetch_or_execute_request(
request, requester, self._do, request, requester, room_id, event_id, txn_id
)


Expand Down Expand Up @@ -1224,7 +1280,6 @@ def register_txn_path(
servlet: RestServlet,
regex_string: str,
http_server: HttpServer,
with_get: bool = False,
clokep marked this conversation as resolved.
Show resolved Hide resolved
) -> None:
"""Registers a transaction-based path.

Expand All @@ -1236,7 +1291,6 @@ def register_txn_path(
regex_string: The regex string to register. Must NOT have a
trailing $ as this string will be appended to.
http_server: The http_server to register paths with.
with_get: True to also register respective GET paths for the PUTs.
"""
on_POST = getattr(servlet, "on_POST", None)
on_PUT = getattr(servlet, "on_PUT", None)
Expand All @@ -1254,18 +1308,6 @@ def register_txn_path(
on_PUT,
servlet.__class__.__name__,
)
on_GET = getattr(servlet, "on_GET", None)
if with_get:
if on_GET is None:
raise RuntimeError(
"register_txn_path called with with_get = True, but no on_GET method exists"
)
http_server.register_paths(
"GET",
client_patterns(regex_string + "/(?P<txn_id>[^/]*)$", v1=True),
on_GET,
servlet.__class__.__name__,
)


class TimestampLookupRestServlet(RestServlet):
Expand Down
Loading