Skip to content

Commit

Permalink
Add coverage and improve performance of is_ssh_key (#940)
Browse files Browse the repository at this point in the history
* Add coverage and improve performance of is_ssh_key

* simplify
  • Loading branch information
bdraco authored May 11, 2024
1 parent 97345a7 commit 2afbe32
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 15 deletions.
17 changes: 3 additions & 14 deletions jwt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,26 +131,15 @@ def is_pem_format(key: bytes) -> bool:


# Based on https://github.com/pyca/cryptography/blob/bcb70852d577b3f490f015378c75cba74986297b/src/cryptography/hazmat/primitives/serialization/ssh.py#L40-L46
_CERT_SUFFIX = b"[email protected]"
_SSH_PUBKEY_RC = re.compile(rb"\A(\S+)[ \t]+(\S+)")
_SSH_KEY_FORMATS = [
_SSH_KEY_FORMATS = (
b"ssh-ed25519",
b"ssh-rsa",
b"ssh-dss",
b"ecdsa-sha2-nistp256",
b"ecdsa-sha2-nistp384",
b"ecdsa-sha2-nistp521",
]
)


def is_ssh_key(key: bytes) -> bool:
if any(string_value in key for string_value in _SSH_KEY_FORMATS):
return True

ssh_pubkey_match = _SSH_PUBKEY_RC.match(key)
if ssh_pubkey_match:
key_type = ssh_pubkey_match.group(1)
if _CERT_SUFFIX == key_type[-len(_CERT_SUFFIX) :]:
return True

return False
return key.startswith(_SSH_KEY_FORMATS)
18 changes: 17 additions & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest

from jwt.utils import force_bytes, from_base64url_uint, to_base64url_uint
from jwt.utils import force_bytes, from_base64url_uint, is_ssh_key, to_base64url_uint


@pytest.mark.parametrize(
Expand Down Expand Up @@ -37,3 +37,19 @@ def test_from_base64url_uint(inputval, expected):
def test_force_bytes_raises_error_on_invalid_object():
with pytest.raises(TypeError):
force_bytes({}) # type: ignore[arg-type]


@pytest.mark.parametrize(
"key_format",
(
b"ssh-ed25519",
b"ssh-rsa",
b"ssh-dss",
b"ecdsa-sha2-nistp256",
b"ecdsa-sha2-nistp384",
b"ecdsa-sha2-nistp521",
),
)
def test_is_ssh_key(key_format):
assert is_ssh_key(key_format + b" any") is True
assert is_ssh_key(b"not a ssh key") is False

0 comments on commit 2afbe32

Please sign in to comment.