Skip to content

Commit

Permalink
Merge pull request scitokens#80 from djw8605/autodetect-method
Browse files Browse the repository at this point in the history
Adding autodetection of the algorithm
  • Loading branch information
bbockelm authored Jul 30, 2018
2 parents 7943cda + da8de8a commit 1fbf4cd
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 1 deletion.
39 changes: 38 additions & 1 deletion src/scitokens/scitokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from .utils import config
from .utils.errors import MissingIssuerException, InvalidTokenFormat, MissingKeyException, UnsupportedKeyException
from cryptography.hazmat.primitives.serialization import load_pem_public_key
from cryptography.hazmat.primitives.asymmetric import rsa, ec

class SciToken(object):
"""
Expand All @@ -41,10 +42,28 @@ def __init__(self, key=None, algorithm=None, key_id=None, parent=None, claims=No
raise NotImplementedError()

self._key = key
derived_alg = None
if key:
derived_alg = self._derive_algorithm(key)

# Make sure we support the key algorithm
if key and not algorithm and not derived_alg:
# We don't know the key algorithm
raise UnsupportedKeyException("Key was given for SciToken, but algorithm was not "
"passed to SciToken creation and it cannot be derived "
"from the provided key")
elif derived_alg and not algorithm:
self._key_alg = derived_alg
elif derived_alg and algorithm and derived_alg != algorithm:
error_str = ("Key provided reports algorithm type: {0}, ".format(derived_alg) +
"while scitoken creation argument was {0}".format(algorithm))
raise UnsupportedKeyException(error_str)
elif key and algorithm:
self._key_alg = algorithm
else:
# If key is not specified, and neither is algorithm
self._key_alg = algorithm if algorithm is not None else config.get('default_alg')

self._key_alg = algorithm if algorithm is not None else config.get('default_alg')
if self._key_alg not in ["RS256", "ES256"]:
raise UnsupportedKeyException()
self._key_id = key_id
Expand All @@ -54,6 +73,24 @@ def __init__(self, key=None, algorithm=None, key_id=None, parent=None, claims=No
self.insecure = False
self._serialized_token = None

@staticmethod
def _derive_algorithm(key):
"""
Derive the algorithm type from the PEM contents of the key
returns: Key algorithm if known, otherwise None
"""

if isinstance(key, rsa.RSAPrivateKey):
return "RS256"
elif isinstance(key, ec.EllipticCurvePrivateKey):
if key.curve.name == "secp256r1":
return "ES256"

# If it gets here, we don't know what type of key
return None


def claims(self):
"""
Return an iterator of (key, value) pairs of claims, starting
Expand Down
31 changes: 31 additions & 0 deletions tests/test_create_scitoken.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,37 @@ def test_unsupported_key(self):
with self.assertRaises(UnsupportedKeyException):
scitokens.SciToken(key = self._private_key, algorithm="doesnotexist")

def test_autodetect_keytype(self):
"""
Test the autodetection of the key type
"""
private_key = generate_private_key(
public_exponent=65537,
key_size=2048,
backend=default_backend()
)

ec_private_key = ec.generate_private_key(
ec.SECP256R1(), default_backend()
)

# Test when we give it the wrong algorithm type
with self.assertRaises(scitokens.scitokens.UnsupportedKeyException):
token = scitokens.SciToken(key = private_key, algorithm="ES256")

# Test when we give it the wrong algorithm type
with self.assertRaises(scitokens.scitokens.UnsupportedKeyException):
token = scitokens.SciToken(key = ec_private_key, algorithm="RS256")

# Test when we give an unsupported algorithm
unsupported_private_key = ec.generate_private_key(
ec.SECP192R1(), default_backend()
)
with self.assertRaises(scitokens.scitokens.UnsupportedKeyException):
token = scitokens.SciToken(key = unsupported_private_key)

token = scitokens.SciToken(key = ec_private_key, algorithm="ES256")
token.serialize(issuer="local")

if __name__ == '__main__':
unittest.main()

0 comments on commit 1fbf4cd

Please sign in to comment.