Skip to content

Commit

Permalink
add vmap for mapping a function over a leading dimension
Browse files Browse the repository at this point in the history
  • Loading branch information
joelberkeley committed Feb 4, 2024
1 parent bded472 commit b129f31
Show file tree
Hide file tree
Showing 11 changed files with 472 additions and 109 deletions.
1 change: 1 addition & 0 deletions spidr.ipkg
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ modules =
Compiler.Eval,
Compiler.Expr,
Compiler.LiteralRW,
Compiler.Transform,

Compiler.Xla.Prim.TensorFlow.Compiler.Xla.Client.Lib.Arithmetic,
Compiler.Xla.Prim.TensorFlow.Compiler.Xla.Client.Lib.Constants,
Expand Down
91 changes: 77 additions & 14 deletions src/Compiler/Expr.idr
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,18 @@ import Primitive
import Types
import Util

infix 9 ###

public export
data ShapeAndType : Type where
MkShapeAndType : Shape -> (0 dtype : Type) -> Primitive dtype => ShapeAndType
data FullShape : Type where
(###) : Shape -> (0 dtype : Type) -> Primitive dtype => FullShape

export
new : Ref Nat
new = do
n <- get
put (S n)
pure n

public export
data Expr : Type where
Expand Down Expand Up @@ -61,11 +70,15 @@ data Fn : Nat -> Type where
||| @result The position of the function result in the graph.
||| @env The function graph. Includes only nodes in this scope, not outer or inner scope.
MkFn : {arity : _} ->
(params : Vect arity (Nat, ShapeAndType)) ->
(params : Vect arity (Nat, FullShape)) ->
(result : Nat) ->
(env : Env) ->
Fn arity

public export 0
ProgramShape : Type
ProgramShape = SortedMap Nat Shape

public export
data BinaryOp =
Eq
Expand All @@ -85,6 +98,25 @@ data BinaryOp =
| Min
| Max

export
Show BinaryOp where
show Eq = "Eq"
show Ne = "Ne"
show Add = "Add"
show Sub = "Sub"
show Mul = "Mul"
show Div = "Div"
show Rem = "Rem"
show Pow = "Pow"
show Lt = "Lt"
show Gt = "Gt"
show Le = "Le"
show Ge = "Ge"
show And = "And"
show Or = "Or"
show Min = "Min"
show Max = "Max"

public export
data UnaryOp =
Not
Expand Down Expand Up @@ -112,6 +144,32 @@ data UnaryOp =
| Acosh
| Atanh

Show UnaryOp where
show Not = "Not"
show Neg = "Neg"
show Reciprocal = "Reciprocal"
show Ceil = "Ceil"
show Floor = "Floor"
show Abs = "Abs"
show Log = "Log"
show Exp = "Exp"
show Logistic = "Logistic"
show Erf = "Erf"
show Square = "Square"
show Sqrt = "Sqrt"
show Sin = "Sin"
show Cos = "Cos"
show Tan = "Tan"
show Asin = "Asin"
show Acos = "Acos"
show Atan = "Atan"
show Sinh = "Sinh"
show Cosh = "Cosh"
show Tanh = "Tanh"
show Asinh = "Asinh"
show Acosh = "Acosh"
show Atanh = "Atanh"

public export
data Expr : Type where
FromLiteral : PrimitiveRW dtype ty => {shape : _} -> Literal shape ty -> Expr
Expand All @@ -133,18 +191,11 @@ data Expr : Type where
Transpose : List Nat -> Nat -> Expr
Identity : Primitive dtype => Nat -> Expr
Broadcast : Primitive dtype => Shape -> Shape -> Nat -> Expr

||| Apply function `f` with given `arity` over `args`.
|||
||| @f The function to apply.
||| @args The arguments to apply `f` to.
Map : (f : Fn arity) -> (args : Vect arity Nat) -> Shape -> Expr

Reduce : Fn 2 -> Nat -> List Nat -> Nat -> Expr
Sort : Fn 2 -> Nat -> Bool -> List Nat -> Expr
Reverse : List Nat -> Nat -> Expr
BinaryElementwise : BinaryOp -> Nat -> Nat -> Expr
UnaryElementwise : UnaryOp -> Nat -> Expr
UnaryElementwise : {shape : Shape} -> UnaryOp -> Nat -> Expr
BinaryElementwise : {shape : Shape} -> BinaryOp -> Nat -> Nat -> Expr
Argmin : Primitive out => Nat -> Nat -> Expr
Argmax : Primitive out => Nat -> Nat -> Expr
Select : Nat -> Nat -> Nat -> Expr
Expand All @@ -166,7 +217,7 @@ applyN f [] = f
applyN f (x :: xs) = applyN (f x) xs

export
addFn : {arity : _} -> Vect arity ShapeAndType -> FnExpr arity -> State Env (Fn arity)
addFn : {arity : _} -> Vect arity FullShape -> FnExpr arity -> State Env (Fn arity)
addFn params f = do
MkEnv next env <- get
let (subEnv@(MkEnv next _), params, result) = runState (MkEnv next []) $ do
Expand All @@ -177,8 +228,20 @@ addFn params f = do
pure (MkFn params result subEnv)

where
addArg : ShapeAndType -> State Env Nat
addArg : FullShape -> State Env Nat
addArg st = do
MkEnv next env <- get
put (MkEnv (S next) ((next, Arg next) :: env))
pure next

export
Show Expr where
show (FromLiteral {shape} _) = "FromLiteral {shape = \{show shape}}"
show (Arg i) = "Arg \{show i}"
show (Diag i) = "Diag \{show i}"
show (Reshape from to i) = "Reshape \{show from} \{show to} \{show i}"
show (Broadcast from to i) = "Broadcast \{show from} \{show to} \{show i}"
show (UnaryElementwise {shape} op i) = "UnaryElementwise \{show op} \{show i}"
show (BinaryElementwise {shape} op i j) = "BinaryElementwise \{show op} \{show i} \{show j}"
show (Concat axis x y) = "Concat \{show axis} \{show x} \{show y}"
show _ = "OtherExpr"
5 changes: 3 additions & 2 deletions src/Compiler/LiteralRW.idr
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@ limitations under the License.
--}
module Compiler.LiteralRW

import Compiler.Xla.TensorFlow.Compiler.Xla.XlaData
import public Compiler.Xla.TensorFlow.Compiler.Xla.Literal
import Compiler.Xla.TensorFlow.Compiler.Xla.Literal
import Compiler.Xla.TensorFlow.Compiler.Xla.Shape
import Compiler.Xla.TensorFlow.Compiler.Xla.ShapeUtil
import Compiler.Xla.TensorFlow.Compiler.Xla.XlaData
import Literal
import Util

Expand Down
Loading

0 comments on commit b129f31

Please sign in to comment.