Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
joelberkeley committed Aug 29, 2022
1 parent 01e7582 commit 744bd67
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 27 deletions.
73 changes: 50 additions & 23 deletions src/Compiler/Transform.idr
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,21 @@ call (MkFn params (Tuple xs)) args = Tuple (map (\x => call (MkFn params x) args
call (MkFn params (GetTupleElement k x)) args = GetTupleElement k (call (MkFn params x) args)
call (MkFn _ x@MinFiniteValue) _ = x
call (MkFn _ x@MaxFiniteValue) _ = x
call (MkFn params (ConvertElementType {dtype} x)) args = ConvertElementType {dtype} (call (MkFn params x) args)
call (MkFn params (ConvertElementType {dtype} x)) args =
ConvertElementType {dtype} (call (MkFn params x) args)
call (MkFn params (Reshape from to x)) args = Reshape from to (call (MkFn params x) args)
call (MkFn params (Slice starts stops strides x)) args = Slice starts stops strides (call (MkFn params x) args)
call (MkFn params (DynamicSlice starts sizes x)) args = DynamicSlice (map (\x => (call (MkFn params x) args)) starts) sizes (call (MkFn params x) args)
call (MkFn params (Concat k x y)) args = Concat k (call (MkFn params x) args) (call (MkFn params y) args)
call (MkFn params (Slice starts stops strides x)) args =
Slice starts stops strides (call (MkFn params x) args)
call (MkFn params (DynamicSlice starts sizes x)) args =
DynamicSlice (map (\x => (call (MkFn params x) args)) starts) sizes (call (MkFn params x) args)
call (MkFn params (Concat k x y)) args =
Concat k (call (MkFn params x) args) (call (MkFn params y) args)
call (MkFn params (Diag x)) args = Diag (call (MkFn params x) args)
call (MkFn params (Triangle lower x)) args = Triangle lower (call (MkFn params x) args)
call (MkFn params (Transpose axes x)) args = Transpose axes (call (MkFn params x) args)
call (MkFn _ x@(Identity _)) _ = x
call (MkFn params (Broadcast {dtype} from to x)) args = Broadcast {dtype} from to (call (MkFn params x) args)
call (MkFn params (Broadcast {dtype} from to x)) args =
Broadcast {dtype} from to (call (MkFn params x) args)
call (MkFn params (Map x xs ys)) args = ?call_map
call (MkFn params (Reduce x y xs z)) args = ?call_reduce
call (MkFn params (Sort x k y xs)) args = ?call_sort
Expand Down Expand Up @@ -84,18 +89,25 @@ call (MkFn params (Tanh x)) args = Tanh (call (MkFn params x) args)
call (MkFn params (Asinh x)) args = Asinh (call (MkFn params x) args)
call (MkFn params (Acosh x)) args = Acosh (call (MkFn params x) args)
call (MkFn params (Atanh x)) args = Atanh (call (MkFn params x) args)
call (MkFn params (Select pred t f)) args = Select (call (MkFn params pred) args) (call (MkFn params t) args) (call (MkFn params f) args)
call (MkFn params (Select pred t f)) args =
Select (call (MkFn params pred) args) (call (MkFn params t) args) (call (MkFn params f) args)
call (MkFn params (Cond pred ft t ff f)) args = ?call_cond
call (MkFn params (Dot x y)) args = Dot (call (MkFn params x) args) (call (MkFn params y) args)
call (MkFn params (Cholesky x)) args = Cholesky (call (MkFn params x) args)
call (MkFn params (TriangularSolve x y lower)) args = TriangularSolve (call (MkFn params x) args) (call (MkFn params y) args) lower
call (MkFn params (UniformFloatingPoint key state minval maxval shape)) args = UniformFloatingPoint (call (MkFn params key) args) (call (MkFn params state) args) (call (MkFn params minval) args) (call (MkFn params maxval) args) shape
call (MkFn params (NormalFloatingPoint key state shape)) args = NormalFloatingPoint (call (MkFn params key) args) (call (MkFn params state) args) shape
call (MkFn params (TriangularSolve x y lower)) args =
TriangularSolve (call (MkFn params x) args) (call (MkFn params y) args) lower
call (MkFn params (UniformFloatingPoint key state minval maxval shape)) args =
UniformFloatingPoint (call (MkFn params key) args)
(call (MkFn params state) args)
(call (MkFn params minval) args)
(call (MkFn params maxval) args) shape
call (MkFn params (NormalFloatingPoint key state shape)) args =
NormalFloatingPoint (call (MkFn params key) args) (call (MkFn params state) args) shape

-- what if there are parameters in the expression from the surrounding scope? like if we use `vmap` in `sort`?
--
-- there's a traversal going on here that we can abstract. What is it?
export covering
export partial
vmap : Nat -> Fn arity Expr -> Vect arity Expr -> Expr
vmap n (MkFn _ res@(FromLiteral {shape} {dtype} _)) _ = Broadcast {dtype} shape (n :: shape) res
vmap n (MkFn params res@(Parameter _ _ _)) args =
Expand All @@ -105,19 +117,30 @@ vmap n (MkFn params (Tuple xs)) x = ?vmap_tuple
vmap n (MkFn params (GetTupleElement k y)) x = ?vmap_getTupleElement
vmap n (MkFn _ (MinFiniteValue {dtype})) _ = Broadcast {dtype} [] [n] (MinFiniteValue {dtype})
vmap n (MkFn _ (MaxFiniteValue {dtype})) _ = Broadcast {dtype} [] [n] (MinFiniteValue {dtype})
vmap n (MkFn params (ConvertElementType {dtype} y)) x = ConvertElementType {dtype} (vmap n (MkFn params y) x)
vmap n (MkFn params (Reshape from to y)) x = Reshape (n :: from) (n :: to) (vmap n (MkFn params y) x)
vmap n (MkFn params (Slice starts stops strides y)) x = Slice (0 :: starts) (n :: stops) (1 :: strides) (vmap n (MkFn params y) x)
vmap n (MkFn params (DynamicSlice starts sizes y)) x = DynamicSlice (FromLiteral {dtype=U64} (Scalar Z) :: starts) (n :: sizes) (vmap n (MkFn params y) x)
vmap n (MkFn params (Concat axis y z)) x = Concat (S axis) (vmap n (MkFn params y) x) (vmap n (MkFn params z) x)
vmap n (MkFn params (ConvertElementType {dtype} y)) x =
ConvertElementType {dtype} (vmap n (MkFn params y) x)
vmap n (MkFn params (Reshape from to y)) x =
Reshape (n :: from) (n :: to) (vmap n (MkFn params y) x)
vmap n (MkFn params (Slice starts stops strides y)) x =
Slice (0 :: starts) (n :: stops) (1 :: strides) (vmap n (MkFn params y) x)
vmap n (MkFn params (DynamicSlice starts sizes y)) x =
-- DynamicSlice takes scalar arguments `starts`
let starts = (FromLiteral {dtype=U64} (Scalar Z) :: starts)
in DynamicSlice starts (n :: sizes) (vmap n (MkFn params y) x)
vmap n (MkFn params (Concat axis y z)) x =
Concat (S axis) (vmap n (MkFn params y) x) (vmap n (MkFn params z) x)
vmap n (MkFn params (Diag y)) x = Diag (vmap n (MkFn params y) x)
vmap n (MkFn params (Triangle lower y)) x = Triangle lower (vmap n (MkFn params y) x)
vmap n (MkFn params (Transpose axes y)) x = Transpose (0 :: [| S axes |]) (vmap n (MkFn params y) x)
vmap n (MkFn params (Identity {dtype} k)) _ = Broadcast {dtype} [k, k] [n, k, k] (Identity {dtype} k)
vmap n (MkFn params (Broadcast {dtype} from to y)) x = Broadcast {dtype} (n :: from) (n :: to) (vmap n (MkFn params y) x)
vmap n (MkFn params (Identity {dtype} k)) _ =
Broadcast {dtype} [k, k] [n, k, k] (Identity {dtype} k)
vmap n (MkFn params (Broadcast {dtype} from to y)) x =
Broadcast {dtype} (n :: from) (n :: to) (vmap n (MkFn params y) x)
vmap n (MkFn params (Map f operands dimensions)) x = ?vmap_map
vmap n (MkFn params (Reduce f neutral axes y)) x = Reduce f neutral [| S axes |] (vmap n (MkFn params y) x)
vmap n (MkFn params (Sort f dimension isStable ys)) x = Sort f (S dimension) isStable (map (\op => vmap n (MkFn params op) x) ys)
vmap n (MkFn params (Reduce f neutral axes y)) x =
Reduce f neutral [| S axes |] (vmap n (MkFn params y) x)
vmap n (MkFn params (Sort f dimension isStable ys)) x =
Sort f (S dimension) isStable (map (\op => vmap n (MkFn params op) x) ys)
vmap n (MkFn params (Reverse axes y)) x = Reverse [| S axes |] (vmap n (MkFn params y) x)
vmap n (MkFn params (Eq y z)) x = Eq (vmap n (MkFn params y) x) (vmap n (MkFn params z) x)
vmap n (MkFn params (Ne y z)) x = Ne (vmap n (MkFn params y) x) (vmap n (MkFn params z) x)
Expand Down Expand Up @@ -158,10 +181,14 @@ vmap n (MkFn params (Tanh y)) x = Tanh (vmap n (MkFn params y) x)
vmap n (MkFn params (Asinh y)) x = Asinh (vmap n (MkFn params y) x)
vmap n (MkFn params (Acosh y)) x = Acosh (vmap n (MkFn params y) x)
vmap n (MkFn params (Atanh y)) x = Atanh (vmap n (MkFn params y) x)
vmap n (MkFn params (Select pred t f)) x = Select (vmap n (MkFn params pred) x) (vmap n (MkFn params t) x) (vmap n (MkFn params f) x)
vmap n (MkFn params (Cond pred ft t ff f)) x = Select (vmap n (MkFn params $ Broadcast {dtype=PRED} [] ?vmap_cond_pred_shape pred) x) (vmap n (MkFn params $ call ft [t]) x) (vmap n (MkFn params $ call ff [f]) x)
vmap n (MkFn params (Dot y z)) x = ?vmap_dot -- need DotGeneral
vmap n (MkFn params (Select pred t f)) x =
Select (vmap n (MkFn params pred) x) (vmap n (MkFn params t) x) (vmap n (MkFn params f) x)
vmap n (MkFn params (Cond pred ft t ff f)) x =
let condition = (vmap n (MkFn params $ Broadcast {dtype=PRED} [] ?vmap_cond_pred_shape pred) x)
in Select condition (vmap n (MkFn params $ call ft [t]) x) (vmap n (MkFn params $ call ff [f]) x)
vmap n (MkFn params (Dot y z)) x = ?vmap_dot
vmap n (MkFn params (Cholesky y)) x = Cholesky (vmap n (MkFn params y) x)
vmap n (MkFn params (TriangularSolve a b lower)) x = TriangularSolve (vmap n (MkFn params a) x) (vmap n (MkFn params b) x) lower
vmap n (MkFn params (TriangularSolve a b lower)) x =
TriangularSolve (vmap n (MkFn params a) x) (vmap n (MkFn params b) x) lower
vmap n (MkFn params (UniformFloatingPoint y z w v xs)) x = ?vmap_uniformFloatingPoint
vmap n (MkFn params (NormalFloatingPoint y z xs)) x = ?vmap_normalFloatingPoint
9 changes: 5 additions & 4 deletions test/Unit/TestTensor.idr
Original file line number Diff line number Diff line change
Expand Up @@ -548,10 +548,11 @@ vmap = fixedProperty $ do

vmap (\x => reduce @{Sum} [0] x) xs ===# fromLiteral [[2, 4], [10, 8]]

-- let preds = fromLiteral [True, False]
-- vmap (\x => cond x id 1 id 0) preds ===# fromLiteral [1, 0]
-- vmap (\x => cond (fromLiteral True) id x id (fill {shape=[2, 2]} 0)) xs ===# xs
-- vmap (\x => cond (fromLiteral True) (const x) (fill {shape=[]} {dtype=U32} 1) id (fill 0)) xs ===# xs
let preds = fromLiteral [True, False]
vmap (\x => cond x id 1 id 0) preds ===# fromLiteral [1, 0]
vmap (\x => cond (fromLiteral True) id x id (fill {shape=[2, 2]} 0)) xs ===# xs
vmap (\x => cond (fromLiteral True) (const x) (fill {shape=[]} {dtype=U32} 1) id (fill 0)) xs
===# xs

vmap diag xs ===# fromLiteral [[0, 3], [4, 3]]
-- [[2, 3], [0, 1]] + [[0, 3], [4, -2]]
Expand Down

0 comments on commit 744bd67

Please sign in to comment.