Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Whitelist webpush endpoints #182

Merged
merged 12 commits into from
Mar 22, 2021
1 change: 1 addition & 0 deletions changelog.d/182.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add 'allowed_endpoints' configuration option for limiting the endpoints that WebPush pushkins will contact.
13 changes: 13 additions & 0 deletions docs/applications.md
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,19 @@ You also need to set an e-mail address in `vapid_contact_email` in the config fi
where the push gateway operator can reach you in case they need to notify you
about your usage of their API.

Since for webpush, the push gateway endpoint is variable and comes from the browser
through the push data, you may not want to have your sygnal instance connect to any
random addressable server. For this, you can set the `allowed_endpoints` option to
a list of allowed endpoints. Globs are supported. For example, to allow Firefox,
Chrome and Opera (Google) and Edge as a push gateway, you can use this:

```yaml
allowed_endpoints:
- "updates.push.services.mozilla.com"
- "fcm.googleapis.com"
- "*.notify.windows.com"
```

#### Push key and expected push data

In your web application, [the push manager subscribe method](https://developer.mozilla.org/en-US/docs/Web/API/PushManager/subscribe)
Expand Down
25 changes: 25 additions & 0 deletions sygnal/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from logging import LoggerAdapter

from twisted.internet.defer import Deferred
Expand All @@ -37,3 +38,27 @@ async def twisted_sleep(delay, twisted_reactor):
class NotificationLoggerAdapter(LoggerAdapter):
def process(self, msg, kwargs):
return f"[{self.extra['request_id']}] {msg}", kwargs


def glob_to_regex(glob):
"""Converts a glob to a compiled regex object.

The regex is anchored at the beginning and end of the string.

Args:
glob (str)

Returns:
re.RegexObject
"""
res = ""
for c in glob:
if c == "*":
res = res + ".*"
elif c == "?":
res = res + "."
else:
res = res + re.escape(c)

# \A anchors at start of string, \Z at end of string
return re.compile(r"\A" + res + r"\Z", re.IGNORECASE)
24 changes: 23 additions & 1 deletion sygnal/webpushpushkin.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import logging
import os.path
from io import BytesIO
from typing import List, Optional, Pattern
from urllib.parse import urlparse

from prometheus_client import Gauge, Histogram
Expand All @@ -30,6 +31,7 @@

from .exceptions import PushkinSetupException
from .notifications import ConcurrencyLimitedPushkin
from .utils import glob_to_regex

QUEUE_TIME_HISTOGRAM = Histogram(
"sygnal_webpush_queue_time",
Expand Down Expand Up @@ -96,6 +98,14 @@ def __init__(self, name, sygnal, config):
)
self.http_agent_wrapper = HttpAgentWrapper(self.http_agent)

self.allowed_endpoints = None # type: Optional[List[Pattern]]
allowed_endpoints = self.get_config("allowed_endpoints")
if allowed_endpoints:
if not isinstance(allowed_endpoints, list):
raise PushkinSetupException(
"'allowed_endpoints' should be a list or not set"
)
self.allowed_endpoints = list(map(glob_to_regex, allowed_endpoints))
bwindels marked this conversation as resolved.
Show resolved Hide resolved
privkey_filename = self.get_config("vapid_private_key")
if not privkey_filename:
raise PushkinSetupException("'vapid_private_key' not set in config")
Expand All @@ -119,6 +129,18 @@ async def _dispatch_notification_unlimited(self, n, device, context):

endpoint = device.data.get("endpoint")
auth = device.data.get("auth")
endpoint_domain = urlparse(endpoint).netloc
if self.allowed_endpoints:
allowed = any(
regex.fullmatch(endpoint_domain) for regex in self.allowed_endpoints
)
if not allowed:
logger.error(
"push gateway %s is not in allowed_endpoints, blocking request",
endpoint_domain,
)
# abort, but don't reject push key
return []

if not p256dh or not endpoint or not auth:
logger.warn(
Expand Down Expand Up @@ -168,7 +190,7 @@ async def _dispatch_notification_unlimited(self, n, device, context):
logger.warn(
"Rejecting pushkey %s; gateway %s failed with %d: %s",
device.pushkey,
urlparse(endpoint).netloc,
endpoint_domain,
response.code,
response_text,
)
Expand Down