Skip to content

Commit

Permalink
Tweak UTF-8 validation 'roll back' logic (#621)
Browse files Browse the repository at this point in the history
Fixes #620.  We must roll back some if the last SIMD block
contains an incomplete multi-byte code point.  The old logic
for this would roll back by one even if there were zero SIMD
blocks processed, which is exactly the bug.
  • Loading branch information
clyring authored Oct 27, 2023
1 parent 6c880f3 commit 03733af
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 61 deletions.
28 changes: 16 additions & 12 deletions cbits/aarch64/is-valid-utf8.c
Original file line number Diff line number Diff line change
Expand Up @@ -260,20 +260,24 @@ int bytestring_is_valid_utf8(uint8_t const *const src, size_t const len) {
//'Roll back' our pointer a little to prepare for a slow search of the rest.
uint32_t token;
vst1q_lane_u32(&token, vreinterpretq_u32_u8(prev_input), 3);
// We cast this pointer to avoid a redundant check against < 127, as any such
// value would be negative in signed form.
int8_t const *token_ptr = (int8_t const *)&token;
ptrdiff_t lookahead = 0;
if (token_ptr[3] > (int8_t)0xBF) {
lookahead = 1;
} else if (token_ptr[2] > (int8_t)0xBF) {
lookahead = 2;
} else if (token_ptr[1] > (int8_t)0xBF) {
lookahead = 3;
uint8_t const *token_ptr = (uint8_t const *)&token;
ptrdiff_t rollback = 0;
// We must not roll back if no big blocks were processed, as then
// the fallback function would examine out-of-bounds data (#620).
// In that case, prev_input contains only nulls and we skip the if body.
if (token_ptr[3] >= 0x80u) {
// Look for an incomplete multi-byte code point
if (token_ptr[3] >= 0xC0u) {
rollback = 1;
} else if (token_ptr[2] >= 0xE0u) {
rollback = 2;
} else if (token_ptr[1] >= 0xF0u) {
rollback = 3;
}
}
// Finish the job.
uint8_t const *const small_ptr = ptr - lookahead;
size_t const small_len = remaining + lookahead;
uint8_t const *const small_ptr = ptr - rollback;
size_t const small_len = remaining + rollback;
return is_valid_utf8_fallback(small_ptr, small_len);
}

Expand Down
58 changes: 36 additions & 22 deletions cbits/is-valid-utf8.c
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ static int8_t const range_max_lookup[16] = {
// +------------+---------------+------------------+----------------+
// | F0 | 3 | 3 | 6 |
// +------------+---------------+------------------+----------------+
// | F4 | 4 | 4 | 8 |
// | F4 | 3 | 4 | 7 |
// +------------+---------------+------------------+----------------+
// index1 -> E0, index14 -> ED
static int8_t const df_ee_lookup[16] = {
Expand Down Expand Up @@ -498,20 +498,27 @@ is_valid_utf8_ssse3(uint8_t const *const src, size_t const len) {
return 0;
}
// 'Roll back' our pointer a little to prepare for a slow search of the rest.
int16_t tokens[2];
uint16_t tokens[2];
tokens[0] = _mm_extract_epi16(prev_input, 6);
tokens[1] = _mm_extract_epi16(prev_input, 7);
int8_t const *token_ptr = (int8_t const *)tokens;
ptrdiff_t lookahead = 0;
if (token_ptr[3] > (int8_t)0xBF) {
lookahead = 1;
} else if (token_ptr[2] > (int8_t)0xBF) {
lookahead = 2;
} else if (token_ptr[1] > (int8_t)0xBF) {
lookahead = 3;
uint8_t const *token_ptr = (uint8_t const *)tokens;
ptrdiff_t rollback = 0;
// We must not roll back if no big blocks were processed, as then
// the fallback function would examine out-of-bounds data (#620).
// In that case, prev_input contains only nulls and we skip the if body.
if (token_ptr[3] >= 0x80u) {
// Look for an incomplete multi-byte code point
if (token_ptr[3] >= 0xC0u) {
rollback = 1;
} else if (token_ptr[2] >= 0xE0u) {
rollback = 2;
} else if (token_ptr[1] >= 0xF0u) {
rollback = 3;
}
}
uint8_t const *const small_ptr = ptr - lookahead;
size_t const small_len = remaining + lookahead;
// Finish the job.
uint8_t const *const small_ptr = ptr - rollback;
size_t const small_len = remaining + rollback;
return is_valid_utf8_fallback(small_ptr, small_len);
}

Expand Down Expand Up @@ -704,17 +711,24 @@ is_valid_utf8_avx2(uint8_t const *const src, size_t const len) {
}
// 'Roll back' our pointer a little to prepare for a slow search of the rest.
uint32_t tokens_blob = _mm256_extract_epi32(prev_input, 7);
int8_t const *tokens = (int8_t const *)&tokens_blob;
ptrdiff_t lookahead = 0;
if (tokens[3] > (int8_t)0xBF) {
lookahead = 1;
} else if (tokens[2] > (int8_t)0xBF) {
lookahead = 2;
} else if (tokens[1] > (int8_t)0xBF) {
lookahead = 3;
uint8_t const *token_ptr = (uint8_t const *)&tokens_blob;
ptrdiff_t rollback = 0;
// We must not roll back if no big blocks were processed, as then
// the fallback function would examine out-of-bounds data (#620).
// In that case, prev_input contains only nulls and we skip the if body.
if (token_ptr[3] >= 0x80u) {
// Look for an incomplete multi-byte code point
if (token_ptr[3] >= 0xC0u) {
rollback = 1;
} else if (token_ptr[2] >= 0xE0u) {
rollback = 2;
} else if (token_ptr[1] >= 0xF0u) {
rollback = 3;
}
}
uint8_t const *const small_ptr = ptr - lookahead;
size_t const small_len = remaining + lookahead;
// Finish the job.
uint8_t const *const small_ptr = ptr - rollback;
size_t const small_len = remaining + rollback;
return is_valid_utf8_fallback(small_ptr, small_len);
}

Expand Down
63 changes: 36 additions & 27 deletions tests/IsValidUtf8.hs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ checkRegressions = [
testProperty "Three invalid bytes between spaces" $
not $ B.isValidUtf8 threeBytesBetweenSpaces,
testProperty "ASCII stride and invalid multibyte sequence" $
not $ B.isValidUtf8 asciiAndInvalidMultiByte
not $ B.isValidUtf8 asciiAndInvalidMultiByte,
testProperty "Splitting valid in two" splitValid
]
where
tooHigh :: ByteString
Expand All @@ -68,13 +69,21 @@ checkRegressions = [
threeBytesBetweenSpaces = fromList $ replicate 125 32 ++ [242, 134, 159] ++ replicate 128 32

badBlockEnd :: Property
badBlockEnd =
forAllShrinkShow genBadBlock shrinkBadBlock showBadBlock $ \(BadBlock bs) ->
badBlockEnd =
forAllShrinkShow genBadBlock shrinkBadBlock showBadBlock $ \(BadBlock bs) ->
not . B.isValidUtf8 $ bs

asciiAndInvalidMultiByte :: ByteString
asciiAndInvalidMultiByte = fromList $ replicate 32 48 ++ [235, 185]

splitValid :: Property
splitValid = forAll genValidUtf8 $ \bs ->
forAll (choose (0, B.length bs)) $ \k ->
case B.splitAt k bs of
-- q may have non-zero offset, which
-- allows this property test to tickle #620
(p, q) -> B.isValidUtf8 p == B.isValidUtf8 q

-- Helpers

-- A 128-byte sequence with a single bad byte at the end, with the rest being
Expand All @@ -98,7 +107,7 @@ showBadBlock :: BadBlock -> String
showBadBlock (BadBlock bs) = let asList = toList bs in
foldr showHex "" asList

data Utf8Sequence =
data Utf8Sequence =
One Word8 |
Two Word8 Word8 |
Three Word8 Word8 Word8 |
Expand All @@ -116,7 +125,7 @@ instance Arbitrary Utf8Sequence where
genThree :: Gen Utf8Sequence
genThree = do
w1 <- elements [0xE0 .. 0xED]
w2 <- elements $ case w1 of
w2 <- elements $ case w1 of
0xE0 -> [0xA0 .. 0xBF]
0xED -> [0x80 .. 0x9F]
_ -> [0x80 .. 0xBF]
Expand All @@ -125,54 +134,54 @@ instance Arbitrary Utf8Sequence where
genFour :: Gen Utf8Sequence
genFour = do
w1 <- elements [0xF0 .. 0xF4]
w2 <- elements $ case w1 of
w2 <- elements $ case w1 of
0xF0 -> [0x90 .. 0xBF]
0xF4 -> [0x80 .. 0x8F]
_ -> [0x80 .. 0xBF]
w3 <- elements [0x80 .. 0xBF]
w4 <- elements [0x80 .. 0xBF]
pure . Four w1 w2 w3 $ w4
shrink = \case
One w1 -> One <$> case w1 of
One w1 -> One <$> case w1 of
0x00 -> []
_ -> [0x00 .. (w1 - 1)]
Two w1 w2 -> case (w1, w2) of
Two w1 w2 -> case (w1, w2) of
(0xC2, 0x80) -> allOnes
_ -> (Two <$> [0xC2 .. (w1 - 1)] <*> [0x80 .. (w2 - 1)]) ++ allOnes
Three w1 w2 w3 -> case (w1, w2, w3) of
Three w1 w2 w3 -> case (w1, w2, w3) of
(0xE0, 0xA0, 0x80) -> allTwos ++ allOnes
(0xE0, 0xA0, _) -> (Three 0xE0 0xA0 <$> [0x80 .. (w3 - 1)]) ++ allTwos ++ allOnes
(0xE0, _, _) ->
(0xE0, _, _) ->
(Three 0xE0 <$> [0xA0 .. (w2 - 1)] <*> [0x80 .. (w3 - 1)]) ++ allTwos ++ allOnes
_ -> do
w1' <- [0xE0 .. (w1 - 1)]
case w1' of
0xE0 -> (Three 0xE0 <$> [0xA0 .. 0xBF] <*> [0x80 .. 0xBF]) ++
allTwos ++
case w1' of
0xE0 -> (Three 0xE0 <$> [0xA0 .. 0xBF] <*> [0x80 .. 0xBF]) ++
allTwos ++
allOnes
_ -> (Three w1' <$> [0x80 .. 0xBF] <*> [0x80 .. 0xBF]) ++
allTwos ++
_ -> (Three w1' <$> [0x80 .. 0xBF] <*> [0x80 .. 0xBF]) ++
allTwos ++
allOnes
Four w1 w2 w3 w4 -> case (w1, w2, w3, w4) of
Four w1 w2 w3 w4 -> case (w1, w2, w3, w4) of
(0xF0, 0x90, 0x80, 0x80) -> allThrees ++ allTwos ++ allOnes
(0xF0, 0x90, 0x80, _) ->
(Four 0xF0 0x90 0x80 <$> [0x80 .. (w4 - 1)]) ++
(0xF0, 0x90, 0x80, _) ->
(Four 0xF0 0x90 0x80 <$> [0x80 .. (w4 - 1)]) ++
allThrees ++
allTwos ++
allOnes
(0xF0, 0x90, _, _) ->
(0xF0, 0x90, _, _) ->
(Four 0xF0 0x90 <$> [0x80 .. (w3 - 1)] <*> [0x80 .. (w4 - 1)]) ++
allThrees ++
allTwos ++
allOnes
(0xF0, _, _, _) ->
(0xF0, _, _, _) ->
(Four 0xF0 <$> [0x90 .. (w2 - 1)] <*> [0x80 .. (w3 - 1)] <*> [0x80 .. (w4 - 1)]) ++
allThrees ++
allTwos ++
allOnes
_ -> do
w1' <- [0xF0 .. (w1 - 1)]
case w1' of
case w1' of
0xF0 -> (Four 0xF0 <$> [0x90 .. 0xBF] <*> [0x80 .. 0xBF] <*> [0x80 .. 0xBF]) ++
allThrees ++
allTwos ++
Expand All @@ -189,7 +198,7 @@ allTwos :: [Utf8Sequence]
allTwos = Two <$> [0xC2 .. 0xDF] <*> [0x80 .. 0xBF]

allThrees :: [Utf8Sequence]
allThrees = (Three 0xE0 <$> [0xA0 .. 0xBF] <*> [0x80 .. 0xBF]) ++
allThrees = (Three 0xE0 <$> [0xA0 .. 0xBF] <*> [0x80 .. 0xBF]) ++
(Three 0xED <$> [0x80 .. 0x9F] <*> [0x80 .. 0xBF]) ++
(Three <$> [0xE1 .. 0xEC] <*> [0x80 .. 0xBF] <*> [0x80 .. 0xBF]) ++
(Three <$> [0xEE .. 0xEF] <*> [0x80 .. 0xBF] <*> [0x80 .. 0xBF])
Expand Down Expand Up @@ -233,7 +242,7 @@ instance Arbitrary InvalidUtf8 where
, InvalidUtf8 <$> genValidUtf8 <*> genInvalidUtf8 <*> pure mempty
, InvalidUtf8 <$> genValidUtf8 <*> genInvalidUtf8 <*> genValidUtf8
]
shrink (InvalidUtf8 p i s) =
shrink (InvalidUtf8 p i s) =
(InvalidUtf8 p i <$> shrinkValidBS s) ++
((\p' -> InvalidUtf8 p' i s) <$> shrinkValidBS p)

Expand Down Expand Up @@ -262,7 +271,7 @@ genInvalidUtf8 = B.pack <$> oneof [
-- overlong encoding
, do k <- choose (0, 0xFFFF)
let c = chr k
case k of
case k of
_ | k < 0x80 -> oneof [ let (w, x) = ord2 c in pure [w, x]
, let (w, x, y) = ord3 c in pure [w, x, y]
, let (w, x, y, z) = ord4 c in pure [w, x, y, z] ]
Expand All @@ -279,7 +288,7 @@ genInvalidUtf8 = B.pack <$> oneof [
vectorOf k gen

genValidUtf8 :: Gen ByteString
genValidUtf8 = sized $ \size ->
genValidUtf8 = sized $ \size ->
if size <= 0
then pure mempty
else oneof [
Expand All @@ -300,7 +309,7 @@ genValidUtf8 = sized $ \size ->
gen3Byte :: Gen ByteString
gen3Byte = do
b1 <- elements [0xE0 .. 0xED]
b2 <- elements $ case b1 of
b2 <- elements $ case b1 of
0xE0 -> [0xA0 .. 0xBF]
0xED -> [0x80 .. 0x9F]
_ -> [0x80 .. 0xBF]
Expand All @@ -309,7 +318,7 @@ genValidUtf8 = sized $ \size ->
gen4Byte :: Gen ByteString
gen4Byte = do
b1 <- elements [0xF0 .. 0xF4]
b2 <- elements $ case b1 of
b2 <- elements $ case b1 of
0xF0 -> [0x90 .. 0xBF]
0xF4 -> [0x80 .. 0x8F]
_ -> [0x80 .. 0xBF]
Expand Down

0 comments on commit 03733af

Please sign in to comment.