diff --git a/src/decode.rs b/src/decode.rs index 5230fd3..0f66c74 100644 --- a/src/decode.rs +++ b/src/decode.rs @@ -9,18 +9,20 @@ use std::error; #[derive(Clone, Debug, PartialEq, Eq)] pub enum DecodeError { /// An invalid byte was found in the input. The offset and offending byte are provided. - /// Padding characters (`=`) interspersed in the encoded form will be treated as invalid bytes. + /// + /// Padding characters (`=`) interspersed in the encoded form are invalid, as they may only + /// be present as the last 0-2 bytes of input. + /// + /// This error may also indicate that extraneous trailing input bytes are present, causing + /// otherwise valid padding to no longer be the last bytes of input. InvalidByte(usize, u8), - /// The length of the input is invalid. - /// A typical cause of this is stray trailing whitespace or other separator bytes. - /// In the case where excess trailing bytes have produced an invalid length *and* the last byte - /// is also an invalid base64 symbol (as would be the case for whitespace, etc), `InvalidByte` - /// will be emitted instead of `InvalidLength` to make the issue easier to debug. - InvalidLength, + /// The length of the input, as measured in valid base64 symbols, is invalid. + /// There must be 2-4 symbols in the last input quad. + InvalidLength(usize), /// The last non-padding input symbol's encoded 6 bits have nonzero bits that will be discarded. /// This is indicative of corrupted or truncated Base64. - /// Unlike `InvalidByte`, which reports symbols that aren't in the alphabet, this error is for - /// symbols that are in the alphabet but represent nonsensical encodings. + /// Unlike [DecodeError::InvalidByte], which reports symbols that aren't in the alphabet, + /// this error is for symbols that are in the alphabet but represent nonsensical encodings. InvalidLastSymbol(usize, u8), /// The nature of the padding was not as configured: absent or incorrect when it must be /// canonical, or present when it must be absent, etc. @@ -30,8 +32,10 @@ pub enum DecodeError { impl fmt::Display for DecodeError { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match *self { - Self::InvalidByte(index, byte) => write!(f, "Invalid byte {}, offset {}.", byte, index), - Self::InvalidLength => write!(f, "Encoded text cannot have a 6-bit remainder."), + Self::InvalidByte(index, byte) => { + write!(f, "Invalid symbol {}, offset {}.", byte, index) + } + Self::InvalidLength(len) => write!(f, "Invalid input length: {}", len), Self::InvalidLastSymbol(index, byte) => { write!(f, "Invalid last symbol {}, offset {}.", byte, index) } diff --git a/src/engine/general_purpose/decode.rs b/src/engine/general_purpose/decode.rs index 21a386f..31c289e 100644 --- a/src/engine/general_purpose/decode.rs +++ b/src/engine/general_purpose/decode.rs @@ -3,45 +3,25 @@ use crate::{ DecodeError, PAD_BYTE, }; -// decode logic operates on chunks of 8 input bytes without padding -const INPUT_CHUNK_LEN: usize = 8; -const DECODED_CHUNK_LEN: usize = 6; - -// we read a u64 and write a u64, but a u64 of input only yields 6 bytes of output, so the last -// 2 bytes of any output u64 should not be counted as written to (but must be available in a -// slice). -const DECODED_CHUNK_SUFFIX: usize = 2; - -// how many u64's of input to handle at a time -const CHUNKS_PER_FAST_LOOP_BLOCK: usize = 4; - -const INPUT_BLOCK_LEN: usize = CHUNKS_PER_FAST_LOOP_BLOCK * INPUT_CHUNK_LEN; - -// includes the trailing 2 bytes for the final u64 write -const DECODED_BLOCK_LEN: usize = - CHUNKS_PER_FAST_LOOP_BLOCK * DECODED_CHUNK_LEN + DECODED_CHUNK_SUFFIX; - #[doc(hidden)] pub struct GeneralPurposeEstimate { - /// Total number of decode chunks, including a possibly partial last chunk - num_chunks: usize, - decoded_len_estimate: usize, + rem: usize, + conservative_len: usize, } impl GeneralPurposeEstimate { pub(crate) fn new(encoded_len: usize) -> Self { - // Formulas that won't overflow + let rem = encoded_len % 4; Self { - num_chunks: encoded_len / INPUT_CHUNK_LEN - + (encoded_len % INPUT_CHUNK_LEN > 0) as usize, - decoded_len_estimate: (encoded_len / 4 + (encoded_len % 4 > 0) as usize) * 3, + rem, + conservative_len: (encoded_len / 4 + (rem > 0) as usize) * 3, } } } impl DecodeEstimate for GeneralPurposeEstimate { fn decoded_len_estimate(&self) -> usize { - self.decoded_len_estimate + self.conservative_len } } @@ -59,264 +39,237 @@ pub(crate) fn decode_helper( decode_allow_trailing_bits: bool, padding_mode: DecodePaddingMode, ) -> Result { - let remainder_len = input.len() % INPUT_CHUNK_LEN; - - // Because the fast decode loop writes in groups of 8 bytes (unrolled to - // CHUNKS_PER_FAST_LOOP_BLOCK times 8 bytes, where possible) and outputs 8 bytes at a time (of - // which only 6 are valid data), we need to be sure that we stop using the fast decode loop - // soon enough that there will always be 2 more bytes of valid data written after that loop. - let trailing_bytes_to_skip = match remainder_len { - // if input is a multiple of the chunk size, ignore the last chunk as it may have padding, - // and the fast decode logic cannot handle padding - 0 => INPUT_CHUNK_LEN, - // 1 and 5 trailing bytes are illegal: can't decode 6 bits of input into a byte - 1 | 5 => { - // trailing whitespace is so common that it's worth it to check the last byte to - // possibly return a better error message - if let Some(b) = input.last() { - if *b != PAD_BYTE && decode_table[*b as usize] == INVALID_VALUE { - return Err(DecodeError::InvalidByte(input.len() - 1, *b)); - } - } - - return Err(DecodeError::InvalidLength); + // detect a trailing invalid byte, like a newline, as a user convenience + if estimate.rem == 1 { + let last_byte = input[input.len() - 1]; + // exclude pad bytes; might be part of padding that extends from earlier in the input + if last_byte != PAD_BYTE && decode_table[usize::from(last_byte)] == INVALID_VALUE { + return Err(DecodeError::InvalidByte(input.len() - 1, last_byte)); } - // This will decode to one output byte, which isn't enough to overwrite the 2 extra bytes - // written by the fast decode loop. So, we have to ignore both these 2 bytes and the - // previous chunk. - 2 => INPUT_CHUNK_LEN + 2, - // If this is 3 un-padded chars, then it would actually decode to 2 bytes. However, if this - // is an erroneous 2 chars + 1 pad char that would decode to 1 byte, then it should fail - // with an error, not panic from going past the bounds of the output slice, so we let it - // use stage 3 + 4. - 3 => INPUT_CHUNK_LEN + 3, - // This can also decode to one output byte because it may be 2 input chars + 2 padding - // chars, which would decode to 1 byte. - 4 => INPUT_CHUNK_LEN + 4, - // Everything else is a legal decode len (given that we don't require padding), and will - // decode to at least 2 bytes of output. - _ => remainder_len, - }; - - // rounded up to include partial chunks - let mut remaining_chunks = estimate.num_chunks; - - let mut input_index = 0; - let mut output_index = 0; + } + // skip last quad, even if it's complete, as it may have padding + let input_complete_nonterminal_quads_len = input + .len() + .saturating_sub(estimate.rem) + // if rem was 0, subtract 4 to avoid padding + .saturating_sub((estimate.rem == 0) as usize * 4); + debug_assert!( + input.is_empty() || (1..=4).contains(&(input.len() - input_complete_nonterminal_quads_len)) + ); + + const UNROLLED_INPUT_CHUNK_SIZE: usize = 32; + const UNROLLED_OUTPUT_CHUNK_SIZE: usize = UNROLLED_INPUT_CHUNK_SIZE / 4 * 3; + + let input_complete_quads_after_unrolled_chunks_len = + input_complete_nonterminal_quads_len % UNROLLED_INPUT_CHUNK_SIZE; + + let input_unrolled_loop_len = + input_complete_nonterminal_quads_len - input_complete_quads_after_unrolled_chunks_len; + + // chunks of 32 bytes + for (chunk_index, chunk) in input[..input_unrolled_loop_len] + .chunks_exact(UNROLLED_INPUT_CHUNK_SIZE) + .enumerate() { - let length_of_fast_decode_chunks = input.len().saturating_sub(trailing_bytes_to_skip); - - // Fast loop, stage 1 - // manual unroll to CHUNKS_PER_FAST_LOOP_BLOCK of u64s to amortize slice bounds checks - if let Some(max_start_index) = length_of_fast_decode_chunks.checked_sub(INPUT_BLOCK_LEN) { - while input_index <= max_start_index { - let input_slice = &input[input_index..(input_index + INPUT_BLOCK_LEN)]; - let output_slice = &mut output[output_index..(output_index + DECODED_BLOCK_LEN)]; - - decode_chunk( - &input_slice[0..], - input_index, - decode_table, - &mut output_slice[0..], - )?; - decode_chunk( - &input_slice[8..], - input_index + 8, - decode_table, - &mut output_slice[6..], - )?; - decode_chunk( - &input_slice[16..], - input_index + 16, - decode_table, - &mut output_slice[12..], - )?; - decode_chunk( - &input_slice[24..], - input_index + 24, - decode_table, - &mut output_slice[18..], - )?; - - input_index += INPUT_BLOCK_LEN; - output_index += DECODED_BLOCK_LEN - DECODED_CHUNK_SUFFIX; - remaining_chunks -= CHUNKS_PER_FAST_LOOP_BLOCK; - } - } - - // Fast loop, stage 2 (aka still pretty fast loop) - // 8 bytes at a time for whatever we didn't do in stage 1. - if let Some(max_start_index) = length_of_fast_decode_chunks.checked_sub(INPUT_CHUNK_LEN) { - while input_index < max_start_index { - decode_chunk( - &input[input_index..(input_index + INPUT_CHUNK_LEN)], - input_index, - decode_table, - &mut output - [output_index..(output_index + DECODED_CHUNK_LEN + DECODED_CHUNK_SUFFIX)], - )?; - - output_index += DECODED_CHUNK_LEN; - input_index += INPUT_CHUNK_LEN; - remaining_chunks -= 1; - } - } - } + let input_index = chunk_index * UNROLLED_INPUT_CHUNK_SIZE; + let chunk_output = &mut output[chunk_index * UNROLLED_OUTPUT_CHUNK_SIZE + ..(chunk_index + 1) * UNROLLED_OUTPUT_CHUNK_SIZE]; - // Stage 3 - // If input length was such that a chunk had to be deferred until after the fast loop - // because decoding it would have produced 2 trailing bytes that wouldn't then be - // overwritten, we decode that chunk here. This way is slower but doesn't write the 2 - // trailing bytes. - // However, we still need to avoid the last chunk (partial or complete) because it could - // have padding, so we always do 1 fewer to avoid the last chunk. - for _ in 1..remaining_chunks { - decode_chunk_precise( - &input[input_index..], + decode_chunk_8( + &chunk[0..8], input_index, decode_table, - &mut output[output_index..(output_index + DECODED_CHUNK_LEN)], + &mut chunk_output[0..6], + )?; + decode_chunk_8( + &chunk[8..16], + input_index + 8, + decode_table, + &mut chunk_output[6..12], + )?; + decode_chunk_8( + &chunk[16..24], + input_index + 16, + decode_table, + &mut chunk_output[12..18], + )?; + decode_chunk_8( + &chunk[24..32], + input_index + 24, + decode_table, + &mut chunk_output[18..24], )?; - - input_index += INPUT_CHUNK_LEN; - output_index += DECODED_CHUNK_LEN; } - // always have one more (possibly partial) block of 8 input - debug_assert!(input.len() - input_index > 1 || input.is_empty()); - debug_assert!(input.len() - input_index <= 8); + // remaining quads, except for the last possibly partial one, as it may have padding + let output_unrolled_loop_len = input_unrolled_loop_len / 4 * 3; + let output_complete_quad_len = input_complete_nonterminal_quads_len / 4 * 3; + { + let output_after_unroll = &mut output[output_unrolled_loop_len..output_complete_quad_len]; + + for (chunk_index, chunk) in input + [input_unrolled_loop_len..input_complete_nonterminal_quads_len] + .chunks_exact(4) + .enumerate() + { + let chunk_output = &mut output_after_unroll[chunk_index * 3..chunk_index * 3 + 3]; + + decode_chunk_4( + chunk, + input_unrolled_loop_len + chunk_index * 4, + decode_table, + chunk_output, + )?; + } + } super::decode_suffix::decode_suffix( input, - input_index, + input_complete_nonterminal_quads_len, output, - output_index, + output_complete_quad_len, decode_table, decode_allow_trailing_bits, padding_mode, ) } -/// Decode 8 bytes of input into 6 bytes of output. 8 bytes of output will be written, but only the -/// first 6 of those contain meaningful data. +/// Decode 8 bytes of input into 6 bytes of output. /// -/// `input` is the bytes to decode, of which the first 8 bytes will be processed. +/// `input` is the 8 bytes to decode. /// `index_at_start_of_input` is the offset in the overall input (used for reporting errors /// accurately) /// `decode_table` is the lookup table for the particular base64 alphabet. -/// `output` will have its first 8 bytes overwritten, of which only the first 6 are valid decoded -/// data. +/// `output` will have its first 6 bytes overwritten // yes, really inline (worth 30-50% speedup) #[inline(always)] -fn decode_chunk( +fn decode_chunk_8( input: &[u8], index_at_start_of_input: usize, decode_table: &[u8; 256], output: &mut [u8], ) -> Result<(), DecodeError> { - let morsel = decode_table[input[0] as usize]; + let morsel = decode_table[usize::from(input[0])]; if morsel == INVALID_VALUE { return Err(DecodeError::InvalidByte(index_at_start_of_input, input[0])); } - let mut accum = (morsel as u64) << 58; + let mut accum = u64::from(morsel) << 58; - let morsel = decode_table[input[1] as usize]; + let morsel = decode_table[usize::from(input[1])]; if morsel == INVALID_VALUE { return Err(DecodeError::InvalidByte( index_at_start_of_input + 1, input[1], )); } - accum |= (morsel as u64) << 52; + accum |= u64::from(morsel) << 52; - let morsel = decode_table[input[2] as usize]; + let morsel = decode_table[usize::from(input[2])]; if morsel == INVALID_VALUE { return Err(DecodeError::InvalidByte( index_at_start_of_input + 2, input[2], )); } - accum |= (morsel as u64) << 46; + accum |= u64::from(morsel) << 46; - let morsel = decode_table[input[3] as usize]; + let morsel = decode_table[usize::from(input[3])]; if morsel == INVALID_VALUE { return Err(DecodeError::InvalidByte( index_at_start_of_input + 3, input[3], )); } - accum |= (morsel as u64) << 40; + accum |= u64::from(morsel) << 40; - let morsel = decode_table[input[4] as usize]; + let morsel = decode_table[usize::from(input[4])]; if morsel == INVALID_VALUE { return Err(DecodeError::InvalidByte( index_at_start_of_input + 4, input[4], )); } - accum |= (morsel as u64) << 34; + accum |= u64::from(morsel) << 34; - let morsel = decode_table[input[5] as usize]; + let morsel = decode_table[usize::from(input[5])]; if morsel == INVALID_VALUE { return Err(DecodeError::InvalidByte( index_at_start_of_input + 5, input[5], )); } - accum |= (morsel as u64) << 28; + accum |= u64::from(morsel) << 28; - let morsel = decode_table[input[6] as usize]; + let morsel = decode_table[usize::from(input[6])]; if morsel == INVALID_VALUE { return Err(DecodeError::InvalidByte( index_at_start_of_input + 6, input[6], )); } - accum |= (morsel as u64) << 22; + accum |= u64::from(morsel) << 22; - let morsel = decode_table[input[7] as usize]; + let morsel = decode_table[usize::from(input[7])]; if morsel == INVALID_VALUE { return Err(DecodeError::InvalidByte( index_at_start_of_input + 7, input[7], )); } - accum |= (morsel as u64) << 16; + accum |= u64::from(morsel) << 16; - write_u64(output, accum); + output[..6].copy_from_slice(&accum.to_be_bytes()[..6]); Ok(()) } -/// Decode an 8-byte chunk, but only write the 6 bytes actually decoded instead of including 2 -/// trailing garbage bytes. -#[inline] -fn decode_chunk_precise( +/// Like [decode_chunk_8] but for 4 bytes of input and 3 bytes of output. +#[inline(always)] +fn decode_chunk_4( input: &[u8], index_at_start_of_input: usize, decode_table: &[u8; 256], output: &mut [u8], ) -> Result<(), DecodeError> { - let mut tmp_buf = [0_u8; 8]; + let morsel = decode_table[usize::from(input[0])]; + if morsel == INVALID_VALUE { + return Err(DecodeError::InvalidByte(index_at_start_of_input, input[0])); + } + let mut accum = u32::from(morsel) << 26; - decode_chunk( - input, - index_at_start_of_input, - decode_table, - &mut tmp_buf[..], - )?; + let morsel = decode_table[usize::from(input[1])]; + if morsel == INVALID_VALUE { + return Err(DecodeError::InvalidByte( + index_at_start_of_input + 1, + input[1], + )); + } + accum |= u32::from(morsel) << 20; + + let morsel = decode_table[usize::from(input[2])]; + if morsel == INVALID_VALUE { + return Err(DecodeError::InvalidByte( + index_at_start_of_input + 2, + input[2], + )); + } + accum |= u32::from(morsel) << 14; + + let morsel = decode_table[usize::from(input[3])]; + if morsel == INVALID_VALUE { + return Err(DecodeError::InvalidByte( + index_at_start_of_input + 3, + input[3], + )); + } + accum |= u32::from(morsel) << 8; - output[0..6].copy_from_slice(&tmp_buf[0..6]); + output[..3].copy_from_slice(&accum.to_be_bytes()[..3]); Ok(()) } -#[inline] -fn write_u64(output: &mut [u8], value: u64) { - output[..8].copy_from_slice(&value.to_be_bytes()); -} - #[cfg(test)] mod tests { use super::*; @@ -324,37 +277,36 @@ mod tests { use crate::engine::general_purpose::STANDARD; #[test] - fn decode_chunk_precise_writes_only_6_bytes() { + fn decode_chunk_8_writes_only_6_bytes() { let input = b"Zm9vYmFy"; // "foobar" let mut output = [0_u8, 1, 2, 3, 4, 5, 6, 7]; - decode_chunk_precise(&input[..], 0, &STANDARD.decode_table, &mut output).unwrap(); + decode_chunk_8(&input[..], 0, &STANDARD.decode_table, &mut output).unwrap(); assert_eq!(&vec![b'f', b'o', b'o', b'b', b'a', b'r', 6, 7], &output); } #[test] - fn decode_chunk_writes_8_bytes() { - let input = b"Zm9vYmFy"; // "foobar" - let mut output = [0_u8, 1, 2, 3, 4, 5, 6, 7]; + fn decode_chunk_4_writes_only_3_bytes() { + let input = b"Zm9v"; // "foobar" + let mut output = [0_u8, 1, 2, 3]; - decode_chunk(&input[..], 0, &STANDARD.decode_table, &mut output).unwrap(); - assert_eq!(&vec![b'f', b'o', b'o', b'b', b'a', b'r', 0, 0], &output); + decode_chunk_4(&input[..], 0, &STANDARD.decode_table, &mut output).unwrap(); + assert_eq!(&vec![b'f', b'o', b'o', 3], &output); } #[test] fn estimate_short_lengths() { - for (range, (num_chunks, decoded_len_estimate)) in [ - (0..=0, (0, 0)), - (1..=4, (1, 3)), - (5..=8, (1, 6)), - (9..=12, (2, 9)), - (13..=16, (2, 12)), - (17..=20, (3, 15)), + for (range, decoded_len_estimate) in [ + (0..=0, 0), + (1..=4, 3), + (5..=8, 6), + (9..=12, 9), + (13..=16, 12), + (17..=20, 15), ] { for encoded_len in range { let estimate = GeneralPurposeEstimate::new(encoded_len); - assert_eq!(num_chunks, estimate.num_chunks); - assert_eq!(decoded_len_estimate, estimate.decoded_len_estimate); + assert_eq!(decoded_len_estimate, estimate.decoded_len_estimate()); } } } @@ -369,15 +321,7 @@ mod tests { let len_128 = encoded_len as u128; let estimate = GeneralPurposeEstimate::new(encoded_len); - assert_eq!( - ((len_128 + (INPUT_CHUNK_LEN - 1) as u128) / (INPUT_CHUNK_LEN as u128)) - as usize, - estimate.num_chunks - ); - assert_eq!( - ((len_128 + 3) / 4 * 3) as usize, - estimate.decoded_len_estimate - ); + assert_eq!((len_128 + 3) / 4 * 3, estimate.conservative_len as u128); }) } } diff --git a/src/engine/general_purpose/decode_suffix.rs b/src/engine/general_purpose/decode_suffix.rs index 9fbb0d5..3d52ae5 100644 --- a/src/engine/general_purpose/decode_suffix.rs +++ b/src/engine/general_purpose/decode_suffix.rs @@ -3,7 +3,7 @@ use crate::{ DecodeError, PAD_BYTE, }; -/// Decode the last 1-8 bytes, checking for trailing set bits and padding per the provided +/// Decode the last 0-4 bytes, checking for trailing set bits and padding per the provided /// parameters. /// /// Returns the decode metadata representing the total number of bytes decoded, including the ones @@ -17,16 +17,18 @@ pub(crate) fn decode_suffix( decode_allow_trailing_bits: bool, padding_mode: DecodePaddingMode, ) -> Result { + debug_assert!((input.len() - input_index) <= 4); + // Decode any leftovers that might not be a complete input chunk of 8 bytes. // Use a u64 as a stack-resident 8 byte buffer. let mut morsels_in_leftover = 0; - let mut padding_bytes = 0; - let mut first_padding_index: usize = 0; + let mut padding_bytes_count = 0; + // offset from input_index + let mut first_padding_offset: usize = 0; let mut last_symbol = 0_u8; - let start_of_leftovers = input_index; - let mut morsels = [0_u8; 8]; + let mut morsels = [0_u8; 4]; - for (i, &b) in input[start_of_leftovers..].iter().enumerate() { + for (leftover_index, &b) in input[input_index..].iter().enumerate() { // '=' padding if b == PAD_BYTE { // There can be bad padding bytes in a few ways: @@ -41,30 +43,30 @@ pub(crate) fn decode_suffix( // Per config, non-canonical but still functional non- or partially-padded base64 // may be treated as an error condition. - if i % 4 < 2 { + if leftover_index < 2 { // Check for case #2. - let bad_padding_index = start_of_leftovers - + if padding_bytes > 0 { + let bad_padding_index = input_index + + if padding_bytes_count > 0 { // If we've already seen padding, report the first padding index. // This is to be consistent with the normal decode logic: it will report an // error on the first padding character (since it doesn't expect to see // anything but actual encoded data). // This could only happen if the padding started in the previous quad since - // otherwise this case would have been hit at i % 4 == 0 if it was the same + // otherwise this case would have been hit at i == 4 if it was the same // quad. - first_padding_index + first_padding_offset } else { // haven't seen padding before, just use where we are now - i + leftover_index }; return Err(DecodeError::InvalidByte(bad_padding_index, b)); } - if padding_bytes == 0 { - first_padding_index = i; + if padding_bytes_count == 0 { + first_padding_offset = leftover_index; } - padding_bytes += 1; + padding_bytes_count += 1; continue; } @@ -72,9 +74,9 @@ pub(crate) fn decode_suffix( // To make '=' handling consistent with the main loop, don't allow // non-suffix '=' in trailing chunk either. Report error as first // erroneous padding. - if padding_bytes > 0 { + if padding_bytes_count > 0 { return Err(DecodeError::InvalidByte( - start_of_leftovers + first_padding_index, + input_index + first_padding_offset, PAD_BYTE, )); } @@ -85,22 +87,31 @@ pub(crate) fn decode_suffix( // Pack the leftovers from left to right. let morsel = decode_table[b as usize]; if morsel == INVALID_VALUE { - return Err(DecodeError::InvalidByte(start_of_leftovers + i, b)); + return Err(DecodeError::InvalidByte(input_index + leftover_index, b)); } morsels[morsels_in_leftover] = morsel; morsels_in_leftover += 1; } + // If there was 1 trailing byte, and it was valid, and we got to this point without hitting + // an invalid byte, now we can report invalid length + if !input.is_empty() && morsels_in_leftover < 2 { + return Err(DecodeError::InvalidLength( + input_index + morsels_in_leftover, + )); + } + match padding_mode { DecodePaddingMode::Indifferent => { /* everything we care about was already checked */ } DecodePaddingMode::RequireCanonical => { - if (padding_bytes + morsels_in_leftover) % 4 != 0 { + // allow empty input + if (padding_bytes_count + morsels_in_leftover) % 4 != 0 { return Err(DecodeError::InvalidPadding); } } DecodePaddingMode::RequireNone => { - if padding_bytes > 0 { + if padding_bytes_count > 0 { // check at the end to make sure we let the cases of padding that should be InvalidByte // get hit return Err(DecodeError::InvalidPadding); @@ -120,27 +131,21 @@ pub(crate) fn decode_suffix( // useless since there are no more symbols to provide the necessary 4 additional bits // to finish the second original byte. - // TODO how do we know this? - debug_assert!(morsels_in_leftover != 1 && morsels_in_leftover != 5); let leftover_bytes_to_append = morsels_in_leftover * 6 / 8; // Put the up to 6 complete bytes as the high bytes. // Gain a couple percent speedup from nudging these ORs to use more ILP with a two-way split. - let mut leftover_num = ((u64::from(morsels[0]) << 58) - | (u64::from(morsels[1]) << 52) - | (u64::from(morsels[2]) << 46) - | (u64::from(morsels[3]) << 40)) - | ((u64::from(morsels[4]) << 34) - | (u64::from(morsels[5]) << 28) - | (u64::from(morsels[6]) << 22) - | (u64::from(morsels[7]) << 16)); + let mut leftover_num = (u32::from(morsels[0]) << 26) + | (u32::from(morsels[1]) << 20) + | (u32::from(morsels[2]) << 14) + | (u32::from(morsels[3]) << 8); // if there are bits set outside the bits we care about, last symbol encodes trailing bits that // will not be included in the output - let mask = !0 >> (leftover_bytes_to_append * 8); + let mask = !0_u32 >> (leftover_bytes_to_append * 8); if !decode_allow_trailing_bits && (leftover_num & mask) != 0 { // last morsel is at `morsels_in_leftover` - 1 return Err(DecodeError::InvalidLastSymbol( - start_of_leftovers + morsels_in_leftover - 1, + input_index + morsels_in_leftover - 1, last_symbol, )); } @@ -148,16 +153,17 @@ pub(crate) fn decode_suffix( // Strangely, this approach benchmarks better than writing bytes one at a time, // or copy_from_slice into output. for _ in 0..leftover_bytes_to_append { - let hi_byte = (leftover_num >> 56) as u8; + let hi_byte = (leftover_num >> 24) as u8; leftover_num <<= 8; + // TODO use checked writes output[output_index] = hi_byte; output_index += 1; } Ok(DecodeMetadata::new( output_index, - if padding_bytes > 0 { - Some(input_index + first_padding_index) + if padding_bytes_count > 0 { + Some(input_index + first_padding_offset) } else { None }, diff --git a/src/engine/naive.rs b/src/engine/naive.rs index 6a50cbe..2546a6f 100644 --- a/src/engine/naive.rs +++ b/src/engine/naive.rs @@ -115,15 +115,12 @@ impl Engine for Naive { if estimate.rem == 1 { // trailing whitespace is so common that it's worth it to check the last byte to // possibly return a better error message - if let Some(b) = input.last() { - if *b != PAD_BYTE - && self.decode_table[*b as usize] == general_purpose::INVALID_VALUE - { - return Err(DecodeError::InvalidByte(input.len() - 1, *b)); - } + let last_byte = input[input.len() - 1]; + if last_byte != PAD_BYTE + && self.decode_table[usize::from(last_byte)] == general_purpose::INVALID_VALUE + { + return Err(DecodeError::InvalidByte(input.len() - 1, last_byte)); } - - return Err(DecodeError::InvalidLength); } let mut input_index = 0_usize; diff --git a/src/engine/tests.rs b/src/engine/tests.rs index b048005..b73f108 100644 --- a/src/engine/tests.rs +++ b/src/engine/tests.rs @@ -365,26 +365,49 @@ fn decode_detect_invalid_last_symbol(engine_wrapper: E) { } #[apply(all_engines)] -fn decode_detect_invalid_last_symbol_when_length_is_also_invalid( - engine_wrapper: E, -) { - let mut rng = seeded_rng(); - - // check across enough lengths that it would likely cover any implementation's various internal - // small/large input division +fn decode_detect_1_valid_symbol_in_last_quad_invalid_length(engine_wrapper: E) { for len in (0_usize..256).map(|len| len * 4 + 1) { - let engine = E::random_alphabet(&mut rng, &STANDARD); + for mode in all_pad_modes() { + let mut input = vec![b'A'; len]; - let mut input = vec![b'A'; len]; + let engine = E::standard_with_pad_mode(true, mode); - // with a valid last char, it's InvalidLength - assert_eq!(Err(DecodeError::InvalidLength), engine.decode(&input)); - // after mangling the last char, it's InvalidByte - input[len - 1] = b'"'; - assert_eq!( - Err(DecodeError::InvalidByte(len - 1, b'"')), - engine.decode(&input) - ); + assert_eq!(Err(DecodeError::InvalidLength(len)), engine.decode(&input)); + // if we add padding, then the first pad byte in the quad is invalid because it should + // be the second symbol + for _ in 0..3 { + input.push(PAD_BYTE); + assert_eq!( + Err(DecodeError::InvalidByte(len, PAD_BYTE)), + engine.decode(&input) + ); + } + } + } +} + +#[apply(all_engines)] +fn decode_detect_1_invalid_byte_in_last_quad_invalid_byte(engine_wrapper: E) { + for prefix_len in (0_usize..256).map(|len| len * 4) { + for mode in all_pad_modes() { + let mut input = vec![b'A'; prefix_len]; + input.push(b'*'); + + let engine = E::standard_with_pad_mode(true, mode); + + assert_eq!( + Err(DecodeError::InvalidByte(prefix_len, b'*')), + engine.decode(&input) + ); + // adding padding doesn't matter + for _ in 0..3 { + input.push(PAD_BYTE); + assert_eq!( + Err(DecodeError::InvalidByte(prefix_len, b'*')), + engine.decode(&input) + ); + } + } } } @@ -471,8 +494,10 @@ fn decode_detect_invalid_last_symbol_every_possible_three_symbols(engine_wrapper: E) { /// Any amount of padding anywhere before the final non padding character = invalid byte at first /// pad byte. -/// From this, we know padding must extend to the end of the input. -// DecoderReader pseudo-engine detects InvalidLastSymbol instead of InvalidLength because it -// can end a decode on the quad that happens to contain the start of the padding -#[apply(all_engines_except_decoder_reader)] -fn decode_padding_before_final_non_padding_char_error_invalid_byte( +/// From this and [decode_padding_before_final_non_padding_char_error_invalid_byte_at_first_pad_non_canonical_padding_suffix_all_modes], +/// we know padding must extend contiguously to the end of the input. +#[apply(all_engines)] +fn decode_padding_before_final_non_padding_char_error_invalid_byte_at_first_pad_all_modes< + E: EngineWrapper, +>( engine_wrapper: E, ) { - let mut rng = seeded_rng(); + // Different amounts of padding, w/ offset from end for the last non-padding char. + // Only canonical padding, so Canonical mode will work. + let suffixes = &[("AA==", 2), ("AAA=", 1), ("AAAA", 0)]; - // the different amounts of proper padding, w/ offset from end for the last non-padding char - let suffixes = [("/w==", 2), ("iYu=", 1), ("zzzz", 0)]; + for mode in pad_modes_allowing_padding() { + // We don't encode, so we don't care about encode padding. + let engine = E::standard_with_pad_mode(true, mode); - let prefix_quads_range = distributions::Uniform::from(0..=256); + decode_padding_before_final_non_padding_char_error_invalid_byte_at_first_pad( + engine, + suffixes.as_slice(), + ); + } +} - for mode in all_pad_modes() { - // we don't encode so we don't care about encode padding - let engine = E::standard_with_pad_mode(true, mode); +/// See [decode_padding_before_final_non_padding_char_error_invalid_byte_at_first_pad_all_modes] +#[apply(all_engines)] +fn decode_padding_before_final_non_padding_char_error_invalid_byte_at_first_pad_non_canonical_padding_suffix< + E: EngineWrapper, +>( + engine_wrapper: E, +) { + // Different amounts of padding, w/ offset from end for the last non-padding char, and + // non-canonical padding. + let suffixes = [ + ("AA==", 2), + ("AA=", 1), + ("AA", 0), + ("AAA=", 1), + ("AAA", 0), + ("AAAA", 0), + ]; - for _ in 0..100_000 { - for (suffix, offset) in suffixes.iter() { - let mut s = "ABCD".repeat(prefix_quads_range.sample(&mut rng)); - s.push_str(suffix); - let mut encoded = s.into_bytes(); + // We don't encode, so we don't care about encode padding. + // Decoding is indifferent so that we don't get caught by missing padding on the last quad + let engine = E::standard_with_pad_mode(true, DecodePaddingMode::Indifferent); + + decode_padding_before_final_non_padding_char_error_invalid_byte_at_first_pad( + engine, + suffixes.as_slice(), + ) +} + +fn decode_padding_before_final_non_padding_char_error_invalid_byte_at_first_pad( + engine: impl Engine, + suffixes: &[(&str, usize)], +) { + let mut rng = seeded_rng(); - // calculate a range to write padding into that leaves at least one non padding char - let last_non_padding_offset = encoded.len() - 1 - offset; + let prefix_quads_range = distributions::Uniform::from(0..=256); - // don't include last non padding char as it must stay not padding - let padding_end = rng.gen_range(0..last_non_padding_offset); + for _ in 0..100_000 { + for (suffix, suffix_offset) in suffixes.iter() { + let mut s = "AAAA".repeat(prefix_quads_range.sample(&mut rng)); + s.push_str(suffix); + let mut encoded = s.into_bytes(); - // don't use more than 100 bytes of padding, but also use shorter lengths when - // padding_end is near the start of the encoded data to avoid biasing to padding - // the entire prefix on short lengths - let padding_len = rng.gen_range(1..=usize::min(100, padding_end + 1)); - let padding_start = padding_end.saturating_sub(padding_len); + // calculate a range to write padding into that leaves at least one non padding char + let last_non_padding_offset = encoded.len() - 1 - suffix_offset; - encoded[padding_start..=padding_end].fill(PAD_BYTE); + // don't include last non padding char as it must stay not padding + let padding_end = rng.gen_range(0..last_non_padding_offset); - assert_eq!( - Err(DecodeError::InvalidByte(padding_start, PAD_BYTE)), - engine.decode(&encoded), - ); - } + // don't use more than 100 bytes of padding, but also use shorter lengths when + // padding_end is near the start of the encoded data to avoid biasing to padding + // the entire prefix on short lengths + let padding_len = rng.gen_range(1..=usize::min(100, padding_end + 1)); + let padding_start = padding_end.saturating_sub(padding_len); + + encoded[padding_start..=padding_end].fill(PAD_BYTE); + + // should still have non-padding before any final padding + assert_ne!(PAD_BYTE, encoded[last_non_padding_offset]); + assert_eq!( + Err(DecodeError::InvalidByte(padding_start, PAD_BYTE)), + engine.decode(&encoded), + "len: {}, input: {}", + encoded.len(), + String::from_utf8(encoded).unwrap() + ); } } } -/// Any amount of padding before final chunk that crosses over into final chunk with 2-4 bytes = +/// Any amount of padding before final chunk that crosses over into final chunk with 1-4 bytes = /// invalid byte at first pad byte. -/// From this and [decode_padding_starts_before_final_chunk_error_invalid_length] we know the -/// padding must start in the final chunk. -// DecoderReader pseudo-engine detects InvalidLastSymbol instead of InvalidLength because it -// can end a decode on the quad that happens to contain the start of the padding -#[apply(all_engines_except_decoder_reader)] -fn decode_padding_starts_before_final_chunk_error_invalid_byte( +/// From this we know the padding must start in the final chunk. +#[apply(all_engines)] +fn decode_padding_starts_before_final_chunk_error_invalid_byte_at_first_pad( engine_wrapper: E, ) { let mut rng = seeded_rng(); // must have at least one prefix quad let prefix_quads_range = distributions::Uniform::from(1..256); - // excluding 1 since we don't care about invalid length in this test - let suffix_pad_len_range = distributions::Uniform::from(2..=4); - for mode in all_pad_modes() { + let suffix_pad_len_range = distributions::Uniform::from(1..=4); + // don't use no-padding mode, as the reader decode might decode a block that ends with + // valid padding, which should then be referenced when encountering the later invalid byte + for mode in pad_modes_allowing_padding() { // we don't encode so we don't care about encode padding let engine = E::standard_with_pad_mode(true, mode); for _ in 0..100_000 { let suffix_len = suffix_pad_len_range.sample(&mut rng); - let mut encoded = "ABCD" + // all 0 bits so we don't hit InvalidLastSymbol with the reader decoder + let mut encoded = "AAAA" .repeat(prefix_quads_range.sample(&mut rng)) .into_bytes(); encoded.resize(encoded.len() + suffix_len, PAD_BYTE); @@ -705,40 +774,6 @@ fn decode_padding_starts_before_final_chunk_error_invalid_byte } } -/// Any amount of padding before final chunk that crosses over into final chunk with 1 byte = -/// invalid length. -/// From this we know the padding must start in the final chunk. -// DecoderReader pseudo-engine detects InvalidByte instead of InvalidLength because it starts by -// decoding only the available complete quads -#[apply(all_engines_except_decoder_reader)] -fn decode_padding_starts_before_final_chunk_error_invalid_length( - engine_wrapper: E, -) { - let mut rng = seeded_rng(); - - // must have at least one prefix quad - let prefix_quads_range = distributions::Uniform::from(1..256); - for mode in all_pad_modes() { - // we don't encode so we don't care about encode padding - let engine = E::standard_with_pad_mode(true, mode); - for _ in 0..100_000 { - let mut encoded = "ABCD" - .repeat(prefix_quads_range.sample(&mut rng)) - .into_bytes(); - encoded.resize(encoded.len() + 1, PAD_BYTE); - - // amount of padding must be long enough to extend back from suffix into previous - // quads - let padding_len = rng.gen_range(1 + 1..encoded.len()); - // no non-padding after padding in this test, so padding goes to the end - let padding_start = encoded.len() - padding_len; - encoded[padding_start..].fill(PAD_BYTE); - - assert_eq!(Err(DecodeError::InvalidLength), engine.decode(&encoded),); - } - } -} - /// 0-1 bytes of data before any amount of padding in final chunk = invalid byte, since padding /// is not valid data (consistent with error for pad bytes in earlier chunks). /// From this we know there must be 2-3 bytes of data before padding @@ -756,29 +791,22 @@ fn decode_too_little_data_before_padding_error_invalid_byte(en let suffix_data_len = suffix_data_len_range.sample(&mut rng); let prefix_quad_len = prefix_quads_range.sample(&mut rng); - // ensure there is a suffix quad - let min_padding = usize::from(suffix_data_len == 0); - // for all possible padding lengths - for padding_len in min_padding..=(4 - suffix_data_len) { + for padding_len in 1..=(4 - suffix_data_len) { let mut encoded = "ABCD".repeat(prefix_quad_len).into_bytes(); encoded.resize(encoded.len() + suffix_data_len, b'A'); encoded.resize(encoded.len() + padding_len, PAD_BYTE); - if suffix_data_len + padding_len == 1 { - assert_eq!(Err(DecodeError::InvalidLength), engine.decode(&encoded),); - } else { - assert_eq!( - Err(DecodeError::InvalidByte( - prefix_quad_len * 4 + suffix_data_len, - PAD_BYTE, - )), - engine.decode(&encoded), - "suffix data len {} pad len {}", - suffix_data_len, - padding_len - ); - } + assert_eq!( + Err(DecodeError::InvalidByte( + prefix_quad_len * 4 + suffix_data_len, + PAD_BYTE, + )), + engine.decode(&encoded), + "suffix data len {} pad len {}", + suffix_data_len, + padding_len + ); } } } @@ -918,258 +946,64 @@ fn decode_pad_mode_indifferent_padding_accepts_anything(engine ); } -//this is a MAY in the rfc: https://tools.ietf.org/html/rfc4648#section-3.3 -// DecoderReader pseudo-engine finds the first padding, but doesn't report it as an error, -// because in the next decode it finds more padding, which is reported as InvalidByte, just -// with an offset at its position in the second decode, rather than being linked to the start -// of the padding that was first seen in the previous decode. -#[apply(all_engines_except_decoder_reader)] -fn decode_pad_byte_in_penultimate_quad_error(engine_wrapper: E) { - for mode in all_pad_modes() { - // we don't encode so we don't care about encode padding - let engine = E::standard_with_pad_mode(true, mode); - - for num_prefix_quads in 0..256 { - // leave room for at least one pad byte in penultimate quad - for num_valid_bytes_penultimate_quad in 0..4 { - // can't have 1 or it would be invalid length - for num_pad_bytes_in_final_quad in 2..=4 { - let mut s: String = "ABCD".repeat(num_prefix_quads); - - // varying amounts of padding in the penultimate quad - for _ in 0..num_valid_bytes_penultimate_quad { - s.push('A'); - } - // finish penultimate quad with padding - for _ in num_valid_bytes_penultimate_quad..4 { - s.push('='); - } - // and more padding in the final quad - for _ in 0..num_pad_bytes_in_final_quad { - s.push('='); - } - - // padding should be an invalid byte before the final quad. - // Could argue that the *next* padding byte (in the next quad) is technically the first - // erroneous one, but reporting that accurately is more complex and probably nobody cares - assert_eq!( - DecodeError::InvalidByte( - num_prefix_quads * 4 + num_valid_bytes_penultimate_quad, - b'=', - ), - engine.decode(&s).unwrap_err(), - ); - } - } - } - } -} - -#[apply(all_engines)] -fn decode_bytes_after_padding_in_final_quad_error(engine_wrapper: E) { - for mode in all_pad_modes() { - // we don't encode so we don't care about encode padding - let engine = E::standard_with_pad_mode(true, mode); - - for num_prefix_quads in 0..256 { - // leave at least one byte in the quad for padding - for bytes_after_padding in 1..4 { - let mut s: String = "ABCD".repeat(num_prefix_quads); - - // every invalid padding position with a 3-byte final quad: 1 to 3 bytes after padding - for _ in 0..(3 - bytes_after_padding) { - s.push('A'); - } - s.push('='); - for _ in 0..bytes_after_padding { - s.push('A'); - } - - // First (and only) padding byte is invalid. - assert_eq!( - DecodeError::InvalidByte( - num_prefix_quads * 4 + (3 - bytes_after_padding), - b'=' - ), - engine.decode(&s).unwrap_err() - ); - } - } - } -} - -#[apply(all_engines)] -fn decode_absurd_pad_error(engine_wrapper: E) { - for mode in all_pad_modes() { - // we don't encode so we don't care about encode padding - let engine = E::standard_with_pad_mode(true, mode); - - for num_prefix_quads in 0..256 { - let mut s: String = "ABCD".repeat(num_prefix_quads); - s.push_str("==Y=Wx===pY=2U====="); - - // first padding byte - assert_eq!( - DecodeError::InvalidByte(num_prefix_quads * 4, b'='), - engine.decode(&s).unwrap_err() - ); - } - } -} - -// DecoderReader pseudo-engine detects InvalidByte instead of InvalidLength because it starts by -// decoding only the available complete quads -#[apply(all_engines_except_decoder_reader)] -fn decode_too_much_padding_returns_error(engine_wrapper: E) { - for mode in all_pad_modes() { - // we don't encode so we don't care about encode padding - let engine = E::standard_with_pad_mode(true, mode); - - for num_prefix_quads in 0..256 { - // add enough padding to ensure that we'll hit all decode stages at the different lengths - for pad_bytes in 1..=64 { - let mut s: String = "ABCD".repeat(num_prefix_quads); - let padding: String = "=".repeat(pad_bytes); - s.push_str(&padding); - - if pad_bytes % 4 == 1 { - assert_eq!(DecodeError::InvalidLength, engine.decode(&s).unwrap_err()); - } else { - assert_eq!( - DecodeError::InvalidByte(num_prefix_quads * 4, b'='), - engine.decode(&s).unwrap_err() - ); - } - } - } - } -} - -// DecoderReader pseudo-engine detects InvalidByte instead of InvalidLength because it starts by -// decoding only the available complete quads -#[apply(all_engines_except_decoder_reader)] -fn decode_padding_followed_by_non_padding_returns_error(engine_wrapper: E) { - for mode in all_pad_modes() { - // we don't encode so we don't care about encode padding - let engine = E::standard_with_pad_mode(true, mode); - - for num_prefix_quads in 0..256 { - for pad_bytes in 0..=32 { - let mut s: String = "ABCD".repeat(num_prefix_quads); - let padding: String = "=".repeat(pad_bytes); - s.push_str(&padding); - s.push('E'); - - if pad_bytes % 4 == 0 { - assert_eq!(DecodeError::InvalidLength, engine.decode(&s).unwrap_err()); - } else { - assert_eq!( - DecodeError::InvalidByte(num_prefix_quads * 4, b'='), - engine.decode(&s).unwrap_err() - ); - } - } - } - } -} - -#[apply(all_engines)] -fn decode_one_char_in_final_quad_with_padding_error(engine_wrapper: E) { - for mode in all_pad_modes() { - // we don't encode so we don't care about encode padding - let engine = E::standard_with_pad_mode(true, mode); - - for num_prefix_quads in 0..256 { - let mut s: String = "ABCD".repeat(num_prefix_quads); - s.push_str("E="); - - assert_eq!( - DecodeError::InvalidByte(num_prefix_quads * 4 + 1, b'='), - engine.decode(&s).unwrap_err() - ); - - // more padding doesn't change the error - s.push('='); - assert_eq!( - DecodeError::InvalidByte(num_prefix_quads * 4 + 1, b'='), - engine.decode(&s).unwrap_err() - ); - - s.push('='); - assert_eq!( - DecodeError::InvalidByte(num_prefix_quads * 4 + 1, b'='), - engine.decode(&s).unwrap_err() - ); - } - } -} - -#[apply(all_engines)] -fn decode_too_few_symbols_in_final_quad_error(engine_wrapper: E) { - for mode in all_pad_modes() { - // we don't encode so we don't care about encode padding - let engine = E::standard_with_pad_mode(true, mode); - - for num_prefix_quads in 0..256 { - // <2 is invalid - for final_quad_symbols in 0..2 { - for padding_symbols in 0..=(4 - final_quad_symbols) { - let mut s: String = "ABCD".repeat(num_prefix_quads); - - for _ in 0..final_quad_symbols { - s.push('A'); - } - for _ in 0..padding_symbols { - s.push('='); - } - - match final_quad_symbols + padding_symbols { - 0 => continue, - 1 => { - assert_eq!(DecodeError::InvalidLength, engine.decode(&s).unwrap_err()); - } - _ => { - // error reported at first padding byte - assert_eq!( - DecodeError::InvalidByte( - num_prefix_quads * 4 + final_quad_symbols, - b'=', - ), - engine.decode(&s).unwrap_err() - ); - } - } - } - } - } - } -} - +/// 1 trailing byte that's not padding is detected as invalid byte even though there's padding +/// in the middle of the input. This is essentially mandating the eager check for 1 trailing byte +/// to catch the \n suffix case. // DecoderReader pseudo-engine can't handle DecodePaddingMode::RequireNone since it will decode // a complete quad with padding in it before encountering the stray byte that makes it an invalid // length #[apply(all_engines_except_decoder_reader)] -fn decode_invalid_trailing_bytes(engine_wrapper: E) { +fn decode_invalid_trailing_bytes_all_pad_modes_invalid_byte(engine_wrapper: E) { for mode in all_pad_modes() { do_invalid_trailing_byte(E::standard_with_pad_mode(true, mode), mode); } } #[apply(all_engines)] -fn decode_invalid_trailing_bytes_all_modes(engine_wrapper: E) { +fn decode_invalid_trailing_bytes_invalid_byte(engine_wrapper: E) { // excluding no padding mode because the DecoderWrapper pseudo-engine will fail with // InvalidPadding because it will decode the last complete quad with padding first for mode in pad_modes_allowing_padding() { do_invalid_trailing_byte(E::standard_with_pad_mode(true, mode), mode); } } +fn do_invalid_trailing_byte(engine: impl Engine, mode: DecodePaddingMode) { + for last_byte in [b'*', b'\n'] { + for num_prefix_quads in 0..256 { + let mut s: String = "ABCD".repeat(num_prefix_quads); + s.push_str("Cg=="); + let mut input = s.into_bytes(); + input.push(last_byte); + + // The case of trailing newlines is common enough to warrant a test for a good error + // message. + assert_eq!( + Err(DecodeError::InvalidByte( + num_prefix_quads * 4 + 4, + last_byte + )), + engine.decode(&input), + "mode: {:?}, input: {}", + mode, + String::from_utf8(input).unwrap() + ); + } + } +} +/// When there's 1 trailing byte, but it's padding, it's only InvalidByte if there isn't padding +/// earlier. #[apply(all_engines)] -fn decode_invalid_trailing_padding_as_invalid_length(engine_wrapper: E) { +fn decode_invalid_trailing_padding_as_invalid_byte_at_first_pad_byte( + engine_wrapper: E, +) { // excluding no padding mode because the DecoderWrapper pseudo-engine will fail with // InvalidPadding because it will decode the last complete quad with padding first for mode in pad_modes_allowing_padding() { - do_invalid_trailing_padding_as_invalid_length(E::standard_with_pad_mode(true, mode), mode); + do_invalid_trailing_padding_as_invalid_byte_at_first_padding( + E::standard_with_pad_mode(true, mode), + mode, + ); } } @@ -1177,48 +1011,36 @@ fn decode_invalid_trailing_padding_as_invalid_length(engine_wr // a complete quad with padding in it before encountering the stray byte that makes it an invalid // length #[apply(all_engines_except_decoder_reader)] -fn decode_invalid_trailing_padding_as_invalid_length_all_modes( +fn decode_invalid_trailing_padding_as_invalid_byte_at_first_byte_all_modes( engine_wrapper: E, ) { for mode in all_pad_modes() { - do_invalid_trailing_padding_as_invalid_length(E::standard_with_pad_mode(true, mode), mode); + do_invalid_trailing_padding_as_invalid_byte_at_first_padding( + E::standard_with_pad_mode(true, mode), + mode, + ); } } - -#[apply(all_engines)] -fn decode_wrong_length_error(engine_wrapper: E) { - let engine = E::standard_with_pad_mode(true, DecodePaddingMode::Indifferent); - +fn do_invalid_trailing_padding_as_invalid_byte_at_first_padding( + engine: impl Engine, + mode: DecodePaddingMode, +) { for num_prefix_quads in 0..256 { - // at least one token, otherwise it wouldn't be a final quad - for num_tokens_final_quad in 1..=4 { - for num_padding in 0..=(4 - num_tokens_final_quad) { - let mut s: String = "IIII".repeat(num_prefix_quads); - for _ in 0..num_tokens_final_quad { - s.push('g'); - } - for _ in 0..num_padding { - s.push('='); - } + for (suffix, pad_offset) in [("AA===", 2), ("AAA==", 3), ("AAAA=", 4)] { + let mut s: String = "ABCD".repeat(num_prefix_quads); + s.push_str(suffix); - let res = engine.decode(&s); - if num_tokens_final_quad >= 2 { - assert!(res.is_ok()); - } else if num_tokens_final_quad == 1 && num_padding > 0 { - // = is invalid if it's too early - assert_eq!( - Err(DecodeError::InvalidByte( - num_prefix_quads * 4 + num_tokens_final_quad, - 61 - )), - res - ); - } else if num_padding > 2 { - assert_eq!(Err(DecodeError::InvalidPadding), res); - } else { - assert_eq!(Err(DecodeError::InvalidLength), res); - } - } + assert_eq!( + // pad after `g`, not the last one + Err(DecodeError::InvalidByte( + num_prefix_quads * 4 + pad_offset, + PAD_BYTE + )), + engine.decode(&s), + "mode: {:?}, input: {}", + mode, + s + ); } } } @@ -1248,14 +1070,23 @@ fn decode_into_slice_fits_in_precisely_sized_slice(engine_wrap assert_encode_sanity(&encoded_data, engine.config().encode_padding(), input_len); decode_buf.resize(input_len, 0); - // decode into the non-empty buf let decode_bytes_written = engine .decode_slice_unchecked(encoded_data.as_bytes(), &mut decode_buf[..]) .unwrap(); - assert_eq!(orig_data.len(), decode_bytes_written); assert_eq!(orig_data, decode_buf); + + // TODO + // same for checked variant + // decode_buf.clear(); + // decode_buf.resize(input_len, 0); + // // decode into the non-empty buf + // let decode_bytes_written = engine + // .decode_slice(encoded_data.as_bytes(), &mut decode_buf[..]) + // .unwrap(); + // assert_eq!(orig_data.len(), decode_bytes_written); + // assert_eq!(orig_data, decode_buf); } } @@ -1355,38 +1186,6 @@ fn estimate_via_u128_inflation(engine_wrapper: E) { }) } -fn do_invalid_trailing_byte(engine: impl Engine, mode: DecodePaddingMode) { - for num_prefix_quads in 0..256 { - let mut s: String = "ABCD".repeat(num_prefix_quads); - s.push_str("Cg==\n"); - - // The case of trailing newlines is common enough to warrant a test for a good error - // message. - assert_eq!( - Err(DecodeError::InvalidByte(num_prefix_quads * 4 + 4, b'\n')), - engine.decode(&s), - "mode: {:?}, input: {}", - mode, - s - ); - } -} - -fn do_invalid_trailing_padding_as_invalid_length(engine: impl Engine, mode: DecodePaddingMode) { - for num_prefix_quads in 0..256 { - let mut s: String = "ABCD".repeat(num_prefix_quads); - s.push_str("Cg==="); - - assert_eq!( - Err(DecodeError::InvalidLength), - engine.decode(&s), - "mode: {:?}, input: {}", - mode, - s - ); - } -} - /// Returns a tuple of the original data length, the encoded data length (just data), and the length including padding. /// /// Vecs provided should be empty. diff --git a/src/read/decoder.rs b/src/read/decoder.rs index b656ae3..125eeab 100644 --- a/src/read/decoder.rs +++ b/src/read/decoder.rs @@ -35,37 +35,39 @@ pub struct DecoderReader<'e, E: Engine, R: io::Read> { /// Where b64 data is read from inner: R, - // Holds b64 data read from the delegate reader. + /// Holds b64 data read from the delegate reader. b64_buffer: [u8; BUF_SIZE], - // The start of the pending buffered data in b64_buffer. + /// The start of the pending buffered data in `b64_buffer`. b64_offset: usize, - // The amount of buffered b64 data. + /// The amount of buffered b64 data after `b64_offset` in `b64_len`. b64_len: usize, - // Since the caller may provide us with a buffer of size 1 or 2 that's too small to copy a - // decoded chunk in to, we have to be able to hang on to a few decoded bytes. - // Technically we only need to hold 2 bytes but then we'd need a separate temporary buffer to - // decode 3 bytes into and then juggle copying one byte into the provided read buf and the rest - // into here, which seems like a lot of complexity for 1 extra byte of storage. - decoded_buffer: [u8; DECODED_CHUNK_SIZE], - // index of start of decoded data + /// Since the caller may provide us with a buffer of size 1 or 2 that's too small to copy a + /// decoded chunk in to, we have to be able to hang on to a few decoded bytes. + /// Technically we only need to hold 2 bytes, but then we'd need a separate temporary buffer to + /// decode 3 bytes into and then juggle copying one byte into the provided read buf and the rest + /// into here, which seems like a lot of complexity for 1 extra byte of storage. + decoded_chunk_buffer: [u8; DECODED_CHUNK_SIZE], + /// Index of start of decoded data in `decoded_chunk_buffer` decoded_offset: usize, - // length of decoded data + /// Length of decoded data after `decoded_offset` in `decoded_chunk_buffer` decoded_len: usize, - // used to provide accurate offsets in errors - total_b64_decoded: usize, - // offset of previously seen padding, if any + /// Input length consumed so far. + /// Used to provide accurate offsets in errors + input_consumed_len: usize, + /// offset of previously seen padding, if any padding_offset: Option, } +// exclude b64_buffer as it's uselessly large impl<'e, E: Engine, R: io::Read> fmt::Debug for DecoderReader<'e, E, R> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { f.debug_struct("DecoderReader") .field("b64_offset", &self.b64_offset) .field("b64_len", &self.b64_len) - .field("decoded_buffer", &self.decoded_buffer) + .field("decoded_chunk_buffer", &self.decoded_chunk_buffer) .field("decoded_offset", &self.decoded_offset) .field("decoded_len", &self.decoded_len) - .field("total_b64_decoded", &self.total_b64_decoded) + .field("input_consumed_len", &self.input_consumed_len) .field("padding_offset", &self.padding_offset) .finish() } @@ -80,10 +82,10 @@ impl<'e, E: Engine, R: io::Read> DecoderReader<'e, E, R> { b64_buffer: [0; BUF_SIZE], b64_offset: 0, b64_len: 0, - decoded_buffer: [0; DECODED_CHUNK_SIZE], + decoded_chunk_buffer: [0; DECODED_CHUNK_SIZE], decoded_offset: 0, decoded_len: 0, - total_b64_decoded: 0, + input_consumed_len: 0, padding_offset: None, } } @@ -100,7 +102,7 @@ impl<'e, E: Engine, R: io::Read> DecoderReader<'e, E, R> { debug_assert!(copy_len <= self.decoded_len); buf[..copy_len].copy_from_slice( - &self.decoded_buffer[self.decoded_offset..self.decoded_offset + copy_len], + &self.decoded_chunk_buffer[self.decoded_offset..self.decoded_offset + copy_len], ); self.decoded_offset += copy_len; @@ -146,18 +148,22 @@ impl<'e, E: Engine, R: io::Read> DecoderReader<'e, E, R> { ) .map_err(|e| match e { DecodeError::InvalidByte(offset, byte) => { - // This can be incorrect, but not in a way that probably matters to anyone: - // if there was padding handled in a previous decode, and we are now getting - // InvalidByte due to more padding, we should arguably report InvalidByte with - // PAD_BYTE at the original padding position (`self.padding_offset`), but we - // don't have a good way to tie those two cases together, so instead we - // just report the invalid byte as if the previous padding, and its possibly - // related downgrade to a now invalid byte, didn't happen. - DecodeError::InvalidByte(self.total_b64_decoded + offset, byte) + match (byte, self.padding_offset) { + // if there was padding in a previous block of decoding that happened to + // be correct, and we now find more padding that happens to be incorrect, + // to be consistent with non-reader decodes, record the error at the first + // padding + (PAD_BYTE, Some(first_pad_offset)) => { + DecodeError::InvalidByte(first_pad_offset, PAD_BYTE) + } + _ => DecodeError::InvalidByte(self.input_consumed_len + offset, byte), + } + } + DecodeError::InvalidLength(len) => { + DecodeError::InvalidLength(self.input_consumed_len + len) } - DecodeError::InvalidLength => DecodeError::InvalidLength, DecodeError::InvalidLastSymbol(offset, byte) => { - DecodeError::InvalidLastSymbol(self.total_b64_decoded + offset, byte) + DecodeError::InvalidLastSymbol(self.input_consumed_len + offset, byte) } DecodeError::InvalidPadding => DecodeError::InvalidPadding, }) @@ -176,8 +182,8 @@ impl<'e, E: Engine, R: io::Read> DecoderReader<'e, E, R> { self.padding_offset = self.padding_offset.or(decode_metadata .padding_offset - .map(|offset| self.total_b64_decoded + offset)); - self.total_b64_decoded += b64_len_to_decode; + .map(|offset| self.input_consumed_len + offset)); + self.input_consumed_len += b64_len_to_decode; self.b64_offset += b64_len_to_decode; self.b64_len -= b64_len_to_decode; @@ -283,7 +289,7 @@ impl<'e, E: Engine, R: io::Read> io::Read for DecoderReader<'e, E, R> { let to_decode = cmp::min(self.b64_len, BASE64_CHUNK_SIZE); let decoded = self.decode_to_buf(to_decode, &mut decoded_chunk[..])?; - self.decoded_buffer[..decoded].copy_from_slice(&decoded_chunk[..decoded]); + self.decoded_chunk_buffer[..decoded].copy_from_slice(&decoded_chunk[..decoded]); self.decoded_offset = 0; self.decoded_len = decoded;