Skip to content

Commit 7a69912

Browse files
Implement the vector indexing operation
1 parent 0dec9b4 commit 7a69912

File tree

7 files changed

+33
-13
lines changed

7 files changed

+33
-13
lines changed

accelerate-llvm-ptx/src/Data/Array/Accelerate/LLVM/PTX/CodeGen/Base.hs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -445,7 +445,7 @@ shfl sop tR val delta = go tR val
445445
repack :: Int32 -> CodeGen PTX (Operands (Vec m Int32))
446446
repack 0 = return $ ir v' (A.undef (VectorScalarType v'))
447447
repack i = do
448-
d <- instr $ ExtractElement (i-1) c
448+
d <- instr $ ExtractElement integralType c (constOp (i-1))
449449
e <- integral integralType d
450450
f <- repack (i-1)
451451
g <- instr $ InsertElement (i-1) (op v' f) (op integralType e)

accelerate-llvm/src/Data/Array/Accelerate/LLVM/CodeGen/Arithmetic.hs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
module Data.Array.Accelerate.LLVM.CodeGen.Arithmetic
2121
where
2222

23+
import Data.Primitive.Vec
24+
2325
import Data.Array.Accelerate.AST ( PrimMaybe )
2426
import Data.Array.Accelerate.Analysis.Match
2527
import Data.Array.Accelerate.Representation.Tag
@@ -464,6 +466,17 @@ min ty x y
464466
| otherwise = do c <- unbool <$> lte ty x y
465467
binop (flip Select c) ty x y
466468

469+
-- Vector operators
470+
-- ----------------------
471+
472+
vecCreate :: VectorType (Vec n a) -> CodeGen arch (Operands (Vec n a))
473+
vecCreate = undefined
474+
475+
vecIndex :: VectorType (Vec n a) -> IntegralType i -> Operands (Vec n a) -> Operands i -> CodeGen arch (Operands a)
476+
vecIndex tv ti (OP_Vec v) i = do
477+
(OP_Int32 i') <- fromIntegral ti (IntegralNumType TypeInt32) i
478+
instr $ ExtractElement TypeInt32 v i'
479+
467480

468481
-- Logical operators
469482
-- -----------------

accelerate-llvm/src/Data/Array/Accelerate/LLVM/CodeGen/Array.hs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ import LLVM.AST.Type.AddrSpace
2828
import LLVM.AST.Type.Instruction
2929
import LLVM.AST.Type.Instruction.Volatile
3030
import LLVM.AST.Type.Operand
31+
import LLVM.AST.Type.Constant
3132
import LLVM.AST.Type.Representation
3233

3334
import Data.Array.Accelerate.Representation.Array
@@ -205,16 +206,15 @@ store addrspace volatility e p v
205206
| SingleScalarType{} <- e = do_ $ Store volatility p v
206207
| VectorScalarType s <- e
207208
, VectorType n base <- s
208-
, m <- fromIntegral n
209-
= if popCount m == 1
209+
= if popCount n == 1
210210
then do_ $ Store volatility p v
211211
else do
212212
p' <- instr' $ PtrCast (PtrPrimType (ScalarPrimType (SingleScalarType base)) addrspace) p
213213
--
214-
let go i
215-
| i >= m = return ()
214+
let go i
215+
| i >= n = return ()
216216
| otherwise = do
217-
x <- instr' $ ExtractElement i v
217+
x <- instr' $ ExtractElement integralType v (constOp n)
218218
q <- instr' $ GetElementPtr p' [integral integralType i]
219219
_ <- instr' $ Store volatility q x
220220
go (i+1)

accelerate-llvm/src/Data/Array/Accelerate/LLVM/CodeGen/Constant.hs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ scalar t = ConstantOperand . ScalarConstant t
6161
single :: SingleType a -> a -> Operand a
6262
single t = scalar (SingleScalarType t)
6363

64-
vector :: VectorType (Vec n a) -> (Vec n a) -> Operand (Vec n a)
64+
vector :: VectorType (Vec n a) -> Vec n a -> Operand (Vec n a)
6565
vector t = scalar (VectorScalarType t)
6666

6767
num :: NumType a -> a -> Operand a

accelerate-llvm/src/Data/Array/Accelerate/LLVM/CodeGen/Exp.hs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,8 @@ import qualified Data.Array.Accelerate.LLVM.CodeGen.Loop as L
4848
import Data.Primitive.Vec
4949

5050
import LLVM.AST.Type.Instruction
51-
import LLVM.AST.Type.Operand ( Operand )
51+
import LLVM.AST.Type.Operand ( Operand(..), constOp)
52+
import LLVM.AST.Type.Constant ( Constant(..), )
5253

5354
import Control.Applicative hiding ( Const )
5455
import Control.Monad
@@ -105,7 +106,7 @@ llvmOfOpenExp top env aenv = cvtE top
105106
llvmOfOpenExp body (env `pushE` (lhs, x)) aenv
106107
Evar (Var _ ix) -> return $ prj ix env
107108
Const tp c -> return $ ir tp $ scalar tp c
108-
PrimConst c -> let tp = (SingleScalarType $ primConstType c)
109+
PrimConst c -> let tp = primConstType c
109110
in return $ ir tp $ scalar tp $ primConst c
110111
PrimApp f x -> primFun f x
111112
Undef tp -> return $ ir tp $ undef tp
@@ -165,7 +166,7 @@ llvmOfOpenExp top env aenv = cvtE top
165166
go (VecRnil _) _ = internalError "index mismatch"
166167
go (VecRsucc vecr') i = do
167168
xs <- go vecr' (i - 1)
168-
x <- instr' $ ExtractElement (fromIntegral i - 1) vec
169+
x <- instr' $ ExtractElement TypeInt vec (constOp (i - 1))
169170
return $ OP_Pair xs (ir singleTp x)
170171

171172
singleTp :: SingleType single -- GHC 8.4 cannot infer this type for some reason
@@ -307,6 +308,7 @@ llvmOfOpenExp top env aenv = cvtE top
307308
PrimEq t -> primbool $ A.uncurry (A.eq t) =<< cvtE x
308309
PrimNEq t -> primbool $ A.uncurry (A.neq t) =<< cvtE x
309310
PrimLNot -> primbool $ A.lnot =<< bool (cvtE x)
311+
PrimVectorIndex v i -> A.uncurry (A.vecIndex v i) =<< cvtE x
310312
-- no missing patterns, whoo!
311313

312314

accelerate-llvm/src/LLVM/AST/Type/Instruction.hs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -182,8 +182,9 @@ data Instruction a where
182182

183183
-- <http://llvm.org/docs/LangRef.html#extractelement-instruction>
184184
--
185-
ExtractElement :: Int32 -- TupleIdx (ProdRepr (Vec n a)) a
185+
ExtractElement :: IntegralType i -- TupleIdx (ProdRepr (Vec n a)) a
186186
-> Operand (Vec n a)
187+
-> Operand i
187188
-> Instruction a
188189

189190
-- <http://llvm.org/docs/LangRef.html#insertelement-instruction>
@@ -406,7 +407,7 @@ instance Downcast (Instruction a) LLVM.Instruction where
406407
BXor _ x y -> LLVM.Xor (downcast x) (downcast y) md
407408
LNot x -> LLVM.Xor (downcast x) (LLVM.ConstantOperand (LLVM.Int 1 1)) md
408409
InsertElement i v x -> LLVM.InsertElement (downcast v) (downcast x) (constant i) md
409-
ExtractElement i v -> LLVM.ExtractElement (downcast v) (constant i) md
410+
ExtractElement _ v i -> LLVM.ExtractElement (downcast v) (downcast i) md
410411
ExtractValue _ i s -> extractStruct i (downcast s)
411412
Load _ v p -> LLVM.Load (downcast v) (downcast p) atomicity alignment md
412413
Store v p x -> LLVM.Store (downcast v) (downcast p) (downcast x) atomicity alignment md
@@ -594,7 +595,7 @@ instance TypeOf Instruction where
594595
LAnd x _ -> typeOf x
595596
LOr x _ -> typeOf x
596597
LNot x -> typeOf x
597-
ExtractElement _ x -> typeOfVec x
598+
ExtractElement _ x _ -> typeOfVec x
598599
InsertElement _ x _ -> typeOf x
599600
ExtractValue t _ _ -> scalar t
600601
Load t _ _ -> scalar t

accelerate-llvm/src/LLVM/AST/Type/Operand.hs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
module LLVM.AST.Type.Operand (
1616

1717
Operand(..),
18+
constOp,
1819

1920
) where
2021

@@ -32,6 +33,9 @@ data Operand a where
3233
LocalReference :: Type a -> Name a -> Operand a
3334
ConstantOperand :: Constant a -> Operand a
3435

36+
constOp :: (IsScalar a) => a -> Operand a
37+
constOp x = ConstantOperand (ScalarConstant scalarType x)
38+
3539

3640
-- | Convert to llvm-hs
3741
--

0 commit comments

Comments
 (0)