Skip to content

Commit

Permalink
add vmap for mapping a function over a leading dimension
Browse files Browse the repository at this point in the history
  • Loading branch information
joelberkeley committed Dec 10, 2023
1 parent 5d31e9c commit 638d857
Show file tree
Hide file tree
Showing 12 changed files with 638 additions and 242 deletions.
1 change: 1 addition & 0 deletions spidr.ipkg
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ modules =
Compiler.Eval,
Compiler.Expr,
Compiler.LiteralRW,
Compiler.Transform,

Compiler.Xla.Prim.TensorFlow.Compiler.Xla.Client.Lib.Arithmetic,
Compiler.Xla.Prim.TensorFlow.Compiler.Xla.Client.Lib.Constants,
Expand Down
33 changes: 16 additions & 17 deletions src/Compiler/Eval.idr
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,10 @@ lookup : Nat -> Computation XlaOp
lookup n = do
case lookup n !get of
Nothing =>
lift $ left (IndexErr "Tried to look up value at index \{show n} but none was found.")
lift $ left (IndexErr "Tried to look up XlaOp at index \{show n} but none was found. Indices: \{show $ keys !get}")
Just op => pure op

interpret : XlaBuilder -> Nat -> Env -> Computation XlaOp
interpret : XlaBuilder -> Nat -> Program -> Computation XlaOp

buildSub : XlaBuilder -> String -> Fn arity -> Computation XlaComputation
buildSub builder name (MkFn params i env) = do
Expand All @@ -78,8 +78,8 @@ buildSub builder name (MkFn params i env) = do

where

