diff --git a/lib/Crypto/Cipher/PKCS1_OAEP.py b/lib/Crypto/Cipher/PKCS1_OAEP.py index a4a9132e..e87487e0 100644 --- a/lib/Crypto/Cipher/PKCS1_OAEP.py +++ b/lib/Crypto/Cipher/PKCS1_OAEP.py @@ -23,11 +23,13 @@ from Crypto.Signature.pss import MGF1 import Crypto.Hash.SHA1 -from Crypto.Util.py3compat import bord, _copy_bytes +from Crypto.Util.py3compat import _copy_bytes import Crypto.Util.number -from Crypto.Util.number import ceil_div, bytes_to_long, long_to_bytes -from Crypto.Util.strxor import strxor +from Crypto.Util.number import ceil_div, bytes_to_long, long_to_bytes +from Crypto.Util.strxor import strxor from Crypto import Random +from ._pkcs1_oaep_decode import oaep_decode + class PKCS1OAEP_Cipher: """Cipher object for PKCS#1 v1.5 OAEP. @@ -68,7 +70,7 @@ def __init__(self, key, hashAlgo, mgfunc, label, randfunc): if mgfunc: self._mgf = mgfunc else: - self._mgf = lambda x,y: MGF1(x,y,self._hashObj) + self._mgf = lambda x, y: MGF1(x, y, self._hashObj) self._label = _copy_bytes(None, None, label) self._randfunc = randfunc @@ -105,7 +107,7 @@ def encrypt(self, message): # See 7.1.1 in RFC3447 modBits = Crypto.Util.number.size(self._key.n) - k = ceil_div(modBits, 8) # Convert from bits to bytes + k = ceil_div(modBits, 8) # Convert from bits to bytes hLen = self._hashObj.digest_size mLen = len(message) @@ -159,11 +161,11 @@ def decrypt(self, ciphertext): # See 7.1.2 in RFC3447 modBits = Crypto.Util.number.size(self._key.n) - k = ceil_div(modBits,8) # Convert from bits to bytes + k = ceil_div(modBits, 8) # Convert from bits to bytes hLen = self._hashObj.digest_size # Step 1b and 1c - if len(ciphertext) != k or k k: - size = _pkcs1_decode(em, b'', expected_pt_len, output) + size = pkcs1_decode(em, b'', expected_pt_len, output) if size < 0: return sentinel else: return output[size:] # Step 3 (somewhat constant time) - size = _pkcs1_decode(em, sentinel, expected_pt_len, output) + size = pkcs1_decode(em, sentinel, expected_pt_len, output) return output[size:] diff --git a/lib/Crypto/Cipher/_pkcs1_oaep_decode.py b/lib/Crypto/Cipher/_pkcs1_oaep_decode.py new file mode 100644 index 00000000..fc075282 --- /dev/null +++ b/lib/Crypto/Cipher/_pkcs1_oaep_decode.py @@ -0,0 +1,41 @@ +from Crypto.Util._raw_api import (load_pycryptodome_raw_lib, c_size_t, + c_uint8_ptr) + + +_raw_pkcs1_decode = load_pycryptodome_raw_lib("Crypto.Cipher._pkcs1_decode", + """ + int pkcs1_decode(const uint8_t *em, size_t len_em, + const uint8_t *sentinel, size_t len_sentinel, + size_t expected_pt_len, + uint8_t *output); + + int oaep_decode(const uint8_t *em, + size_t em_len, + const uint8_t *lHash, + size_t hLen, + const uint8_t *db, + size_t db_len); + """) + + +def pkcs1_decode(em, sentinel, expected_pt_len, output): + if len(em) != len(output): + raise ValueError("Incorrect output length") + + ret = _raw_pkcs1_decode.pkcs1_decode(c_uint8_ptr(em), + c_size_t(len(em)), + c_uint8_ptr(sentinel), + c_size_t(len(sentinel)), + c_size_t(expected_pt_len), + c_uint8_ptr(output)) + return ret + + +def oaep_decode(em, lHash, db): + ret = _raw_pkcs1_decode.oaep_decode(c_uint8_ptr(em), + c_size_t(len(em)), + c_uint8_ptr(lHash), + c_size_t(len(lHash)), + c_uint8_ptr(db), + c_size_t(len(db))) + return ret diff --git a/src/pkcs1_decode.c b/src/pkcs1_decode.c index d065cea0..5bf3a5ed 100644 --- a/src/pkcs1_decode.c +++ b/src/pkcs1_decode.c @@ -130,7 +130,7 @@ STATIC size_t safe_select_idx(size_t in1, size_t in2, uint8_t choice) * - in1[] is NOT equal to in2[] where neq_mask[] is 0xFF. * Return non-zero otherwise. */ -STATIC uint8_t safe_cmp(const uint8_t *in1, const uint8_t *in2, +STATIC uint8_t safe_cmp_masks(const uint8_t *in1, const uint8_t *in2, const uint8_t *eq_mask, const uint8_t *neq_mask, size_t len) { @@ -187,7 +187,7 @@ STATIC size_t safe_search(const uint8_t *in1, uint8_t c, size_t len) return result; } -#define EM_PREFIX_LEN 10 +#define PKCS1_PREFIX_LEN 10 /* * Decode and verify the PKCS#1 padding, then put either the plaintext @@ -222,13 +222,13 @@ EXPORT_SYM int pkcs1_decode(const uint8_t *em, size_t len_em_output, if (NULL == em || NULL == output || NULL == sentinel) { return -1; } - if (len_em_output < (EM_PREFIX_LEN + 2)) { + if (len_em_output < (PKCS1_PREFIX_LEN + 2)) { return -1; } if (len_sentinel > len_em_output) { return -1; } - if (expected_pt_len > 0 && expected_pt_len > (len_em_output - EM_PREFIX_LEN - 1)) { + if (expected_pt_len > 0 && expected_pt_len > (len_em_output - PKCS1_PREFIX_LEN - 1)) { return -1; } @@ -240,7 +240,7 @@ EXPORT_SYM int pkcs1_decode(const uint8_t *em, size_t len_em_output, memcpy(padded_sentinel + (len_em_output - len_sentinel), sentinel, len_sentinel); /** The first 10 bytes must follow the pattern **/ - match = safe_cmp(em, + match = safe_cmp_masks(em, (const uint8_t*)"\x00\x02" "\x00\x00\x00\x00\x00\x00\x00\x00", (const uint8_t*)"\xFF\xFF" "\x00\x00\x00\x00\x00\x00\x00\x00", (const uint8_t*)"\x00\x00" "\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF", @@ -283,3 +283,72 @@ EXPORT_SYM int pkcs1_decode(const uint8_t *em, size_t len_em_output, free(padded_sentinel); return result; } + +/* + * Decode and verify the OAEP padding in constant time. + * + * The function returns the number of bytes to ignore at the beginning + * of db (the rest is the plaintext), or -1 in case of problems. + */ + +EXPORT_SYM int oaep_decode(const uint8_t *em, + size_t em_len, + const uint8_t *lHash, + size_t hLen, + const uint8_t *db, + size_t db_len) /* em_len - 1 - hLen */ +{ + int result; + size_t one_pos, search_len, i; + uint8_t wrong_padding; + uint8_t *eq_mask = NULL; + uint8_t *neq_mask = NULL; + uint8_t *target_db = NULL; + + if (NULL == em || NULL == lHash || NULL == db) { + return -1; + } + + if (em_len < 2*hLen+2 || db_len != em_len-1-hLen) { + return -1; + } + + /* Allocate */ + eq_mask = (uint8_t*) calloc(1, db_len); + neq_mask = (uint8_t*) calloc(1, db_len); + target_db = (uint8_t*) calloc(1, db_len); + if (NULL == eq_mask || NULL == neq_mask || NULL == target_db) { + result = -1; + goto cleanup; + } + + /* Step 3g */ + search_len = db_len - hLen; + + one_pos = safe_search(db + hLen, 0x01, search_len); + if (SIZE_T_MAX == one_pos) { + result = -1; + goto cleanup; + } + + memset(eq_mask, 0xAA, db_len); + memcpy(target_db, lHash, hLen); + memset(eq_mask, 0xFF, hLen); + + for (i=0; i