Skip to content

Commit

Permalink
Use Union type to describe both types of payload
Browse files Browse the repository at this point in the history
  • Loading branch information
jacopofar committed Apr 8, 2018
1 parent 060ff60 commit af795a8
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 51 deletions.
18 changes: 9 additions & 9 deletions jwt/api_jws.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from collections import Mapping
try:
# import required by mypy to perform type checking, not used for normal execution
from typing import Callable, Dict, List, Optional # NOQA
from typing import Callable, Dict, List, Optional, Union # NOQA
except ImportError:
pass

Expand Down Expand Up @@ -74,13 +74,13 @@ def get_algorithms(self):
"""
return list(self._valid_algs)

def encode_bytes(self,
payload, # type: bytes
key, # type: str
algorithm='HS256', # type: str
headers=None, # type: Optional[Dict]
json_encoder=None # type: Optional[Callable]
):
def encode(self,
payload, # type: Union[Dict, bytes]
key, # type: str
algorithm='HS256', # type: str
headers=None, # type: Optional[Dict]
json_encoder=None # type: Optional[Callable]
):
segments = []

if algorithm is None:
Expand Down Expand Up @@ -236,7 +236,7 @@ def _validate_kid(self, kid):


_jws_global_obj = PyJWS()
encode = _jws_global_obj.encode_bytes
encode = _jws_global_obj.encode
decode = _jws_global_obj.decode
register_algorithm = _jws_global_obj.register_algorithm
unregister_algorithm = _jws_global_obj.unregister_algorithm
Expand Down
8 changes: 4 additions & 4 deletions jwt/api_jwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from datetime import datetime, timedelta
try:
# import required by mypy to perform type checking, not used for normal execution
from typing import Callable, Dict, List, Optional # NOQA
from typing import Callable, Dict, List, Optional, Union # NOQA
except ImportError:
pass

Expand Down Expand Up @@ -39,7 +39,7 @@ def _get_default_options():
}

def encode(self,
payload, # type: Dict
payload, # type: Union[Dict, bytes]
key, # type: str
algorithm='HS256', # type: str
headers=None, # type: Optional[Dict]
Expand All @@ -54,15 +54,15 @@ def encode(self,
for time_claim in ['exp', 'iat', 'nbf']:
# Convert datetime to a intDate value in known time-format claims
if isinstance(payload.get(time_claim), datetime):
payload[time_claim] = timegm(payload[time_claim].utctimetuple())
payload[time_claim] = timegm(payload[time_claim].utctimetuple()) # type: ignore

json_payload = json.dumps(
payload,
separators=(',', ':'),
cls=json_encoder
).encode('utf-8')

return super(PyJWT, self).encode_bytes(
return super(PyJWT, self).encode(
json_payload, key, algorithm, headers, json_encoder
)

Expand Down
76 changes: 38 additions & 38 deletions tests/test_api_jws.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def test_override_options(self):
assert not jws.options['verify_signature']

def test_non_object_options_dont_persist(self, jws, payload):
token = jws.encode_bytes(payload, 'secret')
token = jws.encode(payload, 'secret')

jws.decode(token, 'secret', options={'verify_signature': False})

Expand All @@ -87,14 +87,14 @@ def test_options_must_be_dict(self, jws):

def test_encode_decode(self, jws, payload):
secret = 'secret'
jws_message = jws.encode_bytes(payload, secret)
jws_message = jws.encode(payload, secret)
decoded_payload = jws.decode(jws_message, secret)

assert decoded_payload == payload

def test_decode_fails_when_alg_is_not_on_method_algorithms_param(self, jws, payload):
secret = 'secret'
jws_token = jws.encode_bytes(payload, secret, algorithm='HS256')
jws_token = jws.encode(payload, secret, algorithm='HS256')
jws.decode(jws_token, secret)

with pytest.raises(InvalidAlgorithmError):
Expand Down Expand Up @@ -155,10 +155,10 @@ def test_decode_with_non_mapping_header_throws_exception(self, jws):

def test_encode_algorithm_param_should_be_case_sensitive(self, jws, payload):

jws.encode_bytes(payload, 'secret', algorithm='HS256')
jws.encode(payload, 'secret', algorithm='HS256')

with pytest.raises(NotImplementedError) as context:
jws.encode_bytes(payload, None, algorithm='hs256')
jws.encode(payload, None, algorithm='hs256')

exception = context.value
assert str(exception) == 'Algorithm not supported'
Expand All @@ -177,7 +177,7 @@ def test_decode_algorithm_param_should_be_case_sensitive(self, jws):
def test_bad_secret(self, jws, payload):
right_secret = 'foo'
bad_secret = 'bar'
jws_message = jws.encode_bytes(payload, right_secret)
jws_message = jws.encode(payload, right_secret)

with pytest.raises(DecodeError) as excinfo:
# Backward compat for ticket #315
Expand Down Expand Up @@ -258,7 +258,7 @@ def test_load_verify_valid_jws(self, jws, payload):

def test_allow_skip_verification(self, jws, payload):
right_secret = 'foo'
jws_message = jws.encode_bytes(payload, right_secret)
jws_message = jws.encode(payload, right_secret)
decoded_payload = jws.decode(jws_message, verify=False)

assert decoded_payload == payload
Expand Down Expand Up @@ -301,22 +301,22 @@ def test_decode_no_algorithms_verify_signature_false(self, jws):

def test_load_no_verification(self, jws, payload):
right_secret = 'foo'
jws_message = jws.encode_bytes(payload, right_secret)
jws_message = jws.encode(payload, right_secret)

decoded_payload = jws.decode(jws_message, key=None, verify=False)

assert decoded_payload == payload

def test_no_secret(self, jws, payload):
right_secret = 'foo'
jws_message = jws.encode_bytes(payload, right_secret)
jws_message = jws.encode(payload, right_secret)

with pytest.raises(DecodeError):
jws.decode(jws_message)

def test_verify_signature_with_no_secret(self, jws, payload):
right_secret = 'foo'
jws_message = jws.encode_bytes(payload, right_secret)
jws_message = jws.encode(payload, right_secret)

with pytest.raises(DecodeError) as exc:
jws.decode(jws_message)
Expand All @@ -335,32 +335,32 @@ def test_verify_signature_with_no_algo_header_throws_exception(self, jws, payloa

def test_invalid_crypto_alg(self, jws, payload):
with pytest.raises(NotImplementedError):
jws.encode_bytes(payload, 'secret', algorithm='HS1024')
jws.encode(payload, 'secret', algorithm='HS1024')

@pytest.mark.skipif(has_crypto, reason='Scenario requires cryptography to not be installed')
def test_missing_crypto_library_better_error_messages(self, jws, payload):
with pytest.raises(NotImplementedError) as excinfo:
jws.encode_bytes(payload, 'secret', algorithm='RS256')
jws.encode(payload, 'secret', algorithm='RS256')
assert 'cryptography' in str(excinfo.value)

def test_unicode_secret(self, jws, payload):
secret = '\xc2'
jws_message = jws.encode_bytes(payload, secret)
jws_message = jws.encode(payload, secret)
decoded_payload = jws.decode(jws_message, secret)

assert decoded_payload == payload

def test_nonascii_secret(self, jws, payload):
secret = '\xc2' # char value that ascii codec cannot decode
jws_message = jws.encode_bytes(payload, secret)
jws_message = jws.encode(payload, secret)

decoded_payload = jws.decode(jws_message, secret)

assert decoded_payload == payload

def test_bytes_secret(self, jws, payload):
secret = b'\xc2' # char value that ascii codec cannot decode
jws_message = jws.encode_bytes(payload, secret)
jws_message = jws.encode(payload, secret)

decoded_payload = jws.decode(jws_message, secret)

Expand Down Expand Up @@ -415,18 +415,18 @@ def test_decode_invalid_crypto_padding(self, jws):
assert 'Invalid crypto padding' in str(exc.value)

def test_decode_with_algo_none_should_fail(self, jws, payload):
jws_message = jws.encode_bytes(payload, key=None, algorithm=None)
jws_message = jws.encode(payload, key=None, algorithm=None)

with pytest.raises(DecodeError):
jws.decode(jws_message)

def test_decode_with_algo_none_and_verify_false_should_pass(self, jws, payload):
jws_message = jws.encode_bytes(payload, key=None, algorithm=None)
jws_message = jws.encode(payload, key=None, algorithm=None)
jws.decode(jws_message, verify=False)

def test_get_unverified_header_returns_header_values(self, jws, payload):
jws_message = jws.encode_bytes(payload, key='secret', algorithm='HS256',
headers={'kid': 'toomanysecrets'})
jws_message = jws.encode(payload, key='secret', algorithm='HS256',
headers={'kid': 'toomanysecrets'})

header = jws.get_unverified_header(jws_message)

Expand All @@ -451,7 +451,7 @@ def test_encode_decode_with_rsa_sha256(self, jws, payload):
with open('tests/keys/testkey_rsa', 'r') as rsa_priv_file:
priv_rsakey = load_pem_private_key(force_bytes(rsa_priv_file.read()),
password=None, backend=default_backend())
jws_message = jws.encode_bytes(payload, priv_rsakey, algorithm='RS256')
jws_message = jws.encode(payload, priv_rsakey, algorithm='RS256')

with open('tests/keys/testkey_rsa.pub', 'r') as rsa_pub_file:
pub_rsakey = load_ssh_public_key(force_bytes(rsa_pub_file.read()),
Expand All @@ -462,7 +462,7 @@ def test_encode_decode_with_rsa_sha256(self, jws, payload):
# string-formatted key
with open('tests/keys/testkey_rsa', 'r') as rsa_priv_file:
priv_rsakey = rsa_priv_file.read()
jws_message = jws.encode_bytes(payload, priv_rsakey, algorithm='RS256')
jws_message = jws.encode(payload, priv_rsakey, algorithm='RS256')

with open('tests/keys/testkey_rsa.pub', 'r') as rsa_pub_file:
pub_rsakey = rsa_pub_file.read()
Expand All @@ -474,7 +474,7 @@ def test_encode_decode_with_rsa_sha384(self, jws, payload):
with open('tests/keys/testkey_rsa', 'r') as rsa_priv_file:
priv_rsakey = load_pem_private_key(force_bytes(rsa_priv_file.read()),
password=None, backend=default_backend())
jws_message = jws.encode_bytes(payload, priv_rsakey, algorithm='RS384')
jws_message = jws.encode(payload, priv_rsakey, algorithm='RS384')

with open('tests/keys/testkey_rsa.pub', 'r') as rsa_pub_file:
pub_rsakey = load_ssh_public_key(force_bytes(rsa_pub_file.read()),
Expand All @@ -484,7 +484,7 @@ def test_encode_decode_with_rsa_sha384(self, jws, payload):
# string-formatted key
with open('tests/keys/testkey_rsa', 'r') as rsa_priv_file:
priv_rsakey = rsa_priv_file.read()
jws_message = jws.encode_bytes(payload, priv_rsakey, algorithm='RS384')
jws_message = jws.encode(payload, priv_rsakey, algorithm='RS384')

with open('tests/keys/testkey_rsa.pub', 'r') as rsa_pub_file:
pub_rsakey = rsa_pub_file.read()
Expand All @@ -496,7 +496,7 @@ def test_encode_decode_with_rsa_sha512(self, jws, payload):
with open('tests/keys/testkey_rsa', 'r') as rsa_priv_file:
priv_rsakey = load_pem_private_key(force_bytes(rsa_priv_file.read()),
password=None, backend=default_backend())
jws_message = jws.encode_bytes(payload, priv_rsakey, algorithm='RS512')
jws_message = jws.encode(payload, priv_rsakey, algorithm='RS512')

with open('tests/keys/testkey_rsa.pub', 'r') as rsa_pub_file:
pub_rsakey = load_ssh_public_key(force_bytes(rsa_pub_file.read()),
Expand All @@ -506,7 +506,7 @@ def test_encode_decode_with_rsa_sha512(self, jws, payload):
# string-formatted key
with open('tests/keys/testkey_rsa', 'r') as rsa_priv_file:
priv_rsakey = rsa_priv_file.read()
jws_message = jws.encode_bytes(payload, priv_rsakey, algorithm='RS512')
jws_message = jws.encode(payload, priv_rsakey, algorithm='RS512')

with open('tests/keys/testkey_rsa.pub', 'r') as rsa_pub_file:
pub_rsakey = rsa_pub_file.read()
Expand Down Expand Up @@ -538,7 +538,7 @@ def test_encode_decode_with_ecdsa_sha256(self, jws, payload):
with open('tests/keys/testkey_ec', 'r') as ec_priv_file:
priv_eckey = load_pem_private_key(force_bytes(ec_priv_file.read()),
password=None, backend=default_backend())
jws_message = jws.encode_bytes(payload, priv_eckey, algorithm='ES256')
jws_message = jws.encode(payload, priv_eckey, algorithm='ES256')

with open('tests/keys/testkey_ec.pub', 'r') as ec_pub_file:
pub_eckey = load_pem_public_key(force_bytes(ec_pub_file.read()),
Expand All @@ -548,7 +548,7 @@ def test_encode_decode_with_ecdsa_sha256(self, jws, payload):
# string-formatted key
with open('tests/keys/testkey_ec', 'r') as ec_priv_file:
priv_eckey = ec_priv_file.read()
jws_message = jws.encode_bytes(payload, priv_eckey, algorithm='ES256')
jws_message = jws.encode(payload, priv_eckey, algorithm='ES256')

with open('tests/keys/testkey_ec.pub', 'r') as ec_pub_file:
pub_eckey = ec_pub_file.read()
Expand All @@ -561,7 +561,7 @@ def test_encode_decode_with_ecdsa_sha384(self, jws, payload):
with open('tests/keys/testkey_ec', 'r') as ec_priv_file:
priv_eckey = load_pem_private_key(force_bytes(ec_priv_file.read()),
password=None, backend=default_backend())
jws_message = jws.encode_bytes(payload, priv_eckey, algorithm='ES384')
jws_message = jws.encode(payload, priv_eckey, algorithm='ES384')

with open('tests/keys/testkey_ec.pub', 'r') as ec_pub_file:
pub_eckey = load_pem_public_key(force_bytes(ec_pub_file.read()),
Expand All @@ -571,7 +571,7 @@ def test_encode_decode_with_ecdsa_sha384(self, jws, payload):
# string-formatted key
with open('tests/keys/testkey_ec', 'r') as ec_priv_file:
priv_eckey = ec_priv_file.read()
jws_message = jws.encode_bytes(payload, priv_eckey, algorithm='ES384')
jws_message = jws.encode(payload, priv_eckey, algorithm='ES384')

with open('tests/keys/testkey_ec.pub', 'r') as ec_pub_file:
pub_eckey = ec_pub_file.read()
Expand All @@ -583,7 +583,7 @@ def test_encode_decode_with_ecdsa_sha512(self, jws, payload):
with open('tests/keys/testkey_ec', 'r') as ec_priv_file:
priv_eckey = load_pem_private_key(force_bytes(ec_priv_file.read()),
password=None, backend=default_backend())
jws_message = jws.encode_bytes(payload, priv_eckey, algorithm='ES521')
jws_message = jws.encode(payload, priv_eckey, algorithm='ES521')

with open('tests/keys/testkey_ec.pub', 'r') as ec_pub_file:
pub_eckey = load_pem_public_key(force_bytes(ec_pub_file.read()), backend=default_backend())
Expand All @@ -592,7 +592,7 @@ def test_encode_decode_with_ecdsa_sha512(self, jws, payload):
# string-formatted key
with open('tests/keys/testkey_ec', 'r') as ec_priv_file:
priv_eckey = ec_priv_file.read()
jws_message = jws.encode_bytes(payload, priv_eckey, algorithm='ES521')
jws_message = jws.encode(payload, priv_eckey, algorithm='ES521')

with open('tests/keys/testkey_ec.pub', 'r') as ec_pub_file:
pub_eckey = ec_pub_file.read()
Expand All @@ -618,7 +618,7 @@ def test_skip_check_signature(self, jws):
jws.decode(token, 'secret', options={'verify_signature': False})

def test_decode_options_must_be_dict(self, jws, payload):
token = jws.encode_bytes(payload, 'secret')
token = jws.encode(payload, 'secret')

with pytest.raises(TypeError):
jws.decode(token, 'secret', options=object())
Expand All @@ -640,10 +640,10 @@ def default(self, o):
}

with pytest.raises(TypeError):
jws.encode_bytes(payload, 'secret', headers=data)
jws.encode(payload, 'secret', headers=data)

token = jws.encode_bytes(payload, 'secret', headers=data,
json_encoder=CustomJSONEncoder)
token = jws.encode(payload, 'secret', headers=data,
json_encoder=CustomJSONEncoder)

header = force_bytes(force_unicode(token).split('.')[0])
header = json.loads(force_unicode(base64url_decode(header)))
Expand All @@ -653,7 +653,7 @@ def default(self, o):

def test_encode_headers_parameter_adds_headers(self, jws, payload):
headers = {'testheader': True}
token = jws.encode_bytes(payload, 'secret', headers=headers)
token = jws.encode(payload, 'secret', headers=headers)

if not isinstance(token, string_types):
token = token.decode()
Expand All @@ -671,11 +671,11 @@ def test_encode_headers_parameter_adds_headers(self, jws, payload):

def test_encode_fails_on_invalid_kid_types(self, jws, payload):
with pytest.raises(InvalidTokenError) as exc:
jws.encode_bytes(payload, 'secret', headers={'kid': 123})
jws.encode(payload, 'secret', headers={'kid': 123})

assert 'Key ID header parameter must be a string' == str(exc.value)

with pytest.raises(InvalidTokenError) as exc:
jws.encode_bytes(payload, 'secret', headers={'kid': None})
jws.encode(payload, 'secret', headers={'kid': None})

assert 'Key ID header parameter must be a string' == str(exc.value)

0 comments on commit af795a8

Please sign in to comment.