Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
joelberkeley committed Jun 3, 2023
1 parent 696ecde commit 62bf752
Show file tree
Hide file tree
Showing 4 changed files with 186 additions and 133 deletions.
10 changes: 5 additions & 5 deletions src/Compiler/Eval.idr
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ lookup n = do
lift $ left (IndexErr "Tried to look up XlaOp at index \{show n} but none was found. Indices: \{show $ keys !get}")
Just op => pure op

interpret : XlaBuilder -> Nat -> Env -> Computation XlaOp
interpret : XlaBuilder -> Nat -> Program -> Computation XlaOp

buildSub : XlaBuilder -> String -> Fn arity -> Computation XlaComputation
buildSub builder name (MkFn params i env) = do
Expand All @@ -78,8 +78,8 @@ buildSub builder name (MkFn params i env) = do

where

interpretParameter : XlaBuilder -> (Nat, Nat, ShapeAndType) -> Computation ()
interpretParameter builder (position, i, MkShapeAndType shape dtype) = do
interpretParameter : XlaBuilder -> (Nat, Nat, FullShape) -> Computation ()
interpretParameter builder (position, i, shape ### dtype) = do
xlaShape <- mkShape {dtype} shape
param <- parameter builder position xlaShape name
put $ insert i param !get
Expand Down Expand Up @@ -202,14 +202,14 @@ interpret builder root env = do
interpretExpr (n, expr) = put (insert n !(enqueue builder expr) !get)

export
toString : Nat -> Env -> EitherT Err IO String
toString : Nat -> Program -> EitherT Err IO String
toString root env = do
builder <- mkXlaBuilder "toString"
xlaOp <- evalStateT empty (interpret builder root env)
pure $ opToString builder xlaOp

export
run : PrimitiveRW dtype a => Nat -> Env -> {shape : _} -> EitherT Err IO (Literal shape a)
run : PrimitiveRW dtype a => Nat -> Program -> {shape : _} -> EitherT Err IO (Literal shape a)
run root env = do
builder <- mkXlaBuilder "root"
root <- evalStateT empty (interpret builder root env)
Expand Down
16 changes: 11 additions & 5 deletions src/Compiler/Expr.idr
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,11 @@ import Primitive
import Types
import Util

infix 9 ###

public export
data ShapeAndType : Type where
MkShapeAndType : Shape -> (0 dtype : Type) -> Primitive dtype => ShapeAndType
data FullShape : Type where
(###) : Shape -> (0 dtype : Type) -> Primitive dtype => FullShape

export
new : Ref Nat
Expand All @@ -39,12 +41,16 @@ new = do
data Expr : Type where

public export 0
Env : Type
Env = SortedMap Nat Expr
ProgramShape : Type
ProgramShape = SortedMap Nat Shape

public export 0
Program : Type
Program = SortedMap Nat Expr

public export
data Fn : Nat -> Type where
MkFn : {arity : _} -> Vect arity (Nat, ShapeAndType) -> Nat -> Env -> Fn arity
MkFn : {arity : _} -> Vect arity (Nat, FullShape) -> Nat -> Program -> Fn arity

public export
data BinaryOp =
Expand Down
37 changes: 29 additions & 8 deletions src/Compiler/Transform.idr
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,34 @@ record Acc where
||| Keys are indices of nodes in the original
metadata : SortedMap Nat Value
||| the resulting graph
graph : Env

graph : Program

||| Traverse the `program` in sorted order. For each `Expr` in the graph, inspect the nodes it is
||| built from. Each node it is built from either
||| * does not exist in `program`. This means that it comes from the global scope, is therefore
||| constant with respect to the `vmap` argument, and we simply broadcast the value using the
||| shape extracted from `programShape`.
||| * exists in `program`, in which case ...
||| If a node is built from only constant nodes, it is also constant.
|||
||| @res A pointer to the return value of the original function.
||| @n The size of the vmap-ed dimension.
||| @param A pointer to the parameter in the `vmap`-ed function.
||| @arg A pointer to the argument to `vmap`.
||| @to The return shape of the function to vmap.
||| @localProgram The program to vmap. We vecotrize the whole of this, so this should not include
||| the whole global program, just the local program containing all values dependent on the value
||| we vmap over.
||| @globalProgramShape The shape of the whole global program.
export partial
vmap : (res, n, param, arg : Nat) -> (to : Shape) -> (graph : Env) -> Ref (Env, Nat)
vmap res n param arg to original = do
vmap : (res, n, param, arg : Nat) ->
(to : Shape) ->
(localProgram : Program) ->
(globalProgramShape : ProgramShape) ->
Ref (Program, Nat)
vmap res n param arg to localProgram globalProgramShape = do
foo <- runEitherT $ do
acc <- recurse (toList original) impl (MkAcc empty empty)
acc <- recurse (toList localProgram) impl (MkAcc empty empty)
case lookup res acc.metadata `or` idris_crash "\{show res} \{show (keys acc.metadata)}" of
Var i => pure (acc.graph, i)
Const => lift new <&> \j => (insert j (Broadcast to (n :: to) res) acc.graph, j)
Expand Down Expand Up @@ -140,14 +161,14 @@ vmap res n param arg to original = do
||| @n The size of the extra dimensions we're mapping over.
||| @arg The index of the argument to replace
export covering
vmap : (res, n, arg : Nat) -> (unvmapped : Env) -> Expr -> Ref (Env, Nat)
vmap : (res, n, arg : Nat) -> (unvmapped : Program) -> Expr -> Ref (Program, Nat)
vmap res n arg unvmapped expr = runStateT empty (impl expr)
where
impl : Expr -> StateT Env Ref Nat
impl : Expr -> StateT Program Ref Nat
recurse : Shape -> Nat -> StateT Env Ref Nat
recurse : Shape -> Nat -> StateT Program Ref Nat
recurse shape j =
case lookup j unvmapped of
Just expr => impl expr
Expand Down
Loading

0 comments on commit 62bf752

Please sign in to comment.