diff --git a/src/Data/ByteString/Base64/Internal/Head.hs b/src/Data/ByteString/Base64/Internal/Head.hs index 61d9c3d..85d49ad 100644 --- a/src/Data/ByteString/Base64/Internal/Head.hs +++ b/src/Data/ByteString/Base64/Internal/Head.hs @@ -1,6 +1,7 @@ {-# LANGUAGE BangPatterns #-} {-# LANGUAGE CPP #-} {-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE ScopedTypeVariables #-} -- | -- Module : Data.ByteString.Base64.Internal.Head -- Copyright : (c) 2019-2022 Emily Pillmore @@ -42,15 +43,28 @@ import GHC.Word import System.IO.Unsafe #ifdef SIMD -import Foreign.C.Types (CSize, CChar) +import Foreign.C.Types (CChar, CInt, CSize) import Foreign.Storable (peek) import qualified Foreign.Marshal.Utils as Foreign +import qualified Data.Text as T import LibBase64Bindings #endif encodeBase64_ :: EncodingTable -> ByteString -> ByteString #ifdef SIMD -encodeBase64_ _ (PS !sfp !soff !slen) = +encodeBase64_ table b@(PS _ _ !slen) + | slen < threshold = encodeBase64Loop_ table b + | otherwise = encodeBase64Simd_ b + where + !threshold = 1000 -- 1k +#else +encodeBase64_ table b = encodeBase64Loop_ table b +#endif +{-# inline encodeBase64_ #-} + +#ifdef SIMD +encodeBase64Simd_ :: ByteString -> ByteString +encodeBase64Simd_ (PS !sfp !soff !slen) = unsafeDupablePerformIO $ do dfp <- mallocPlainForeignPtrBytes dlen dlenFinal <- do @@ -68,14 +82,10 @@ encodeBase64_ _ (PS !sfp !soff !slen) = where !dlen = 4 * ((slen + 2) `div` 3) !base64Flags = 0 +#endif - intToCSize :: Int -> CSize - intToCSize = fromIntegral - - cSizeToInt :: CSize -> Int - cSizeToInt = fromIntegral -#else -encodeBase64_ (EncodingTable !aptr !efp) (PS !sfp !soff !slen) = +encodeBase64Loop_ :: EncodingTable -> ByteString -> ByteString +encodeBase64Loop_ (EncodingTable !aptr !efp) (PS !sfp !soff !slen) = unsafeDupablePerformIO $ do dfp <- mallocPlainForeignPtrBytes dlen withForeignPtr dfp $ \dptr -> @@ -90,7 +100,6 @@ encodeBase64_ (EncodingTable !aptr !efp) (PS !sfp !soff !slen) = (loopTail dfp aptr dptr (castPtr end)) where !dlen = 4 * ((slen + 2) `div` 3) -#endif encodeBase64Nopad_ :: EncodingTable -> ByteString -> ByteString encodeBase64Nopad_ (EncodingTable !aptr !efp) (PS !sfp !soff !slen) = @@ -109,6 +118,33 @@ encodeBase64Nopad_ (EncodingTable !aptr !efp) (PS !sfp !soff !slen) = where !dlen = 4 * ((slen + 2) `div` 3) +#ifdef SIMD +decodeBase64Simd_ :: ByteString -> IO (Either Text ByteString) +decodeBase64Simd_ (PS !sfp !soff !slen) = do + withForeignPtr sfp $ \src -> do + dfp <- mallocPlainForeignPtrBytes dlen + edlenFinal :: Either Text CSize <- do + withForeignPtr dfp $ \out -> do + Foreign.with (intToCSize dlen) $ \outlen -> do + decodeResult <- base64_decode + (plusPtr (castPtr src :: Ptr CChar) soff) + (intToCSize slen) + out + outlen + base64Flags + case decodeResult of + 1 -> Right <$> peek outlen + 0 -> pure (Left "SIMD: Invalid input") + (-1) -> pure (Left "Invalid Codec") + x -> pure (Left ("Unexpected result from libbase64 base64_decode: " <> T.pack (show (cIntToInt x)))) + pure $ fmap + (\dlenFinal -> PS (castForeignPtr dfp) 0 (cSizeToInt dlenFinal)) + edlenFinal + where + !dlen = (slen `quot` 4) * 3 + !base64Flags = 0 +#endif + -- | The main decode function. Takes a padding flag, a decoding table, and -- the input value, producing either an error string on the left, or a -- decoded value. @@ -123,7 +159,22 @@ decodeBase64_ :: ForeignPtr Word8 -> ByteString -> IO (Either Text ByteString) -decodeBase64_ !dtfp (PS !sfp !soff !slen) = +#ifdef SIMD +decodeBase64_ dtfp b@(PS _ _ !slen) + | slen < threshold = decodeBase64Loop_ dtfp b + | otherwise = decodeBase64Simd_ b + where + !threshold = 250 +#else +decodeBase64_ dtfp b = decodeBase64Loop_ dtfp b +#endif +{-# inline decodeBase64_ #-} + +decodeBase64Loop_ + :: ForeignPtr Word8 + -> ByteString + -> IO (Either Text ByteString) +decodeBase64Loop_ !dtfp (PS !sfp !soff !slen) = withForeignPtr dtfp $ \dtable -> withForeignPtr sfp $ \sptr -> do dfp <- mallocPlainForeignPtrBytes dlen @@ -134,7 +185,7 @@ decodeBase64_ !dtfp (PS !sfp !soff !slen) = dptr end dfp where !dlen = (slen `quot` 4) * 3 -{-# inline decodeBase64_ #-} +{-# inline decodeBase64Loop_ #-} decodeBase64Lenient_ :: ForeignPtr Word8 -> ByteString -> ByteString decodeBase64Lenient_ !dtfp (PS !sfp !soff !slen) = unsafeDupablePerformIO $ @@ -150,3 +201,14 @@ decodeBase64Lenient_ !dtfp (PS !sfp !soff !slen) = unsafeDupablePerformIO $ dfp where !dlen = ((slen + 3) `div` 4) * 3 + +#ifdef SIMD +intToCSize :: Int -> CSize +intToCSize = fromIntegral + +cSizeToInt :: CSize -> Int +cSizeToInt = fromIntegral + +cIntToInt :: CInt -> Int +cIntToInt = fromIntegral +#endif