diff --git a/src/pyotp/hotp.py b/src/pyotp/hotp.py index be9d536..84b6e9b 100644 --- a/src/pyotp/hotp.py +++ b/src/pyotp/hotp.py @@ -29,6 +29,11 @@ def __init__( """ if digest is None: digest = hashlib.sha1 + elif digest in [ + hashlib.md5, + hashlib.shake_128 + ]: + raise ValueError("selected digest function must generate digest size greater than or equals to 18 bytes") self.initial_count = initial_count super().__init__(s=s, digits=digits, digest=digest, name=name, issuer=issuer) diff --git a/src/pyotp/otp.py b/src/pyotp/otp.py index 9f0d1af..b339b8f 100644 --- a/src/pyotp/otp.py +++ b/src/pyotp/otp.py @@ -21,6 +21,11 @@ def __init__( if digits > 10: raise ValueError("digits must be no greater than 10") self.digest = digest + if digest in [ + hashlib.md5, + hashlib.shake_128 + ]: + raise ValueError("selected digest function must generate digest size greater than or equals to 18 bytes") self.secret = s self.name = name or "Secret" self.issuer = issuer @@ -33,6 +38,8 @@ def generate_otp(self, input: int) -> str: if input < 0: raise ValueError("input must be positive integer") hasher = hmac.new(self.byte_secret(), self.int_to_bytestring(input), self.digest) + if hasher.digest_size < 18: + raise ValueError("digest size is lower than 18 bytes, which will trigger error on otp generation") hmac_hash = bytearray(hasher.digest()) offset = hmac_hash[-1] & 0xF code = ( diff --git a/src/pyotp/totp.py b/src/pyotp/totp.py index fd49ed0..ae077bc 100644 --- a/src/pyotp/totp.py +++ b/src/pyotp/totp.py @@ -32,6 +32,11 @@ def __init__( """ if digest is None: digest = hashlib.sha1 + elif digest in [ + hashlib.md5, + hashlib.shake_128 + ]: + raise ValueError("selected digest function must generate digest size greater than or equals to 18 bytes") self.interval = interval super().__init__(s=s, digits=digits, digest=digest, name=name, issuer=issuer) diff --git a/test.py b/test.py index 04f7374..8c902af 100755 --- a/test.py +++ b/test.py @@ -334,6 +334,16 @@ def test_valid_window(self): self.assertTrue(totp.verify("681610", 200, 1)) self.assertFalse(totp.verify("195979", 200, 1)) +class DigestFunctionTest(unittest.TestCase): + def test_md5(self): + with self.assertRaises(ValueError) as cm: + pyotp.OTP(s="secret", digest=hashlib.md5) + self.assertEqual("selected digest function must generate digest size greater than or equals to 18 bytes", str(cm.exception)) + + def test_shake128(self): + with self.assertRaises(ValueError) as cm: + pyotp.OTP(s="secret", digest=hashlib.shake_128) + self.assertEqual("selected digest function must generate digest size greater than or equals to 18 bytes", str(cm.exception)) class ParseUriTest(unittest.TestCase): def test_invalids(self):