Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] SIMD acceleration #50

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions base64.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@ source-repository head
type: git
location: https://github.com/emilypi/base64.git

flag simd
description: use libbase64 simd library
default: True
manual: True

library
exposed-modules:
Data.Base64.Types
Expand Down Expand Up @@ -61,7 +66,13 @@ library

hs-source-dirs: src
default-language: Haskell2010
ghc-options: -Wall
ghc-options: -Wall -O2

if flag(simd)
build-depends:
libbase64-bindings
cpp-options:
-DSIMD

test-suite base64-tests
default-language: Haskell2010
Expand Down Expand Up @@ -99,4 +110,4 @@ benchmark bench
, random-bytestring
, text >=2.0

ghc-options: -Wall -rtsopts
ghc-options: -Wall -O2 -rtsopts
80 changes: 43 additions & 37 deletions default.nix
Original file line number Diff line number Diff line change
@@ -1,41 +1,47 @@
{ nixpkgs ? import <nixpkgs> {}, compiler ? "default", doBenchmark ? false }:

let
hs-libbase64 = import (builtins.fetchTarball {
url = "https://github.com/chessai/hs-libbase64-bindings/archive/e8a5194742f41ce4109b05098a2859e8052ad1c1.tar.gz";
sha256 = "1xqfjqb1ghh8idnindc6gfr62d78m5cc6jpbhv1hja8lkdrl8qf8";
}) {};
in
{
compiler ? "ghc944",
pkgs ? import <nixpkgs> {
config = {
allowBroken = false;
allowUnfree = false;
};

overlays = [ ];
/*(self: super: {
# not in nixpkgs yet
inherit (hs-libbase64) libbase64;

"haskell.packages.${compiler}" = haskell.packages.${compiler}.override {
overrides = hself: hsuper: {
inherit (hs-libbase64) libbase64-bindings;
};
};
})*/
},
returnShellEnv ? false,
}:

inherit (nixpkgs) pkgs;

f = { mkDerivation, base, base64-bytestring, bytestring
, criterion, deepseq, ghc-byteorder, QuickCheck, random-bytestring
, stdenv, tasty, tasty-hunit, tasty-quickcheck, text, text-short
}:
mkDerivation {
pname = "base64";
version = "0.4.2.2";
src = ./.;
libraryHaskellDepends = [
base bytestring deepseq ghc-byteorder text text-short
];
testHaskellDepends = [
base base64-bytestring bytestring QuickCheck random-bytestring
tasty tasty-hunit tasty-quickcheck text text-short
];
benchmarkHaskellDepends = [
base base64-bytestring bytestring criterion deepseq
random-bytestring text
];
homepage = "https://github.com/emilypi/base64";
description = "Fast RFC 4648-compliant Base64 encoding";
license = stdenv.lib.licenses.bsd3;
};

haskellPackages = if compiler == "default"
then pkgs.haskellPackages
else pkgs.haskell.packages.${compiler};

variant = if doBenchmark then pkgs.haskell.lib.doBenchmark else pkgs.lib.id;

drv = variant (haskellPackages.callPackage f {});

let
nix-gitignore = import (pkgs.fetchFromGitHub {
owner = "hercules-ci";
repo = "gitignore";
rev = "9e80c4d83026fa6548bc53b1a6fab8549a6991f6";
sha256 = "04n9chlpbifgc5pa3zx6ff3rji9am6msrbn1z3x1iinjz2xjfp4p";
}) {};
in
pkgs.haskell.packages.${compiler}.developPackage {
name = "base64";
root = nix-gitignore.gitignoreSource ./.;

overrides = self: super: {
inherit (hs-libbase64) libbase64-bindings;
};

if pkgs.lib.inNixShell then drv.env else drv
inherit returnShellEnv;
}
1 change: 1 addition & 0 deletions shell.nix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
import ./default.nix { returnShellEnv = true; }
101 changes: 98 additions & 3 deletions src/Data/ByteString/Base64/Internal/Head.hs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
-- |
-- Module : Data.ByteString.Base64.Internal.Head
-- Copyright : (c) 2019-2022 Emily Pillmore
Expand Down Expand Up @@ -35,9 +36,50 @@ import GHC.Word

import System.IO.Unsafe ( unsafeDupablePerformIO )

#ifdef SIMD
import Foreign.C.Types (CChar, CInt, CSize)
import Foreign.Storable (peek)
import qualified Foreign.Marshal.Utils as Foreign
import qualified Data.Text as T
import LibBase64Bindings
#endif

