From af795a89fca39da28494986a5ae379fb73aae2c9 Mon Sep 17 00:00:00 2001 From: Jacopo Farina Date: Sun, 8 Apr 2018 22:05:22 +0200 Subject: [PATCH] Use Union type to describe both types of payload --- jwt/api_jws.py | 18 +++++----- jwt/api_jwt.py | 8 ++--- tests/test_api_jws.py | 76 +++++++++++++++++++++---------------------- 3 files changed, 51 insertions(+), 51 deletions(-) diff --git a/jwt/api_jws.py b/jwt/api_jws.py index 37f5f51e..3e429a75 100644 --- a/jwt/api_jws.py +++ b/jwt/api_jws.py @@ -4,7 +4,7 @@ from collections import Mapping try: # import required by mypy to perform type checking, not used for normal execution - from typing import Callable, Dict, List, Optional # NOQA + from typing import Callable, Dict, List, Optional, Union # NOQA except ImportError: pass @@ -74,13 +74,13 @@ def get_algorithms(self): """ return list(self._valid_algs) - def encode_bytes(self, - payload, # type: bytes - key, # type: str - algorithm='HS256', # type: str - headers=None, # type: Optional[Dict] - json_encoder=None # type: Optional[Callable] - ): + def encode(self, + payload, # type: Union[Dict, bytes] + key, # type: str + algorithm='HS256', # type: str + headers=None, # type: Optional[Dict] + json_encoder=None # type: Optional[Callable] + ): segments = [] if algorithm is None: @@ -236,7 +236,7 @@ def _validate_kid(self, kid): _jws_global_obj = PyJWS() -encode = _jws_global_obj.encode_bytes +encode = _jws_global_obj.encode decode = _jws_global_obj.decode register_algorithm = _jws_global_obj.register_algorithm unregister_algorithm = _jws_global_obj.unregister_algorithm diff --git a/jwt/api_jwt.py b/jwt/api_jwt.py index 1eb12a0f..90fa0a17 100644 --- a/jwt/api_jwt.py +++ b/jwt/api_jwt.py @@ -5,7 +5,7 @@ from datetime import datetime, timedelta try: # import required by mypy to perform type checking, not used for normal execution - from typing import Callable, Dict, List, Optional # NOQA + from typing import Callable, Dict, List, Optional, Union # NOQA except ImportError: pass @@ -39,7 +39,7 @@ def _get_default_options(): } def encode(self, - payload, # type: Dict + payload, # type: Union[Dict, bytes] key, # type: str algorithm='HS256', # type: str headers=None, # type: Optional[Dict] @@ -54,7 +54,7 @@ def encode(self, for time_claim in ['exp', 'iat', 'nbf']: # Convert datetime to a intDate value in known time-format claims if isinstance(payload.get(time_claim), datetime): - payload[time_claim] = timegm(payload[time_claim].utctimetuple()) + payload[time_claim] = timegm(payload[time_claim].utctimetuple()) # type: ignore json_payload = json.dumps( payload, @@ -62,7 +62,7 @@ def encode(self, cls=json_encoder ).encode('utf-8') - return super(PyJWT, self).encode_bytes( + return super(PyJWT, self).encode( json_payload, key, algorithm, headers, json_encoder ) diff --git a/tests/test_api_jws.py b/tests/test_api_jws.py index 6efb879f..4bf8d398 100644 --- a/tests/test_api_jws.py +++ b/tests/test_api_jws.py @@ -75,7 +75,7 @@ def test_override_options(self): assert not jws.options['verify_signature'] def test_non_object_options_dont_persist(self, jws, payload): - token = jws.encode_bytes(payload, 'secret') + token = jws.encode(payload, 'secret') jws.decode(token, 'secret', options={'verify_signature': False}) @@ -87,14 +87,14 @@ def test_options_must_be_dict(self, jws): def test_encode_decode(self, jws, payload): secret = 'secret' - jws_message = jws.encode_bytes(payload, secret) + jws_message = jws.encode(payload, secret) decoded_payload = jws.decode(jws_message, secret) assert decoded_payload == payload def test_decode_fails_when_alg_is_not_on_method_algorithms_param(self, jws, payload): secret = 'secret' - jws_token = jws.encode_bytes(payload, secret, algorithm='HS256') + jws_token = jws.encode(payload, secret, algorithm='HS256') jws.decode(jws_token, secret) with pytest.raises(InvalidAlgorithmError): @@ -155,10 +155,10 @@ def test_decode_with_non_mapping_header_throws_exception(self, jws): def test_encode_algorithm_param_should_be_case_sensitive(self, jws, payload): - jws.encode_bytes(payload, 'secret', algorithm='HS256') + jws.encode(payload, 'secret', algorithm='HS256') with pytest.raises(NotImplementedError) as context: - jws.encode_bytes(payload, None, algorithm='hs256') + jws.encode(payload, None, algorithm='hs256') exception = context.value assert str(exception) == 'Algorithm not supported' @@ -177,7 +177,7 @@ def test_decode_algorithm_param_should_be_case_sensitive(self, jws): def test_bad_secret(self, jws, payload): right_secret = 'foo' bad_secret = 'bar' - jws_message = jws.encode_bytes(payload, right_secret) + jws_message = jws.encode(payload, right_secret) with pytest.raises(DecodeError) as excinfo: # Backward compat for ticket #315 @@ -258,7 +258,7 @@ def test_load_verify_valid_jws(self, jws, payload): def test_allow_skip_verification(self, jws, payload): right_secret = 'foo' - jws_message = jws.encode_bytes(payload, right_secret) + jws_message = jws.encode(payload, right_secret) decoded_payload = jws.decode(jws_message, verify=False) assert decoded_payload == payload @@ -301,7 +301,7 @@ def test_decode_no_algorithms_verify_signature_false(self, jws): def test_load_no_verification(self, jws, payload): right_secret = 'foo' - jws_message = jws.encode_bytes(payload, right_secret) + jws_message = jws.encode(payload, right_secret) decoded_payload = jws.decode(jws_message, key=None, verify=False) @@ -309,14 +309,14 @@ def test_load_no_verification(self, jws, payload): def test_no_secret(self, jws, payload): right_secret = 'foo' - jws_message = jws.encode_bytes(payload, right_secret) + jws_message = jws.encode(payload, right_secret) with pytest.raises(DecodeError): jws.decode(jws_message) def test_verify_signature_with_no_secret(self, jws, payload): right_secret = 'foo' - jws_message = jws.encode_bytes(payload, right_secret) + jws_message = jws.encode(payload, right_secret) with pytest.raises(DecodeError) as exc: jws.decode(jws_message) @@ -335,24 +335,24 @@ def test_verify_signature_with_no_algo_header_throws_exception(self, jws, payloa def test_invalid_crypto_alg(self, jws, payload): with pytest.raises(NotImplementedError): - jws.encode_bytes(payload, 'secret', algorithm='HS1024') + jws.encode(payload, 'secret', algorithm='HS1024') @pytest.mark.skipif(has_crypto, reason='Scenario requires cryptography to not be installed') def test_missing_crypto_library_better_error_messages(self, jws, payload): with pytest.raises(NotImplementedError) as excinfo: - jws.encode_bytes(payload, 'secret', algorithm='RS256') + jws.encode(payload, 'secret', algorithm='RS256') assert 'cryptography' in str(excinfo.value) def test_unicode_secret(self, jws, payload): secret = '\xc2' - jws_message = jws.encode_bytes(payload, secret) + jws_message = jws.encode(payload, secret) decoded_payload = jws.decode(jws_message, secret) assert decoded_payload == payload def test_nonascii_secret(self, jws, payload): secret = '\xc2' # char value that ascii codec cannot decode - jws_message = jws.encode_bytes(payload, secret) + jws_message = jws.encode(payload, secret) decoded_payload = jws.decode(jws_message, secret) @@ -360,7 +360,7 @@ def test_nonascii_secret(self, jws, payload): def test_bytes_secret(self, jws, payload): secret = b'\xc2' # char value that ascii codec cannot decode - jws_message = jws.encode_bytes(payload, secret) + jws_message = jws.encode(payload, secret) decoded_payload = jws.decode(jws_message, secret) @@ -415,18 +415,18 @@ def test_decode_invalid_crypto_padding(self, jws): assert 'Invalid crypto padding' in str(exc.value) def test_decode_with_algo_none_should_fail(self, jws, payload): - jws_message = jws.encode_bytes(payload, key=None, algorithm=None) + jws_message = jws.encode(payload, key=None, algorithm=None) with pytest.raises(DecodeError): jws.decode(jws_message) def test_decode_with_algo_none_and_verify_false_should_pass(self, jws, payload): - jws_message = jws.encode_bytes(payload, key=None, algorithm=None) + jws_message = jws.encode(payload, key=None, algorithm=None) jws.decode(jws_message, verify=False) def test_get_unverified_header_returns_header_values(self, jws, payload): - jws_message = jws.encode_bytes(payload, key='secret', algorithm='HS256', - headers={'kid': 'toomanysecrets'}) + jws_message = jws.encode(payload, key='secret', algorithm='HS256', + headers={'kid': 'toomanysecrets'}) header = jws.get_unverified_header(jws_message) @@ -451,7 +451,7 @@ def test_encode_decode_with_rsa_sha256(self, jws, payload): with open('tests/keys/testkey_rsa', 'r') as rsa_priv_file: priv_rsakey = load_pem_private_key(force_bytes(rsa_priv_file.read()), password=None, backend=default_backend()) - jws_message = jws.encode_bytes(payload, priv_rsakey, algorithm='RS256') + jws_message = jws.encode(payload, priv_rsakey, algorithm='RS256') with open('tests/keys/testkey_rsa.pub', 'r') as rsa_pub_file: pub_rsakey = load_ssh_public_key(force_bytes(rsa_pub_file.read()), @@ -462,7 +462,7 @@ def test_encode_decode_with_rsa_sha256(self, jws, payload): # string-formatted key with open('tests/keys/testkey_rsa', 'r') as rsa_priv_file: priv_rsakey = rsa_priv_file.read() - jws_message = jws.encode_bytes(payload, priv_rsakey, algorithm='RS256') + jws_message = jws.encode(payload, priv_rsakey, algorithm='RS256') with open('tests/keys/testkey_rsa.pub', 'r') as rsa_pub_file: pub_rsakey = rsa_pub_file.read() @@ -474,7 +474,7 @@ def test_encode_decode_with_rsa_sha384(self, jws, payload): with open('tests/keys/testkey_rsa', 'r') as rsa_priv_file: priv_rsakey = load_pem_private_key(force_bytes(rsa_priv_file.read()), password=None, backend=default_backend()) - jws_message = jws.encode_bytes(payload, priv_rsakey, algorithm='RS384') + jws_message = jws.encode(payload, priv_rsakey, algorithm='RS384') with open('tests/keys/testkey_rsa.pub', 'r') as rsa_pub_file: pub_rsakey = load_ssh_public_key(force_bytes(rsa_pub_file.read()), @@ -484,7 +484,7 @@ def test_encode_decode_with_rsa_sha384(self, jws, payload): # string-formatted key with open('tests/keys/testkey_rsa', 'r') as rsa_priv_file: priv_rsakey = rsa_priv_file.read() - jws_message = jws.encode_bytes(payload, priv_rsakey, algorithm='RS384') + jws_message = jws.encode(payload, priv_rsakey, algorithm='RS384') with open('tests/keys/testkey_rsa.pub', 'r') as rsa_pub_file: pub_rsakey = rsa_pub_file.read() @@ -496,7 +496,7 @@ def test_encode_decode_with_rsa_sha512(self, jws, payload): with open('tests/keys/testkey_rsa', 'r') as rsa_priv_file: priv_rsakey = load_pem_private_key(force_bytes(rsa_priv_file.read()), password=None, backend=default_backend()) - jws_message = jws.encode_bytes(payload, priv_rsakey, algorithm='RS512') + jws_message = jws.encode(payload, priv_rsakey, algorithm='RS512') with open('tests/keys/testkey_rsa.pub', 'r') as rsa_pub_file: pub_rsakey = load_ssh_public_key(force_bytes(rsa_pub_file.read()), @@ -506,7 +506,7 @@ def test_encode_decode_with_rsa_sha512(self, jws, payload): # string-formatted key with open('tests/keys/testkey_rsa', 'r') as rsa_priv_file: priv_rsakey = rsa_priv_file.read() - jws_message = jws.encode_bytes(payload, priv_rsakey, algorithm='RS512') + jws_message = jws.encode(payload, priv_rsakey, algorithm='RS512') with open('tests/keys/testkey_rsa.pub', 'r') as rsa_pub_file: pub_rsakey = rsa_pub_file.read() @@ -538,7 +538,7 @@ def test_encode_decode_with_ecdsa_sha256(self, jws, payload): with open('tests/keys/testkey_ec', 'r') as ec_priv_file: priv_eckey = load_pem_private_key(force_bytes(ec_priv_file.read()), password=None, backend=default_backend()) - jws_message = jws.encode_bytes(payload, priv_eckey, algorithm='ES256') + jws_message = jws.encode(payload, priv_eckey, algorithm='ES256') with open('tests/keys/testkey_ec.pub', 'r') as ec_pub_file: pub_eckey = load_pem_public_key(force_bytes(ec_pub_file.read()), @@ -548,7 +548,7 @@ def test_encode_decode_with_ecdsa_sha256(self, jws, payload): # string-formatted key with open('tests/keys/testkey_ec', 'r') as ec_priv_file: priv_eckey = ec_priv_file.read() - jws_message = jws.encode_bytes(payload, priv_eckey, algorithm='ES256') + jws_message = jws.encode(payload, priv_eckey, algorithm='ES256') with open('tests/keys/testkey_ec.pub', 'r') as ec_pub_file: pub_eckey = ec_pub_file.read() @@ -561,7 +561,7 @@ def test_encode_decode_with_ecdsa_sha384(self, jws, payload): with open('tests/keys/testkey_ec', 'r') as ec_priv_file: priv_eckey = load_pem_private_key(force_bytes(ec_priv_file.read()), password=None, backend=default_backend()) - jws_message = jws.encode_bytes(payload, priv_eckey, algorithm='ES384') + jws_message = jws.encode(payload, priv_eckey, algorithm='ES384') with open('tests/keys/testkey_ec.pub', 'r') as ec_pub_file: pub_eckey = load_pem_public_key(force_bytes(ec_pub_file.read()), @@ -571,7 +571,7 @@ def test_encode_decode_with_ecdsa_sha384(self, jws, payload): # string-formatted key with open('tests/keys/testkey_ec', 'r') as ec_priv_file: priv_eckey = ec_priv_file.read() - jws_message = jws.encode_bytes(payload, priv_eckey, algorithm='ES384') + jws_message = jws.encode(payload, priv_eckey, algorithm='ES384') with open('tests/keys/testkey_ec.pub', 'r') as ec_pub_file: pub_eckey = ec_pub_file.read() @@ -583,7 +583,7 @@ def test_encode_decode_with_ecdsa_sha512(self, jws, payload): with open('tests/keys/testkey_ec', 'r') as ec_priv_file: priv_eckey = load_pem_private_key(force_bytes(ec_priv_file.read()), password=None, backend=default_backend()) - jws_message = jws.encode_bytes(payload, priv_eckey, algorithm='ES521') + jws_message = jws.encode(payload, priv_eckey, algorithm='ES521') with open('tests/keys/testkey_ec.pub', 'r') as ec_pub_file: pub_eckey = load_pem_public_key(force_bytes(ec_pub_file.read()), backend=default_backend()) @@ -592,7 +592,7 @@ def test_encode_decode_with_ecdsa_sha512(self, jws, payload): # string-formatted key with open('tests/keys/testkey_ec', 'r') as ec_priv_file: priv_eckey = ec_priv_file.read() - jws_message = jws.encode_bytes(payload, priv_eckey, algorithm='ES521') + jws_message = jws.encode(payload, priv_eckey, algorithm='ES521') with open('tests/keys/testkey_ec.pub', 'r') as ec_pub_file: pub_eckey = ec_pub_file.read() @@ -618,7 +618,7 @@ def test_skip_check_signature(self, jws): jws.decode(token, 'secret', options={'verify_signature': False}) def test_decode_options_must_be_dict(self, jws, payload): - token = jws.encode_bytes(payload, 'secret') + token = jws.encode(payload, 'secret') with pytest.raises(TypeError): jws.decode(token, 'secret', options=object()) @@ -640,10 +640,10 @@ def default(self, o): } with pytest.raises(TypeError): - jws.encode_bytes(payload, 'secret', headers=data) + jws.encode(payload, 'secret', headers=data) - token = jws.encode_bytes(payload, 'secret', headers=data, - json_encoder=CustomJSONEncoder) + token = jws.encode(payload, 'secret', headers=data, + json_encoder=CustomJSONEncoder) header = force_bytes(force_unicode(token).split('.')[0]) header = json.loads(force_unicode(base64url_decode(header))) @@ -653,7 +653,7 @@ def default(self, o): def test_encode_headers_parameter_adds_headers(self, jws, payload): headers = {'testheader': True} - token = jws.encode_bytes(payload, 'secret', headers=headers) + token = jws.encode(payload, 'secret', headers=headers) if not isinstance(token, string_types): token = token.decode() @@ -671,11 +671,11 @@ def test_encode_headers_parameter_adds_headers(self, jws, payload): def test_encode_fails_on_invalid_kid_types(self, jws, payload): with pytest.raises(InvalidTokenError) as exc: - jws.encode_bytes(payload, 'secret', headers={'kid': 123}) + jws.encode(payload, 'secret', headers={'kid': 123}) assert 'Key ID header parameter must be a string' == str(exc.value) with pytest.raises(InvalidTokenError) as exc: - jws.encode_bytes(payload, 'secret', headers={'kid': None}) + jws.encode(payload, 'secret', headers={'kid': None}) assert 'Key ID header parameter must be a string' == str(exc.value)