Skip to content

Commit

Permalink
add simd decode; add thresholds for non-simd fallback in encode/decode
Browse files Browse the repository at this point in the history
  • Loading branch information
chessai committed Apr 20, 2023
1 parent 913eda0 commit fbd9196
Showing 1 changed file with 74 additions and 12 deletions.
86 changes: 74 additions & 12 deletions src/Data/ByteString/Base64/Internal/Head.hs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
-- |
-- Module : Data.ByteString.Base64.Internal.Head
-- Copyright : (c) 2019-2022 Emily Pillmore
Expand Down Expand Up @@ -42,15 +43,28 @@ import GHC.Word
import System.IO.Unsafe

#ifdef SIMD
import Foreign.C.Types (CSize, CChar)
import Foreign.C.Types (CChar, CInt, CSize)
import Foreign.Storable (peek)
import qualified Foreign.Marshal.Utils as Foreign
import qualified Data.Text as T
import LibBase64Bindings
#endif

encodeBase64_ :: EncodingTable -> ByteString -> ByteString
#ifdef SIMD
encodeBase64_ _ (PS !sfp !soff !slen) =
encodeBase64_ table b@(PS _ _ !slen)
| slen < threshold = encodeBase64Loop_ table b
| otherwise = encodeBase64Simd_ b
where
!threshold = 1000 -- 1k
#else
encodeBase64_ table b = encodeBase64Loop_ table b
#endif
{-# inline encodeBase64_ #-}

#ifdef SIMD
encodeBase64Simd_ :: ByteString -> ByteString
encodeBase64Simd_ (PS !sfp !soff !slen) =
unsafeDupablePerformIO $ do
dfp <- mallocPlainForeignPtrBytes dlen
dlenFinal <- do
Expand All @@ -68,14 +82,10 @@ encodeBase64_ _ (PS !sfp !soff !slen) =
where
!dlen = 4 * ((slen + 2) `div` 3)
!base64Flags = 0
#endif

intToCSize :: Int -> CSize
intToCSize = fromIntegral

cSizeToInt :: CSize -> Int
cSizeToInt = fromIntegral
#else
encodeBase64_ (EncodingTable !aptr !efp) (PS !sfp !soff !slen) =
encodeBase64Loop_ :: EncodingTable -> ByteString -> ByteString
encodeBase64Loop_ (EncodingTable !aptr !efp) (PS !sfp !soff !slen) =
unsafeDupablePerformIO $ do
dfp <- mallocPlainForeignPtrBytes dlen
withForeignPtr dfp $ \dptr ->
Expand All @@ -90,7 +100,6 @@ encodeBase64_ (EncodingTable !aptr !efp) (PS !sfp !soff !slen) =
(loopTail dfp aptr dptr (castPtr end))
where
!dlen = 4 * ((slen + 2) `div` 3)
#endif

encodeBase64Nopad_ :: EncodingTable -> ByteString -> ByteString
encodeBase64Nopad_ (EncodingTable !aptr !efp) (PS !sfp !soff !slen) =
Expand All @@ -109,6 +118,33 @@ encodeBase64Nopad_ (EncodingTable !aptr !efp) (PS !sfp !soff !slen) =
where
!dlen = 4 * ((slen + 2) `div` 3)

#ifdef SIMD
decodeBase64Simd_ :: ByteString -> IO (Either Text ByteString)
decodeBase64Simd_ (PS !sfp !soff !slen) = do
withForeignPtr sfp $ \src -> do
dfp <- mallocPlainForeignPtrBytes dlen
edlenFinal :: Either Text CSize <- do
withForeignPtr dfp $ \out -> do
Foreign.with (intToCSize dlen) $ \outlen -> do
decodeResult <- base64_decode
(plusPtr (castPtr src :: Ptr CChar) soff)
(intToCSize slen)
out
outlen
base64Flags
case decodeResult of
1 -> Right <$> peek outlen
0 -> pure (Left "SIMD: Invalid input")
(-1) -> pure (Left "Invalid Codec")
x -> pure (Left ("Unexpected result from libbase64 base64_decode: " <> T.pack (show (cIntToInt x))))
pure $ fmap
(\dlenFinal -> PS (castForeignPtr dfp) 0 (cSizeToInt dlenFinal))
edlenFinal
where
!dlen = (slen `quot` 4) * 3
!base64Flags = 0
#endif

-- | The main decode function. Takes a padding flag, a decoding table, and
-- the input value, producing either an error string on the left, or a
-- decoded value.
Expand All @@ -123,7 +159,22 @@ decodeBase64_
:: ForeignPtr Word8
-> ByteString
-> IO (Either Text ByteString)
decodeBase64_ !dtfp (PS !sfp !soff !slen) =
#ifdef SIMD
decodeBase64_ dtfp b@(PS _ _ !slen)
| slen < threshold = decodeBase64Loop_ dtfp b
| otherwise = decodeBase64Simd_ b
where
!threshold = 250
#else
decodeBase64_ dtfp b = decodeBase64Loop_ dtfp b
#endif
{-# inline decodeBase64_ #-}

decodeBase64Loop_
:: ForeignPtr Word8
-> ByteString
-> IO (Either Text ByteString)
decodeBase64Loop_ !dtfp (PS !sfp !soff !slen) =
withForeignPtr dtfp $ \dtable ->
withForeignPtr sfp $ \sptr -> do
dfp <- mallocPlainForeignPtrBytes dlen
Expand All @@ -134,7 +185,7 @@ decodeBase64_ !dtfp (PS !sfp !soff !slen) =
dptr end dfp
where
!dlen = (slen `quot` 4) * 3
{-# inline decodeBase64_ #-}
{-# inline decodeBase64Loop_ #-}

decodeBase64Lenient_ :: ForeignPtr Word8 -> ByteString -> ByteString
decodeBase64Lenient_ !dtfp (PS !sfp !soff !slen) = unsafeDupablePerformIO $
Expand All @@ -150,3 +201,14 @@ decodeBase64Lenient_ !dtfp (PS !sfp !soff !slen) = unsafeDupablePerformIO $
dfp
where
!dlen = ((slen + 3) `div` 4) * 3

#ifdef SIMD
intToCSize :: Int -> CSize
intToCSize = fromIntegral

cSizeToInt :: CSize -> Int
cSizeToInt = fromIntegral

cIntToInt :: CInt -> Int
cIntToInt = fromIntegral
#endif

0 comments on commit fbd9196

Please sign in to comment.