encodeBase64_ :: EncodingTable -> ByteString -> ByteString
encodeBase64_ (EncodingTable !aptr !efp) (PS !sfp !soff !slen) =
#ifdef SIMD
encodeBase64_ table b@(PS _ _ !slen)
| slen < threshold = encodeBase64Loop_ table b
| otherwise = encodeBase64Simd_ b
where
!threshold = 1000 -- 1k
#else
encodeBase64_ table b = encodeBase64Loop_ table b
#endif
{-# inline encodeBase64_ #-}

#ifdef SIMD
encodeBase64Simd_ :: ByteString -> ByteString
encodeBase64Simd_ (PS !sfp !soff !slen) =
unsafeDupablePerformIO $ do
dfp <- mallocPlainForeignPtrBytes dlen
dlenFinal <- do
withForeignPtr dfp $ \out ->
withForeignPtr sfp $ \src -> do
Foreign.with (intToCSize dlen) $ \outlen -> do
base64_encode
(plusPtr (castPtr src :: Ptr CChar) soff)
(intToCSize slen)
out
outlen
base64Flags
peek outlen
pure (PS (castForeignPtr dfp) 0 (cSizeToInt dlenFinal))
where
!dlen = 4 * ((slen + 2) `div` 3)
!base64Flags = 0
#endif

encodeBase64Loop_ :: EncodingTable -> ByteString -> ByteString
encodeBase64Loop_ (EncodingTable !aptr !efp) (PS !sfp !soff !slen) =
unsafeDupablePerformIO $ do
dfp <- mallocPlainForeignPtrBytes dlen
withForeignPtr dfp $ \dptr ->
Expand Down Expand Up @@ -70,6 +112,33 @@ encodeBase64Nopad_ (EncodingTable !aptr !efp) (PS !sfp !soff !slen) =
where
!dlen = 4 * ((slen + 2) `div` 3)

#ifdef SIMD
decodeBase64Simd_ :: ByteString -> IO (Either Text ByteString)
decodeBase64Simd_ (PS !sfp !soff !slen) = do
withForeignPtr sfp $ \src -> do
dfp <- mallocPlainForeignPtrBytes dlen
edlenFinal :: Either Text CSize <- do
withForeignPtr dfp $ \out -> do
Foreign.with (intToCSize dlen) $ \outlen -> do
decodeResult <- base64_decode
(plusPtr (castPtr src :: Ptr CChar) soff)
(intToCSize slen)
out
outlen
base64Flags
case decodeResult of
1 -> Right <$> peek outlen
0 -> pure (Left "SIMD: Invalid input")
(-1) -> pure (Left "Invalid Codec")
x -> pure (Left ("Unexpected result from libbase64 base64_decode: " <> T.pack (show (cIntToInt x))))
pure $ fmap
(\dlenFinal -> PS (castForeignPtr dfp) 0 (cSizeToInt dlenFinal))
edlenFinal
where
!dlen = (slen `quot` 4) * 3
!base64Flags = 0
#endif

-- | The main decode function. Takes a padding flag, a decoding table, and
-- the input value, producing either an error string on the left, or a
-- decoded value.
Expand All @@ -84,7 +153,22 @@ decodeBase64_
:: ForeignPtr Word8
-> ByteString
-> IO (Either Text ByteString)
decodeBase64_ !dtfp (PS !sfp !soff !slen) =
#ifdef SIMD
decodeBase64_ dtfp b@(PS _ _ !slen)
| slen < threshold = decodeBase64Loop_ dtfp b
| otherwise = decodeBase64Simd_ b
where
!threshold = 250
#else
decodeBase64_ dtfp b = decodeBase64Loop_ dtfp b
#endif
{-# inline decodeBase64_ #-}

decodeBase64Loop_
:: ForeignPtr Word8
-> ByteString
-> IO (Either Text ByteString)
decodeBase64Loop_ !dtfp (PS !sfp !soff !slen) =
withForeignPtr dtfp $ \dtable ->
withForeignPtr sfp $ \sptr -> do
dfp <- mallocPlainForeignPtrBytes dlen
Expand All @@ -95,7 +179,7 @@ decodeBase64_ !dtfp (PS !sfp !soff !slen) =
dptr end dfp
where
!dlen = (slen `quot` 4) * 3
{-# inline decodeBase64_ #-}
{-# inline decodeBase64Loop_ #-}

-- | The main decode function for typed base64 values.
--
Expand Down Expand Up @@ -137,3 +221,14 @@ decodeBase64Lenient_ !dtfp (PS !sfp !soff !slen) = unsafeDupablePerformIO $
dfp
where
!dlen = ((slen + 3) `div` 4) * 3

#ifdef SIMD
intToCSize :: Int -> CSize
intToCSize = fromIntegral

cSizeToInt :: CSize -> Int
cSizeToInt = fromIntegral

cIntToInt :: CInt -> Int
cIntToInt = fromIntegral
#endif