Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Fix client IPs being broken on Python 3 #3908

Merged
merged 11 commits into from
Sep 20, 2018
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ matrix:
- python: 3.6
env: TOX_ENV=py36

- python: 3.6
env: TOX_ENV=py36-postgres TRIAL_FLAGS="-j 4"
services:
- postgresql

- python: 3.6
env: TOX_ENV=check_isort

Expand Down
1 change: 1 addition & 0 deletions changelog.d/3908.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix adding client IPs to the database failing on Python 3.
2 changes: 1 addition & 1 deletion synapse/http/site.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ def getClientIP(self):
C{b"-"}.
"""
return self.requestHeaders.getRawHeaders(
b"x-forwarded-for", [b"-"])[0].split(b",")[0].strip()
b"x-forwarded-for", [b"-"])[0].split(b",")[0].strip().decode('ascii')


class SynapseRequestFactory(object):
Expand Down
35 changes: 20 additions & 15 deletions synapse/storage/client_ips.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,21 +119,26 @@ def _update_client_ips_batch_txn(self, txn, to_update):
for entry in iteritems(to_update):
(user_id, access_token, ip), (user_agent, device_id, last_seen) = entry

self._simple_upsert_txn(
txn,
table="user_ips",
keyvalues={
"user_id": user_id,
"access_token": access_token,
"ip": ip,
"user_agent": user_agent,
"device_id": device_id,
},
values={
"last_seen": last_seen,
},
lock=False,
)
try:
self._simple_upsert_txn(
txn,
table="user_ips",
keyvalues={
"user_id": user_id,
"access_token": access_token,
"ip": ip,
"user_agent": user_agent,
"device_id": device_id,
},
values={
"last_seen": last_seen,
},
lock=False,
)
except Exception as e:
# Failed to upsert, log and continue
logger.error("Failed to insert client IP: %r", entry)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you log the reason for the failure (ie, e) as well as the fact it's failed?



@defer.inlineCallbacks
def get_last_client_ip_by_device(self, user_id, device_id):
Expand Down
8 changes: 5 additions & 3 deletions tests/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def info(self, *args, **kwargs):
return FakeLogger()


def make_request(method, path, content=b"", access_token=None):
def make_request(method, path, content=b"", access_token=None, request=SynapseRequest):
"""
Make a web request using the given method and path, feed it the
content, and return the Request and the Channel underneath.
Expand All @@ -120,14 +120,16 @@ def make_request(method, path, content=b"", access_token=None):
site = FakeSite()
channel = FakeChannel()

req = SynapseRequest(site, channel)
req = request(site, channel)
req.process = lambda: b""
req.content = BytesIO(content)

if access_token:
req.requestHeaders.addRawHeader(b"Authorization", b"Bearer " + access_token)

req.requestHeaders.addRawHeader(b"X-Forwarded-For", b"127.0.0.1")
if content:
req.requestHeaders.addRawHeader(b"Content-Type", b"application/json")

req.requestReceived(method, path, b"1.1")

return req, channel
Expand Down
202 changes: 167 additions & 35 deletions tests/storage/test_client_ips.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -12,35 +13,45 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import hashlib
import hmac
import json

from mock import Mock

from twisted.internet import defer

import tests.unittest
import tests.utils
from synapse.http.site import XForwardedForRequest
from synapse.rest.client.v1 import admin, login

from tests import unittest


class ClientIpStoreTestCase(tests.unittest.TestCase):
def __init__(self, *args, **kwargs):
super(ClientIpStoreTestCase, self).__init__(*args, **kwargs)
self.store = None # type: synapse.storage.DataStore
self.clock = None # type: tests.utils.MockClock
class ClientIpStoreTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
hs = self.setup_test_homeserver()
return hs

@defer.inlineCallbacks
def setUp(self):
self.hs = yield tests.utils.setup_test_homeserver(self.addCleanup)
def prepare(self, hs, reactor, clock):
self.store = self.hs.get_datastore()
self.clock = self.hs.get_clock()

