Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tweak UTF-8 validation 'roll back' logic #621

Merged
merged 1 commit into from
Oct 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 16 additions & 12 deletions cbits/aarch64/is-valid-utf8.c
Original file line number Diff line number Diff line change
Expand Up @@ -260,20 +260,24 @@ int bytestring_is_valid_utf8(uint8_t const *const src, size_t const len) {
//'Roll back' our pointer a little to prepare for a slow search of the rest.
uint32_t token;
vst1q_lane_u32(&token, vreinterpretq_u32_u8(prev_input), 3);
// We cast this pointer to avoid a redundant check against < 127, as any such
// value would be negative in signed form.
int8_t const *token_ptr = (int8_t const *)&token;
ptrdiff_t lookahead = 0;
if (token_ptr[3] > (int8_t)0xBF) {
lookahead = 1;
} else if (token_ptr[2] > (int8_t)0xBF) {
lookahead = 2;
} else if (token_ptr[1] > (int8_t)0xBF) {
lookahead = 3;
uint8_t const *token_ptr = (uint8_t const *)&token;
ptrdiff_t rollback = 0;
// We must not roll back if no big blocks were processed, as then
// the fallback function would examine out-of-bounds data (#620).
// In that case, prev_input contains only nulls and we skip the if body.
if (token_ptr[3] >= 0x80u) {
// Look for an incomplete multi-byte code point
if (token_ptr[3] >= 0xC0u) {
rollback = 1;
} else if (token_ptr[2] >= 0xE0u) {
rollback = 2;
} else if (token_ptr[1] >= 0xF0u) {
rollback = 3;
}
}
// Finish the job.
uint8_t const *const small_ptr = ptr - lookahead;
size_t const small_len = remaining + lookahead;
uint8_t const *const small_ptr = ptr - rollback;
size_t const small_len = remaining + rollback;
return is_valid_utf8_fallback(small_ptr, small_len);
}

Expand Down
58 changes: 36 additions & 22 deletions cbits/is-valid-utf8.c
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ static int8_t const range_max_lookup[16] = {
// +------------+---------------+------------------+----------------+
// | F0 | 3 | 3 | 6 |
// +------------+---------------+------------------+----------------+
// | F4 | 4 | 4 | 8 |
// | F4 | 3 | 4 | 7 |
// +------------+---------------+------------------+----------------+
// index1 -> E0, index14 -> ED
static int8_t const df_ee_lookup[16] = {
Expand Down Expand Up @@ -498,20 +498,27 @@ is_valid_utf8_ssse3(uint8_t const *const src, size_t const len) {
return 0;
}
// 'Roll back' our pointer a little to prepare for a slow search of the rest.
int16_t tokens[2];
uint16_t tokens[2];
tokens[0] = _mm_extract_epi16(prev_input, 6);
tokens[1] = _mm_extract_epi16(prev_input, 7);
int8_t const *token_ptr = (int8_t const *)tokens;
ptrdiff_t lookahead = 0;
if (token_ptr[3] > (int8_t)0xBF) {
lookahead = 1;
} else if (token_ptr[2] > (int8_t)0xBF) {
lookahead = 2;
} else if (token_ptr[1] > (int8_t)0xBF) {
lookahead = 3;
uint8_t const *token_ptr = (uint8_t const *)tokens;
ptrdiff_t rollback = 0;
// We must not roll back if no big blocks were processed, as then
// the fallback function would examine out-of-bounds data (#620).
// In that case, prev_input contains only nulls and we skip the if body.
if (token_ptr[3] >= 0x80u) {
// Look for an incomplete multi-byte code point
if (token_ptr[3] >= 0xC0u) {
rollback = 1;
} else if (token_ptr[2] >= 0xE0u) {
rollback = 2;
} else if (token_ptr[1] >= 0xF0u) {
rollback = 3;
}
}
uint8_t const *const small_ptr = ptr - lookahead;
size_t const small_len = remaining + lookahead;
// Finish the job.
uint8_t const *const small_ptr = ptr - rollback;
size_t const small_len = remaining + rollback;
return is_valid_utf8_fallback(small_ptr, small_len);
}

Expand Down Expand Up @@ -704,17 +711,24 @@ is_valid_utf8_avx2(uint8_t const *const src, size_t const len) {
}
// 'Roll back' our pointer a little to prepare for a slow search of the rest.
uint32_t tokens_blob = _mm256_extract_epi32(prev_input, 7);
int8_t const *tokens = (int8_t const *)&tokens_blob;
ptrdiff_t lookahead = 0;
if (tokens[3] > (int8_t)0xBF) {
lookahead = 1;
} else if (tokens[2] > (int8_t)0xBF) {
lookahead = 2;
} else if (tokens[1] > (int8_t)0xBF) {
lookahead = 3;
uint8_t const *token_ptr = (uint8_t const *)&tokens_blob;
ptrdiff_t rollback = 0;
// We must not roll back if no big blocks were processed, as then
// the fallback function would examine out-of-bounds data (#620).
// In that case, prev_input contains only nulls and we skip the if body.
if (token_ptr[3] >= 0x80u) {
// Look for an incomplete multi-byte code point
if (token_ptr[3] >= 0xC0u) {
rollback = 1;
} else if (token_ptr[2] >= 0xE0u) {
rollback = 2;
} else if (token_ptr[1] >= 0xF0u) {
rollback = 3;
}
}
uint8_t const *const small_ptr = ptr - lookahead;
size_t const small_len = remaining + lookahead;
// Finish the job.
uint8_t const *const small_ptr = ptr - rollback;
size_t const small_len = remaining + rollback;
return is_valid_utf8_fallback(small_ptr, small_len);
}

Expand Down
63 changes: 36 additions & 27 deletions tests/IsValidUtf8.hs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ checkRegressions = [
testProperty "Three invalid bytes between spaces" $
not $ B.isValidUtf8 threeBytesBetweenSpaces,
testProperty "ASCII stride and invalid multibyte sequence" $
not $ B.isValidUtf8 asciiAndInvalidMultiByte
not $ B.isValidUtf8 asciiAndInvalidMultiByte,
testProperty "Splitting valid in two" splitValid
]
where
tooHigh :: ByteString
Expand All @@ -68,13 +69,21 @@ checkRegressions = [
threeBytesBetweenSpaces = fromList $ replicate 125 32 ++ [242, 134, 159] ++ replicate 128 32

badBlockEnd :: Property
badBlockEnd =
forAllShrinkShow genBadBlock shrinkBadBlock showBadBlock $ \(BadBlock bs) ->
badBlockEnd =
forAllShrinkShow genBadBlock shrinkBadBlock showBadBlock $ \(BadBlock bs) ->
not . B.isValidUtf8 $ bs

asciiAndInvalidMultiByte :: ByteString
asciiAndInvalidMultiByte = fromList $ replicate 32 48 ++ [235, 185]

splitValid :: Property
splitValid = forAll genValidUtf8 $ \bs ->
forAll (choose (0, B.length bs)) $ \k ->
case B.splitAt k bs of
-- q may have non-zero offset, which
-- allows this property test to tickle #620
(p, q) -> B.isValidUtf8 p == B.isValidUtf8 q

-- Helpers

-- A 128-byte sequence with a single bad byte at the end, with the rest being
Expand All @@ -98,7 +107,7 @@ showBadBlock :: BadBlock -> String
showBadBlock (BadBlock bs) = let asList = toList bs in
foldr showHex "" asList

data Utf8Sequence =
data Utf8Sequence =
One Word8 |
Two Word8 Word8 |
Three Word8 Word8 Word8 |
Expand All @@ -116,7 +125,7 @@ instance Arbitrary Utf8Sequence where
genThree :: Gen Utf8Sequence
genThree = do
w1 <- elements [0xE0 .. 0xED]
w2 <- elements $ case w1 of
w2 <- elements $ case w1 of
0xE0 -> [0xA0 .. 0xBF]
0xED -> [0x80 .. 0x9F]
_ -> [0x80 .. 0xBF]
Expand All @@ -125,54 +134,54 @@ instance Arbitrary Utf8Sequence where
genFour :: Gen Utf8Sequence
genFour = do
w1 <- elements [0xF0 .. 0xF4]
w2 <- elements $ case w1 of
w2 <- elements $ case w1 of
0xF0 -> [0x90 .. 0xBF]
0xF4 -> [0x80 .. 0x8F]
_ -> [0x80 .. 0xBF]
w3 <- elements [0x80 .. 0xBF]
w4 <- elements [0x80 .. 0xBF]
pure . Four w1 w2 w3 $ w4
shrink = \case
One w1 -> One <$> case w1 of
One w1 -> One <$> case w1 of
0x00 -> []
_ -> [0x00 .. (w1 - 1)]
Two w1 w2 -> case (w1, w2) of
Two w1 w2 -> case (w1, w2) of
(0xC2, 0x80) -> allOnes
_ -> (Two <$> [0xC2 .. (w1 - 1)] <*> [0x80 .. (w2 - 1)]) ++ allOnes
Three w1 w2 w3 -> case (w1, w2, w3) of
Three w1 w2 w3 -> case (w1, w2, w3) of
(0xE0, 0xA0, 0x80) -> allTwos ++ allOnes
(0xE0, 0xA0, _) -> (Three 0xE0 0xA0 <$> [0x80 .. (w3 - 1)]) ++ allTwos ++ allOnes
(0xE0, _, _) ->
(0xE0, _, _) ->
(Three 0xE0 <$> [0xA0 .. (w2 - 1)] <*> [0x80 .. (w3 - 1)]) ++ allTwos ++ allOnes
_ -> do
w1' <- [0xE0 .. (w1 - 1)]
case w1' of
0xE0 -> (Three 0xE0 <$> [0xA0 .. 0xBF] <*> [0x80 .. 0xBF]) ++
allTwos ++
case w1' of
0xE0 -> (Three 0xE0 <$> [0xA0 .. 0xBF] <*> [0x80 .. 0xBF]) ++
allTwos ++
allOnes
_ -> (Three w1' <$> [0x80 .. 0xBF] <*> [0x80 .. 0xBF]) ++
allTwos ++
_ -> (Three w1' <$> [0x80 .. 0xBF] <*> [0x80 .. 0xBF]) ++
allTwos ++
allOnes
Four w1 w2 w3 w4 -> case (w1, w2, w3, w4) of
Four w1 w2 w3 w4 -> case (w1, w2, w3, w4) of
(0xF0, 0x90, 0x80, 0x80) -> allThrees ++ allTwos ++ allOnes
(0xF0, 0x90, 0x80, _) ->
(Four 0xF0 0x90 0x80 <$> [0x80 .. (w4 - 1)]) ++
(0xF0, 0x90, 0x80, _) ->
(Four 0xF0 0x90 0x80 <$> [0x80 .. (w4 - 1)]) ++
allThrees ++
allTwos ++
allOnes
(0xF0, 0x90, _, _) ->
(0xF0, 0x90, _, _) ->
(Four 0xF0 0x90 <$> [0x80 .. (w3 - 1)] <*> [0x80 .. (w4 - 1)]) ++
allThrees ++
allTwos ++
allOnes
(0xF0, _, _, _) ->
(0xF0, _, _, _) ->
(Four 0xF0 <$> [0x90 .. (w2 - 1)] <*> [0x80 .. (w3 - 1)] <*> [0x80 .. (w4 - 1)]) ++
allThrees ++
allTwos ++
allOnes
_ -> do
w1' <- [0xF0 .. (w1 - 1)]
case w1' of
case w1' of
0xF0 -> (Four 0xF0 <$> [0x90 .. 0xBF] <*> [0x80 .. 0xBF] <*> [0x80 .. 0xBF]) ++
allThrees ++
allTwos ++
Expand All @@ -189,7 +198,7 @@ allTwos :: [Utf8Sequence]
allTwos = Two <$> [0xC2 .. 0xDF] <*> [0x80 .. 0xBF]

allThrees :: [Utf8Sequence]
allThrees = (Three 0xE0 <$> [0xA0 .. 0xBF] <*> [0x80 .. 0xBF]) ++
allThrees = (Three 0xE0 <$> [0xA0 .. 0xBF] <*> [0x80 .. 0xBF]) ++
(Three 0xED <$> [0x80 .. 0x9F] <*> [0x80 .. 0xBF]) ++
(Three <$> [0xE1 .. 0xEC] <*> [0x80 .. 0xBF] <*> [0x80 .. 0xBF]) ++
(Three <$> [0xEE .. 0xEF] <*> [0x80 .. 0xBF] <*> [0x80 .. 0xBF])
Expand Down Expand Up @@ -233,7 +242,7 @@ instance Arbitrary InvalidUtf8 where
, InvalidUtf8 <$> genValidUtf8 <*> genInvalidUtf8 <*> pure mempty
, InvalidUtf8 <$> genValidUtf8 <*> genInvalidUtf8 <*> genValidUtf8
]
shrink (InvalidUtf8 p i s) =
shrink (InvalidUtf8 p i s) =
(InvalidUtf8 p i <$> shrinkValidBS s) ++
((\p' -> InvalidUtf8 p' i s) <$> shrinkValidBS p)

Expand Down Expand Up @@ -262,7 +271,7 @@ genInvalidUtf8 = B.pack <$> oneof [
-- overlong encoding
, do k <- choose (0, 0xFFFF)
let c = chr k
case k of
case k of
_ | k < 0x80 -> oneof [ let (w, x) = ord2 c in pure [w, x]
, let (w, x, y) = ord3 c in pure [w, x, y]
, let (w, x, y, z) = ord4 c in pure [w, x, y, z] ]
Expand All @@ -279,7 +288,7 @@ genInvalidUtf8 = B.pack <$> oneof [
vectorOf k gen

genValidUtf8 :: Gen ByteString
genValidUtf8 = sized $ \size ->
genValidUtf8 = sized $ \size ->
if size <= 0
then pure mempty
else oneof [
Expand All @@ -300,7 +309,7 @@ genValidUtf8 = sized $ \size ->
gen3Byte :: Gen ByteString
gen3Byte = do
b1 <- elements [0xE0 .. 0xED]
b2 <- elements $ case b1 of
b2 <- elements $ case b1 of
0xE0 -> [0xA0 .. 0xBF]
0xED -> [0x80 .. 0x9F]
_ -> [0x80 .. 0xBF]
Expand All @@ -309,7 +318,7 @@ genValidUtf8 = sized $ \size ->
gen4Byte :: Gen ByteString
gen4Byte = do
b1 <- elements [0xF0 .. 0xF4]
b2 <- elements $ case b1 of
b2 <- elements $ case b1 of
0xF0 -> [0x90 .. 0xBF]
0xF4 -> [0x80 .. 0x8F]
_ -> [0x80 .. 0xBF]
Expand Down
Loading