interpretParameter : XlaBuilder -> (Nat, Nat, ShapeAndType) -> Computation ()
interpretParameter builder (position, i, MkShapeAndType shape dtype) = do
interpretParameter : XlaBuilder -> (Nat, Nat, FullShape) -> Computation ()
interpretParameter builder (position, i, shape ### dtype) = do
xlaShape <- mkShape {dtype} shape
param <- parameter builder position xlaShape name
put $ insert i param !get
Expand All @@ -104,17 +104,16 @@ enqueue _ (Diag x) = getMatrixDiagonal !(lookup x)
enqueue _ (Triangle tri x) = triangle !(lookup x) tri
enqueue _ (Transpose ordering x) = transpose !(lookup x) ordering
enqueue builder (Identity {dtype} n) = let n = cast n in identityMatrix {dtype} builder n n
enqueue builder (Broadcast {dtype} from to x) =
if elem 0 to && from /= to
then do
literal <- allocLiteral {dtype} to
constantLiteral builder literal
else
let broadcastDims = map (+ length to `minus` length from) $ range $ length from
in broadcastInDim !(lookup x) to broadcastDims
enqueue builder (Map f xs dims) = do
computation <- buildSub builder "computation" f
map builder (toList !(traverse lookup xs)) computation dims
enqueue builder (Broadcast from to x) =
let xlaOp = !(lookup x)
lenFrom = length from
in if elem 0 to && from /= to then
if elem 0 from then reshape xlaOp (range lenFrom) to else do
xlaOp <- slice xlaOp (replicate 0 lenFrom) (replicate 1 lenFrom) (replicate 1 lenFrom)
reshape xlaOp (range lenFrom) to
else
let broadcastDims = map (+ length to `minus` lenFrom) $ range lenFrom
in broadcastInDim xlaOp to broadcastDims
enqueue builder (Reduce f neutral axes x) = do
computation <- buildSub builder "computation" f
reduce !(lookup x) !(lookup neutral) computation axes
Expand Down Expand Up @@ -204,14 +203,14 @@ interpret builder root env = do
interpretExpr (n, expr) = put (insert n !(enqueue builder expr) !get)

export
toString : Nat -> Env -> EitherT Err IO String
toString : Nat -> Program -> EitherT Err IO String
toString root env = do
builder <- mkXlaBuilder "toString"
xlaOp <- evalStateT empty (interpret builder root env)
pure $ opToString builder xlaOp

export
run : PrimitiveRW dtype a => Nat -> Env -> {shape : _} -> EitherT Err IO (Literal shape a)
run : PrimitiveRW dtype a => Nat -> Program -> {shape : _} -> EitherT Err IO (Literal shape a)
run root env = do
builder <- mkXlaBuilder "root"
root <- evalStateT empty (interpret builder root env)
Expand Down
87 changes: 78 additions & 9 deletions src/Compiler/Expr.idr
Original file line number Diff line number Diff line change
Expand Up @@ -25,19 +25,32 @@ import Primitive
import Types
import Util

infix 9 ###

public export
data ShapeAndType : Type where
MkShapeAndType : Shape -> (0 dtype : Type) -> Primitive dtype => ShapeAndType
data FullShape : Type where
(###) : Shape -> (0 dtype : Type) -> Primitive dtype => FullShape

export
new : Ref Nat
new = do
n <- get
put (S n)
pure n

data Expr : Type where

public export 0
Env : Type
Env = SortedMap Nat Expr
ProgramShape : Type
ProgramShape = SortedMap Nat Shape

public export 0
Program : Type
Program = SortedMap Nat Expr

public export
data Fn : Nat -> Type where
MkFn : {arity : _} -> Vect arity (Nat, ShapeAndType) -> Nat -> Env -> Fn arity
MkFn : {arity : _} -> Vect arity (Nat, FullShape) -> Nat -> Program -> Fn arity

public export
data BinaryOp =
Expand All @@ -58,6 +71,25 @@ data BinaryOp =
| Min
| Max

export
Show BinaryOp where
show Eq = "Eq"
show Ne = "Ne"
show Add = "Add"
show Sub = "Sub"
show Mul = "Mul"
show Div = "Div"
show Rem = "Rem"
show Pow = "Pow"
show Lt = "Lt"
show Gt = "Gt"
show Le = "Le"
show Ge = "Ge"
show And = "And"
show Or = "Or"
show Min = "Min"
show Max = "Max"

public export
data UnaryOp =
Not
Expand Down Expand Up @@ -85,6 +117,32 @@ data UnaryOp =
| Acosh
| Atanh

Show UnaryOp where
show Not = "Not"
show Neg = "Neg"
show Reciprocal = "Reciprocal"
show Ceil = "Ceil"
show Floor = "Floor"
show Abs = "Abs"
show Log = "Log"
show Exp = "Exp"
show Logistic = "Logistic"
show Erf = "Erf"
show Square = "Square"
show Sqrt = "Sqrt"
show Sin = "Sin"
show Cos = "Cos"
show Tan = "Tan"
show Asin = "Asin"
show Acos = "Acos"
show Atan = "Atan"
show Sinh = "Sinh"
show Cosh = "Cosh"
show Tanh = "Tanh"
show Asinh = "Asinh"
show Acosh = "Acosh"
show Atanh = "Atanh"

public export
data Expr : Type where
FromLiteral : PrimitiveRW dtype ty => {shape : _} -> Literal shape ty -> Expr
Expand All @@ -104,13 +162,12 @@ data Expr : Type where
Triangle : (lower : Bool) -> Nat -> Expr
Transpose : List Nat -> Nat -> Expr
Identity : Primitive dtype => Nat -> Expr
Broadcast : Primitive dtype => Shape -> Shape -> Nat -> Expr
Map : Fn n -> Vect n Nat -> Shape -> Expr
Broadcast : Shape -> Shape -> Nat -> Expr
Reduce : Fn 2 -> Nat -> List Nat -> Nat -> Expr
Sort : Fn 2 -> Nat -> Bool -> List Nat -> Expr
Reverse : List Nat -> Nat -> Expr
BinaryElementwise : BinaryOp -> Nat -> Nat -> Expr
UnaryElementwise : UnaryOp -> Nat -> Expr
UnaryElementwise : {shape : Shape} -> UnaryOp -> Nat -> Expr
BinaryElementwise : {shape : Shape} -> BinaryOp -> Nat -> Nat -> Expr
Argmin : Primitive out => Nat -> Nat -> Expr
Argmax : Primitive out => Nat -> Nat -> Expr
Select : Nat -> Nat -> Nat -> Expr
Expand All @@ -120,3 +177,15 @@ data Expr : Type where
TriangularSolve : Nat -> Nat -> Bool -> Expr
UniformFloatingPoint : Nat -> Nat -> Nat -> Nat -> Shape -> Expr
NormalFloatingPoint : Nat -> Nat -> Shape -> Expr

export
Show Expr where
show (FromLiteral {shape} _) = "FromLiteral {shape = \{show shape}}"
show (Arg i) = "Arg \{show i}"
show (Diag i) = "Diag \{show i}"
show (Reshape from to i) = "Reshape \{show from} \{show to} \{show i}"
show (Broadcast from to i) = "Broadcast \{show from} \{show to} \{show i}"
show (UnaryElementwise {shape} op i) = "UnaryElementwise \{show op} \{show i}"
show (BinaryElementwise {shape} op i j) = "BinaryElementwise \{show op} \{show i} \{show j}"
show (Concat axis x y) = "Concat \{show axis} \{show x} \{show y}"
show _ = "OtherExpr"
7 changes: 5 additions & 2 deletions src/Compiler/LiteralRW.idr
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@ limitations under the License.
--}
module Compiler.LiteralRW

import Compiler.Xla.TensorFlow.Compiler.Xla.XlaData
import Compiler.Xla.TensorFlow.Compiler.Xla.Literal
import Compiler.Xla.TensorFlow.Compiler.Xla.Shape
import Compiler.Xla.TensorFlow.Compiler.Xla.ShapeUtil
import Compiler.Xla.TensorFlow.Compiler.Xla.XlaData
import Literal
import Util

Expand Down Expand Up @@ -47,7 +49,8 @@ interface Primitive dtype => LiteralRW dtype ty where
export
write : (HasIO io, LiteralRW dtype a) => {shape : _} -> Literal shape a -> io Literal
write xs = liftIO $ do
literal <- allocLiteral {dtype} shape
xlaShape <- mkShape {dtype} shape
literal <- allocLiteral xlaShape
sequence_ [| (\idxs => set {dtype} literal idxs) indexed xs |]
pure literal

Expand Down
Loading

0 comments on commit 638d857

Please sign in to comment.