@defer.inlineCallbacks
def test_insert_new_client_ip(self):
self.clock.now = 12345678
self.reactor.advance(12345678)

user_id = "@user:id"
yield self.store.insert_client_ip(
user_id, "access_token", "ip", "user_agent", "device_id"
self.get_success(
self.store.insert_client_ip(
user_id, "access_token", "ip", "user_agent", "device_id"
)
)

result = yield self.store.get_last_client_ip_by_device(user_id, "device_id")
# Trigger the storage loop
self.reactor.advance(10)

result = self.get_success(
self.store.get_last_client_ip_by_device(user_id, "device_id")
)

r = result[(user_id, "device_id")]
self.assertDictContainsSubset(
Expand All @@ -55,18 +66,18 @@ def test_insert_new_client_ip(self):
r,
)

@defer.inlineCallbacks
def test_disabled_monthly_active_user(self):
self.hs.config.limit_usage_by_mau = False
self.hs.config.max_mau_value = 50
user_id = "@user:server"
yield self.store.insert_client_ip(
user_id, "access_token", "ip", "user_agent", "device_id"
self.get_success(
self.store.insert_client_ip(
user_id, "access_token", "ip", "user_agent", "device_id"
)
)
active = yield self.store.user_last_seen_monthly_active(user_id)
active = self.get_success(self.store.user_last_seen_monthly_active(user_id))
self.assertFalse(active)

@defer.inlineCallbacks
def test_adding_monthly_active_user_when_full(self):
self.hs.config.limit_usage_by_mau = True
self.hs.config.max_mau_value = 50
Expand All @@ -76,38 +87,159 @@ def test_adding_monthly_active_user_when_full(self):
self.store.get_monthly_active_count = Mock(
return_value=defer.succeed(lots_of_users)
)
yield self.store.insert_client_ip(
user_id, "access_token", "ip", "user_agent", "device_id"
self.get_success(
self.store.insert_client_ip(
user_id, "access_token", "ip", "user_agent", "device_id"
)
)
active = yield self.store.user_last_seen_monthly_active(user_id)
active = self.get_success(self.store.user_last_seen_monthly_active(user_id))
self.assertFalse(active)

@defer.inlineCallbacks
def test_adding_monthly_active_user_when_space(self):
self.hs.config.limit_usage_by_mau = True
self.hs.config.max_mau_value = 50
user_id = "@user:server"
active = yield self.store.user_last_seen_monthly_active(user_id)
active = self.get_success(self.store.user_last_seen_monthly_active(user_id))
self.assertFalse(active)

yield self.store.insert_client_ip(
user_id, "access_token", "ip", "user_agent", "device_id"
# Trigger the saving loop
self.reactor.advance(10)

self.get_success(
self.store.insert_client_ip(
user_id, "access_token", "ip", "user_agent", "device_id"
)
)
active = yield self.store.user_last_seen_monthly_active(user_id)
active = self.get_success(self.store.user_last_seen_monthly_active(user_id))
self.assertTrue(active)

@defer.inlineCallbacks
def test_updating_monthly_active_user_when_space(self):
self.hs.config.limit_usage_by_mau = True
self.hs.config.max_mau_value = 50
user_id = "@user:server"
yield self.store.register(user_id=user_id, token="123", password_hash=None)
self.get_success(
self.store.register(user_id=user_id, token="123", password_hash=None)
)

active = yield self.store.user_last_seen_monthly_active(user_id)
active = self.get_success(self.store.user_last_seen_monthly_active(user_id))
self.assertFalse(active)

yield self.store.insert_client_ip(
user_id, "access_token", "ip", "user_agent", "device_id"
# Trigger the saving loop
self.reactor.advance(10)

self.get_success(
self.store.insert_client_ip(
user_id, "access_token", "ip", "user_agent", "device_id"
)
)
active = yield self.store.user_last_seen_monthly_active(user_id)
active = self.get_success(self.store.user_last_seen_monthly_active(user_id))
self.assertTrue(active)


class ClientIpAuthTestCase(unittest.HomeserverTestCase):

servlets = [admin.register_servlets, login.register_servlets]

def make_homeserver(self, reactor, clock):
hs = self.setup_test_homeserver()
return hs

def prepare(self, hs, reactor, clock):
self.hs.config.registration_shared_secret = u"shared"
self.store = self.hs.get_datastore()

# Create the user
request, channel = self.make_request("GET", "/_matrix/client/r0/admin/register")
self.render(request)
nonce = channel.json_body["nonce"]

want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
want_mac.update(nonce.encode('ascii') + b"\x00bob\x00abc123\x00admin")
want_mac = want_mac.hexdigest()

body = json.dumps(
{
"nonce": nonce,
"username": "bob",
"password": "abc123",
"admin": True,
"mac": want_mac,
}
)
request, channel = self.make_request(
"POST", "/_matrix/client/r0/admin/register", body.encode('utf8')
)
self.render(request)

self.assertEqual(channel.code, 200)
self.user_id = channel.json_body["user_id"]

def test_request_with_xforwarded(self):
"""
The IP in X-Forwarded-For is entered into the client IPs table.
"""
self._runtest(
{b"X-Forwarded-For": b"127.9.0.1"},
"127.9.0.1",
{"request": XForwardedForRequest},
)

def test_request_from_getPeer(self):
"""
The IP returned by getPeer is entered into the client IPs table, if
there's no X-Forwarded-For header.
"""
self._runtest({}, "127.0.0.1", {})

def _runtest(self, headers, expected_ip, make_request_args):
device_id = "bleb"

body = json.dumps(
{
"type": "m.login.password",
"user": "bob",
"password": "abc123",
"device_id": device_id,
}
)
request, channel = self.make_request(
"POST", "/_matrix/client/r0/login", body.encode('utf8'), **make_request_args
)
self.render(request)
self.assertEqual(channel.code, 200)
access_token = channel.json_body["access_token"].encode('ascii')

# Advance to a known time
self.reactor.advance(123456 - self.reactor.seconds())

request, channel = self.make_request(
"GET",
"/_matrix/client/r0/admin/users/" + self.user_id,
body.encode('utf8'),
access_token=access_token,
**make_request_args
)
request.requestHeaders.addRawHeader(b"User-Agent", b"Mozzila pizza")

# Add the optional headers
for h, v in headers.items():
request.requestHeaders.addRawHeader(h, v)
self.render(request)

# Advance so the save loop occurs
self.reactor.advance(100)

result = self.get_success(
self.store.get_last_client_ip_by_device(self.user_id, device_id)
)
r = result[(self.user_id, device_id)]
self.assertDictContainsSubset(
{
"user_id": self.user_id,
"device_id": device_id,
"ip": expected_ip,
"user_agent": "Mozzila pizza",
"last_seen": 123456100,
},
r,
)
7 changes: 5 additions & 2 deletions tests/unittest.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from twisted.trial import unittest

from synapse.http.server import JsonResource
from synapse.http.site import SynapseRequest
from synapse.server import HomeServer
from synapse.types import UserID, create_requester
from synapse.util.logcontext import LoggingContextFilter
Expand Down Expand Up @@ -237,7 +238,9 @@ def prepare(self, reactor, clock, homeserver):
Function to optionally be overridden in subclasses.
"""

def make_request(self, method, path, content=b""):
def make_request(
self, method, path, content=b"", access_token=None, request=SynapseRequest
):
"""
Create a SynapseRequest at the path using the method and containing the
given content.
Expand All @@ -255,7 +258,7 @@ def make_request(self, method, path, content=b""):
if isinstance(content, dict):
content = json.dumps(content).encode('utf8')

return make_request(method, path, content)
return make_request(method, path, content, access_token, request)

def render(self, request):
"""
Expand Down
Loading