This repository has been archived by the owner on Apr 26, 2024. It is now read-only.
-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
/
test_oidc2.py
405 lines (347 loc) · 15.9 KB
/
test_oidc2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
from copy import deepcopy
from typing import Any, Dict
from unittest import TestCase
import yaml
from parameterized import parameterized
from pydantic import ValidationError
from synapse.config.oidc2 import (
ClientAuthMethods,
LegacyOIDCProviderModel,
OIDCProviderModel,
)
SAMPLE_CONFIG = yaml.safe_load(
"""
idp_id: my_idp
idp_name: My OpenID provider
idp_icon: "mxc://example.com/blahblahblah"
idp_brand: "brandy"
issuer: "https://accountns.exeample.com"
client_id: "provided-by-your-issuer"
client_secret_jwt_key:
key: DUMMY_PRIVATE_KEY
jwt_header:
alg: ES256
kid: potato123
jwt_payload:
iss: issuer456
client_auth_method: "client_secret_post"
scopes: ["name", "email", "openid"]
authorization_endpoint: https://example.com/auth/authorize?response_mode=form_post
token_endpoint: https://id.example.com/dummy_url_here
jwks_uri: "https://accounts.example.com/.well-known/jwks.json"
user_mapping_provider:
config:
email_template: "{{ user.email }}"
localpart_template: "{{ user.email|localpart_from_email }}"
confirm_localpart: true
attribute_requirements:
- attribute: userGroup
value: "synapseUsers"
"""
)
class PydanticOIDCTestCase(TestCase):
"""Examples to build confidence that pydantic is doing the validation we think
it's doing"""
# Each test gets a dummy config it can change as it sees fit
config: Dict[str, Any]
def setUp(self) -> None:
self.config = deepcopy(SAMPLE_CONFIG)
def test_example_config(self):
# Check that parsing the sample config doesn't raise an error.
OIDCProviderModel.parse_obj(self.config)
def test_idp_id(self) -> None:
"""Example of using a Pydantic constr() field without a default."""
# Enforce that idp_id is required.
with self.assertRaises(ValidationError):
del self.config["idp_id"]
OIDCProviderModel.parse_obj(self.config)
# Enforce that idp_id is a string.
for bad_value in 123, None, ["a"], {"a": "b"}:
with self.assertRaises(ValidationError):
self.config["idp_id"] = bad_value
OIDCProviderModel.parse_obj(self.config)
# Enforce a length between 1 and 250.
with self.assertRaises(ValidationError):
self.config["idp_id"] = ""
OIDCProviderModel.parse_obj(self.config)
with self.assertRaises(ValidationError):
self.config["idp_id"] = "a" * 251
OIDCProviderModel.parse_obj(self.config)
# Enforce the regex
with self.assertRaises(ValidationError):
self.config["idp_id"] = "$"
OIDCProviderModel.parse_obj(self.config)
# What happens with a really long string of prohibited characters?
with self.assertRaises(ValidationError):
self.config["idp_id"] = "$" * 500
OIDCProviderModel.parse_obj(self.config)
def test_legacy_model(self) -> None:
"""Example of widening a field's type in a subclass."""
# Check that parsing the sample config doesn't raise an error.
LegacyOIDCProviderModel.parse_obj(self.config)
# Check we have default values for the attributes which have a legacy fallback
del self.config["idp_id"]
del self.config["idp_name"]
model = LegacyOIDCProviderModel.parse_obj(self.config)
self.assertEqual(model.idp_id, "oidc")
self.assertEqual(model.idp_name, "OIDC")
# Check we still reject bad types
for bad_value in 123, [], {}, None:
with self.assertRaises(ValidationError) as e:
self.config["idp_id"] = bad_value
self.config["idp_name"] = bad_value
LegacyOIDCProviderModel.parse_obj(self.config)
# And while we're at it, check that we spot errors in both fields
reported_bad_fields = {item["loc"] for item in e.exception.errors()}
expected_bad_fields = {("idp_id",), ("idp_name",)}
self.assertEqual(
reported_bad_fields, expected_bad_fields, e.exception.errors()
)
def test_issuer(self) -> None:
"""Example of a StrictStr field without a default."""
# Empty and nonempty strings should be accepted.
for good_value in "", "hello", "hello" * 1000, "☃":
self.config["issuer"] = good_value
OIDCProviderModel.parse_obj(self.config)
# Invalid types should be rejected.
for bad_value in 123, None, ["h", "e", "l", "l", "o"], {"hello": "there"}:
with self.assertRaises(ValidationError):
self.config["issuer"] = bad_value
OIDCProviderModel.parse_obj(self.config)
# A missing issuer should be rejected.
with self.assertRaises(ValidationError):
del self.config["issuer"]
OIDCProviderModel.parse_obj(self.config)
def test_idp_brand(self) -> None:
"""Example of an Optional[StrictStr] field."""
# Empty and nonempty strings should be accepted.
for good_value in "", "hello", "hello" * 1000, "☃":
self.config["idp_brand"] = good_value
OIDCProviderModel.parse_obj(self.config)
# Invalid types should be rejected.
for bad_value in 123, ["h", "e", "l", "l", "o"], {"hello": "there"}:
with self.assertRaises(ValidationError):
self.config["idp_brand"] = bad_value
OIDCProviderModel.parse_obj(self.config)
# A lack of an idp_brand is fine...
del self.config["idp_brand"]
model = OIDCProviderModel.parse_obj(self.config)
self.assertIsNone(model.idp_brand)
# ... and interpreted the same as an explicit `None`.
self.config["idp_brand"] = None
model = OIDCProviderModel.parse_obj(self.config)
self.assertIsNone(model.idp_brand)
def test_idp_icon(self) -> None:
"""Example of a field with a custom validator."""
# Test that bad types are rejected, even with our validator in place
bad_value: object
for bad_value in None, {}, [], 123, 45.6:
with self.assertRaises(ValidationError):
self.config["idp_icon"] = bad_value
OIDCProviderModel.parse_obj(self.config)
# Test that bad strings are rejected by our validator
for bad_value in "", "notaurl", "https://example.com", "mxc://mxc://mxc://":
with self.assertRaises(ValidationError):
self.config["idp_icon"] = bad_value
OIDCProviderModel.parse_obj(self.config)
def test_discover(self) -> None:
"""Example of a StrictBool field with a default."""
# Booleans are permitted.
for value in True, False:
self.config["discover"] = value
model = OIDCProviderModel.parse_obj(self.config)
self.assertEqual(model.discover, value)
# Invalid types should be rejected.
for bad_value in (
-1.0,
0,
1,
float("nan"),
"yes",
"NO",
"True",
"true",
None,
"None",
"null",
["a"],
{"a": "b"},
):
self.config["discover"] = bad_value
with self.assertRaises(ValidationError):
OIDCProviderModel.parse_obj(self.config)
# A missing value is okay, because this field has a default.
del self.config["discover"]
model = OIDCProviderModel.parse_obj(self.config)
self.assertIs(model.discover, True)
def test_client_auth_method(self) -> None:
"""This is an example of using a Pydantic string enum field."""
# check the allowed values are permitted and deserialise to an enum member
for method in "client_secret_basic", "client_secret_post", "none":
self.config["client_auth_method"] = method
model = OIDCProviderModel.parse_obj(self.config)
self.assertIs(model.client_auth_method, ClientAuthMethods[method])
# check the default applies if no auth method is provided.
del self.config["client_auth_method"]
model = OIDCProviderModel.parse_obj(self.config)
self.assertIs(model.client_auth_method, ClientAuthMethods.client_secret_basic)
# Check invalid types are rejected
for bad_value in 123, ["client_secret_basic"], {"a": 1}, None:
with self.assertRaises(ValidationError):
self.config["client_auth_method"] = bad_value
OIDCProviderModel.parse_obj(self.config)
# Check that disallowed strings are rejected
with self.assertRaises(ValidationError):
self.config["client_auth_method"] = "No, Luke, _I_ am your father!"
OIDCProviderModel.parse_obj(self.config)
def test_scopes(self) -> None:
"""Example of a Tuple[StrictStr] with a default."""
# Check that the parsed object holds a tuple
self.config["scopes"] = []
model = OIDCProviderModel.parse_obj(self.config)
self.assertEqual(model.scopes, ())
# Check a variety of list lengths are accepted.
for good_value in ["aa"], ["hello", "world"], ["a"] * 4, [""] * 20:
self.config["scopes"] = good_value
model = OIDCProviderModel.parse_obj(self.config)
self.assertEqual(model.scopes, tuple(good_value))
# Check invalid types are rejected.
for bad_value in (
"",
"abc",
123,
{},
{"a": 1},
None,
[None],
[["a"]],
[{}],
[456],
):
with self.assertRaises(ValidationError):
self.config["scopes"] = bad_value
OIDCProviderModel.parse_obj(self.config)
# Check that "scopes" may be omitted.
del self.config["scopes"]
model = OIDCProviderModel.parse_obj(self.config)
self.assertEqual(model.scopes, ("openid",))
@parameterized.expand(["authorization_endpoint", "token_endpoint"])
def test_endpoints_required_when_discovery_disabled(self, key: str) -> None:
"""Example of a validator that applies to multiple fields."""
# Test that this field is required if discovery is disabled
self.config["discover"] = False
with self.assertRaises(ValidationError):
self.config[key] = None
OIDCProviderModel.parse_obj(self.config)
with self.assertRaises(ValidationError):
del self.config[key]
OIDCProviderModel.parse_obj(self.config)
# We don't validate that the endpoint is a sensible URL; anything str will do
self.config[key] = "blahblah"
OIDCProviderModel.parse_obj(self.config)
def check_all_cases_pass():
self.config[key] = None
OIDCProviderModel.parse_obj(self.config)
del self.config[key]
OIDCProviderModel.parse_obj(self.config)
self.config[key] = "blahblah"
OIDCProviderModel.parse_obj(self.config)
# With discovery enabled, all three cases are accepted.
self.config["discover"] = True
check_all_cases_pass()
# If not specified, discovery is also on by default.
del self.config["discover"]
check_all_cases_pass()
def test_userinfo_endpoint(self) -> None:
"""Example of a more fiddly validator"""
# This field is required if discovery is disabled and the openid scope
# not requested.
self.assertNotIn("userinfo_endpoint", self.config)
with self.assertRaises(ValidationError):
self.config["discover"] = False
self.config["scopes"] = ()
OIDCProviderModel.parse_obj(self.config)
# Still an error even if other scopes are provided
with self.assertRaises(ValidationError):
self.config["discover"] = False
self.config["scopes"] = ("potato", "tomato")
OIDCProviderModel.parse_obj(self.config)
# Passing an explicit None for userinfo_endpoint should also be an error.
with self.assertRaises(ValidationError):
self.config["discover"] = False
self.config["scopes"] = ()
self.config["userinfo_endpoint"] = None
OIDCProviderModel.parse_obj(self.config)
# No error if we enable discovery.
self.config["discover"] = True
self.config["scopes"] = ()
self.config["userinfo_endpoint"] = None
OIDCProviderModel.parse_obj(self.config)
# No error if we enable the openid scope.
self.config["discover"] = False
self.config["scopes"] = ("openid",)
self.config["userinfo_endpoint"] = None
OIDCProviderModel.parse_obj(self.config)
# No error if we don't specify scopes. (They default to `("openid", )`)
self.config["discover"] = False
del self.config["scopes"]
self.config["userinfo_endpoint"] = None
OIDCProviderModel.parse_obj(self.config)
def test_attribute_requirements(self):
# Example of a field involving a nested model
model = OIDCProviderModel.parse_obj(self.config)
self.assertIsInstance(model.attribute_requirements, tuple)
self.assertEqual(
len(model.attribute_requirements), 1, model.attribute_requirements
)
# Bad types should be rejected
bad_value: object
for bad_value in 123, 456.0, False, None, {}, ["hello"]:
with self.assertRaises(ValidationError):
self.config["attribute_requirements"] = bad_value
OIDCProviderModel.parse_obj(self.config)
# An empty list of requirements is okay, ...
self.config["attribute_requirements"] = []
OIDCProviderModel.parse_obj(self.config)
# ...as is an omitted list of requirements...
del self.config["attribute_requirements"]
OIDCProviderModel.parse_obj(self.config)
# ...but not an explicit None.
with self.assertRaises(ValidationError):
self.config["attribute_requirements"] = None
OIDCProviderModel.parse_obj(self.config)
# Multiple requirements are fine.
self.config["attribute_requirements"] = [{"attribute": "k", "value": "v"}] * 3
model = OIDCProviderModel.parse_obj(self.config)
self.assertEqual(
len(model.attribute_requirements), 3, model.attribute_requirements
)
# The submodel's field types should be enforced too.
with self.assertRaises(ValidationError):
self.config["attribute_requirements"] = [{"attribute": "key", "value": 123}]
OIDCProviderModel.parse_obj(self.config)
with self.assertRaises(ValidationError):
self.config["attribute_requirements"] = [{"attribute": 123, "value": "val"}]
OIDCProviderModel.parse_obj(self.config)
with self.assertRaises(ValidationError):
self.config["attribute_requirements"] = [{"attribute": "a", "value": ["b"]}]
OIDCProviderModel.parse_obj(self.config)
with self.assertRaises(ValidationError):
self.config["attribute_requirements"] = [{"attribute": "a", "value": None}]
OIDCProviderModel.parse_obj(self.config)
# Missing fields in the submodel are an error.
with self.assertRaises(ValidationError):
self.config["attribute_requirements"] = [{"attribute": "a"}]
OIDCProviderModel.parse_obj(self.config)
with self.assertRaises(ValidationError):
self.config["attribute_requirements"] = [{"value": "v"}]
OIDCProviderModel.parse_obj(self.config)
with self.assertRaises(ValidationError):
self.config["attribute_requirements"] = [{}]
OIDCProviderModel.parse_obj(self.config)
# Extra fields in the submodel are an error.
with self.assertRaises(ValidationError):
self.config["attribute_requirements"] = [
{"attribute": "a", "value": "v", "answer": "forty-two"}
]
OIDCProviderModel.parse_obj(self.config)