-
-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add
vmap
for mapping a function over a leading dimension
- Loading branch information
1 parent
8d3ed2d
commit 1533c20
Showing
6 changed files
with
298 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,196 @@ | ||
{-- | ||
Copyright 2022 Joel Berkeley | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
--} | ||
module Compiler.Transform | ||
|
||
import Compiler.Expr | ||
import Compiler.LiteralRW | ||
import Literal | ||
import Primitive | ||
import Types | ||
|
||
partial | ||
call : Fn arity Expr -> Vect arity Expr -> Expr | ||
call (MkFn _ x@(FromLiteral _)) _ = x | ||
call (MkFn params x@(Parameter _ _ _)) args = | ||
let Just idx = findIndex (== x) params | Nothing => ?call_parameterNotFound | ||
in index idx args | ||
call (MkFn params (Tuple xs)) args = Tuple (map (\x => call (MkFn params x) args) xs) | ||
call (MkFn params (GetTupleElement k x)) args = GetTupleElement k (call (MkFn params x) args) | ||
call (MkFn _ x@MinFiniteValue) _ = x | ||
call (MkFn _ x@MaxFiniteValue) _ = x | ||
call (MkFn params (ConvertElementType {dtype} x)) args = | ||
ConvertElementType {dtype} (call (MkFn params x) args) | ||
call (MkFn params (Reshape from to x)) args = Reshape from to (call (MkFn params x) args) | ||
call (MkFn params (Slice starts stops strides x)) args = | ||
Slice starts stops strides (call (MkFn params x) args) | ||
call (MkFn params (DynamicSlice starts sizes x)) args = | ||
DynamicSlice (map (\x => (call (MkFn params x) args)) starts) sizes (call (MkFn params x) args) | ||
call (MkFn params (Concat k x y)) args = | ||
Concat k (call (MkFn params x) args) (call (MkFn params y) args) | ||
call (MkFn params (Diag x)) args = Diag (call (MkFn params x) args) | ||
call (MkFn params (Triangle lower x)) args = Triangle lower (call (MkFn params x) args) | ||
call (MkFn params (Transpose axes x)) args = Transpose axes (call (MkFn params x) args) | ||
call (MkFn _ x@(Identity _)) _ = x | ||
call (MkFn params (Broadcast {dtype} from to x)) args = | ||
Broadcast {dtype} from to (call (MkFn params x) args) | ||
call (MkFn params (Map x xs ys)) args = ?call_map | ||
call (MkFn params (Reduce x y xs z)) args = ?call_reduce | ||
call (MkFn params (Sort x k y xs)) args = ?call_sort | ||
call (MkFn params (Reverse axes x)) args = Reverse axes (call (MkFn params x) args) | ||
call (MkFn params (Eq x y)) args = Eq (call (MkFn params x) args) (call (MkFn params y) args) | ||
call (MkFn params (Ne x y)) args = Ne (call (MkFn params x) args) (call (MkFn params y) args) | ||
call (MkFn params (Add x y)) args = Add (call (MkFn params x) args) (call (MkFn params y) args) | ||
call (MkFn params (Sub x y)) args = Sub (call (MkFn params x) args) (call (MkFn params y) args) | ||
call (MkFn params (Mul x y)) args = Mul (call (MkFn params x) args) (call (MkFn params y) args) | ||
call (MkFn params (Div x y)) args = Div (call (MkFn params x) args) (call (MkFn params y) args) | ||
call (MkFn params (Pow x y)) args = Pow (call (MkFn params x) args) (call (MkFn params y) args) | ||
call (MkFn params (Lt x y)) args = Lt (call (MkFn params x) args) (call (MkFn params y) args) | ||
call (MkFn params (Gt x y)) args = Gt (call (MkFn params x) args) (call (MkFn params y) args) | ||
call (MkFn params (Le x y)) args = Le (call (MkFn params x) args) (call (MkFn params y) args) | ||
call (MkFn params (Ge x y)) args = Ge (call (MkFn params x) args) (call (MkFn params y) args) | ||
call (MkFn params (And x y)) args = And (call (MkFn params x) args) (call (MkFn params y) args) | ||
call (MkFn params (Or x y)) args = Or (call (MkFn params x) args) (call (MkFn params y) args) | ||
call (MkFn params (Min x y)) args = Min (call (MkFn params x) args) (call (MkFn params y) args) | ||
call (MkFn params (Max x y)) args = Max (call (MkFn params x) args) (call (MkFn params y) args) | ||
call (MkFn params (Not x)) args = Not (call (MkFn params x) args) | ||
call (MkFn params (Neg x)) args = Neg (call (MkFn params x) args) | ||
call (MkFn params (Reciprocal x)) args = Reciprocal (call (MkFn params x) args) | ||
call (MkFn params (Abs x)) args = Abs (call (MkFn params x) args) | ||
call (MkFn params (Ceil x)) args = Ceil (call (MkFn params x) args) | ||
call (MkFn params (Floor x)) args = Floor (call (MkFn params x) args) | ||
call (MkFn params (Log x)) args = Log (call (MkFn params x) args) | ||
call (MkFn params (Exp x)) args = Exp (call (MkFn params x) args) | ||
call (MkFn params (Logistic x)) args = Logistic (call (MkFn params x) args) | ||
call (MkFn params (Erf x)) args = Erf (call (MkFn params x) args) | ||
call (MkFn params (Square x)) args = Square (call (MkFn params x) args) | ||
call (MkFn params (Sqrt x)) args = Sqrt (call (MkFn params x) args) | ||
call (MkFn params (Sin x)) args = Sin (call (MkFn params x) args) | ||
call (MkFn params (Cos x)) args = Cos (call (MkFn params x) args) | ||
call (MkFn params (Tan x)) args = Tan (call (MkFn params x) args) | ||
call (MkFn params (Asin x)) args = Asin (call (MkFn params x) args) | ||
call (MkFn params (Acos x)) args = Acos (call (MkFn params x) args) | ||
call (MkFn params (Atan x)) args = Atan (call (MkFn params x) args) | ||
call (MkFn params (Sinh x)) args = Sinh (call (MkFn params x) args) | ||
call (MkFn params (Cosh x)) args = Cosh (call (MkFn params x) args) | ||
call (MkFn params (Tanh x)) args = Tanh (call (MkFn params x) args) | ||
call (MkFn params (Asinh x)) args = Asinh (call (MkFn params x) args) | ||
call (MkFn params (Acosh x)) args = Acosh (call (MkFn params x) args) | ||
call (MkFn params (Atanh x)) args = Atanh (call (MkFn params x) args) | ||
call (MkFn params (Select pred t f)) args = | ||
Select (call (MkFn params pred) args) (call (MkFn params t) args) (call (MkFn params f) args) | ||
call (MkFn params (Cond pred ft t ff f)) args = ?call_cond | ||
call (MkFn params (Dot x y)) args = Dot (call (MkFn params x) args) (call (MkFn params y) args) | ||
call (MkFn params (Cholesky x)) args = Cholesky (call (MkFn params x) args) | ||
call (MkFn params (TriangularSolve x y lower)) args = | ||
TriangularSolve (call (MkFn params x) args) (call (MkFn params y) args) lower | ||
call (MkFn params (UniformFloatingPoint key state minval maxval shape)) args = | ||
UniformFloatingPoint (call (MkFn params key) args) | ||
(call (MkFn params state) args) | ||
(call (MkFn params minval) args) | ||
(call (MkFn params maxval) args) shape | ||
call (MkFn params (NormalFloatingPoint key state shape)) args = | ||
NormalFloatingPoint (call (MkFn params key) args) (call (MkFn params state) args) shape | ||
|
||
-- what if there are parameters in the expression from the surrounding scope? like if we use `vmap` in `sort`? | ||
-- | ||
-- there's a traversal going on here that we can abstract. What is it? | ||
export partial | ||
vmap : Nat -> Fn arity Expr -> Vect arity Expr -> Expr | ||
vmap n (MkFn _ res@(FromLiteral {shape} {dtype} _)) _ = Broadcast {dtype} shape (n :: shape) res | ||
vmap n (MkFn params res@(Parameter _ _ _)) args = | ||
let Just idx = findIndex (== res) params | Nothing => ?vmap_parameterNotFound | ||
in index idx args | ||
vmap n (MkFn params (Tuple xs)) x = ?vmap_tuple | ||
vmap n (MkFn params (GetTupleElement k y)) x = ?vmap_getTupleElement | ||
vmap n (MkFn _ (MinFiniteValue {dtype})) _ = Broadcast {dtype} [] [n] (MinFiniteValue {dtype}) | ||
vmap n (MkFn _ (MaxFiniteValue {dtype})) _ = Broadcast {dtype} [] [n] (MinFiniteValue {dtype}) | ||
vmap n (MkFn params (ConvertElementType {dtype} y)) x = | ||
ConvertElementType {dtype} (vmap n (MkFn params y) x) | ||
vmap n (MkFn params (Reshape from to y)) x = | ||
Reshape (n :: from) (n :: to) (vmap n (MkFn params y) x) | ||
vmap n (MkFn params (Slice starts stops strides y)) x = | ||
Slice (0 :: starts) (n :: stops) (1 :: strides) (vmap n (MkFn params y) x) | ||
vmap n (MkFn params (DynamicSlice starts sizes y)) x = | ||
-- takes scalar arguments | ||
let starts = (FromLiteral {dtype=U64} (Scalar Z) :: starts) | ||
in DynamicSlice starts (n :: sizes) (vmap n (MkFn params y) x) | ||
vmap n (MkFn params (Concat axis y z)) x = | ||
Concat (S axis) (vmap n (MkFn params y) x) (vmap n (MkFn params z) x) | ||
vmap n (MkFn params (Diag y)) x = Diag (vmap n (MkFn params y) x) | ||
vmap n (MkFn params (Triangle lower y)) x = Triangle lower (vmap n (MkFn params y) x) | ||
vmap n (MkFn params (Transpose axes y)) x = Transpose (0 :: [| S axes |]) (vmap n (MkFn params y) x) | ||
vmap n (MkFn params (Identity {dtype} k)) _ = | ||
Broadcast {dtype} [k, k] [n, k, k] (Identity {dtype} k) | ||
vmap n (MkFn params (Broadcast {dtype} from to y)) x = | ||
Broadcast {dtype} (n :: from) (n :: to) (vmap n (MkFn params y) x) | ||
vmap n (MkFn params (Map f operands dimensions)) x = ?vmap_map | ||
vmap n (MkFn params (Reduce f neutral axes y)) x = | ||
-- takes scalar arguments | ||
Reduce f neutral [| S axes |] (vmap n (MkFn params y) x) | ||
vmap n (MkFn params (Sort f dimension isStable ys)) x = | ||
-- takes scalar arguments | ||
Sort f (S dimension) isStable (map (\op => vmap n (MkFn params op) x) ys) | ||
vmap n (MkFn params (Reverse axes y)) x = Reverse [| S axes |] (vmap n (MkFn params y) x) | ||
vmap n (MkFn params (Eq y z)) x = Eq (vmap n (MkFn params y) x) (vmap n (MkFn params z) x) | ||
vmap n (MkFn params (Ne y z)) x = Ne (vmap n (MkFn params y) x) (vmap n (MkFn params z) x) | ||
vmap n (MkFn params (Add y z)) x = Add (vmap n (MkFn params y) x) (vmap n (MkFn params z) x) | ||
vmap n (MkFn params (Sub y z)) x = Sub (vmap n (MkFn params y) x) (vmap n (MkFn params z) x) | ||
vmap n (MkFn params (Mul y z)) x = Mul (vmap n (MkFn params y) x) (vmap n (MkFn params z) x) | ||
vmap n (MkFn params (Div y z)) x = Div (vmap n (MkFn params y) x) (vmap n (MkFn params z) x) | ||
vmap n (MkFn params (Pow y z)) x = Pow (vmap n (MkFn params y) x) (vmap n (MkFn params z) x) | ||
vmap n (MkFn params (Lt y z)) x = Lt (vmap n (MkFn params y) x) (vmap n (MkFn params z) x) | ||
vmap n (MkFn params (Gt y z)) x = Gt (vmap n (MkFn params y) x) (vmap n (MkFn params z) x) | ||
vmap n (MkFn params (Le y z)) x = Le (vmap n (MkFn params y) x) (vmap n (MkFn params z) x) | ||
vmap n (MkFn params (Ge y z)) x = Ge (vmap n (MkFn params y) x) (vmap n (MkFn params z) x) | ||
vmap n (MkFn params (And y z)) x = And (vmap n (MkFn params y) x) (vmap n (MkFn params z) x) | ||
vmap n (MkFn params (Or y z)) x = Or (vmap n (MkFn params y) x) (vmap n (MkFn params z) x) | ||
vmap n (MkFn params (Min y z)) x = Min (vmap n (MkFn params y) x) (vmap n (MkFn params z) x) | ||
vmap n (MkFn params (Max y z)) x = Max (vmap n (MkFn params y) x) (vmap n (MkFn params z) x) | ||
vmap n (MkFn params (Not y)) x = Not (vmap n (MkFn params y) x) | ||
vmap n (MkFn params (Neg y)) x = Neg (vmap n (MkFn params y) x) | ||
vmap n (MkFn params (Reciprocal y)) x = Reciprocal (vmap n (MkFn params y) x) | ||
vmap n (MkFn params (Abs y)) x = Abs (vmap n (MkFn params y) x) | ||
vmap n (MkFn params (Ceil y)) x = Ceil (vmap n (MkFn params y) x) | ||
vmap n (MkFn params (Floor y)) x = Floor (vmap n (MkFn params y) x) | ||
vmap n (MkFn params (Log y)) x = Log (vmap n (MkFn params y) x) | ||
vmap n (MkFn params (Exp y)) x = Exp (vmap n (MkFn params y) x) | ||
vmap n (MkFn params (Logistic y)) x = Logistic (vmap n (MkFn params y) x) | ||
vmap n (MkFn params (Erf y)) x = Erf (vmap n (MkFn params y) x) | ||
vmap n (MkFn params (Square y)) x = Square (vmap n (MkFn params y) x) | ||
vmap n (MkFn params (Sqrt y)) x = Sqrt (vmap n (MkFn params y) x) | ||
vmap n (MkFn params (Sin y)) x = Sin (vmap n (MkFn params y) x) | ||
vmap n (MkFn params (Cos y)) x = Cos (vmap n (MkFn params y) x) | ||
vmap n (MkFn params (Tan y)) x = Tan (vmap n (MkFn params y) x) | ||
vmap n (MkFn params (Asin y)) x = Asin (vmap n (MkFn params y) x) | ||
vmap n (MkFn params (Acos y)) x = Acos (vmap n (MkFn params y) x) | ||
vmap n (MkFn params (Atan y)) x = Atan (vmap n (MkFn params y) x) | ||
vmap n (MkFn params (Sinh y)) x = Sinh (vmap n (MkFn params y) x) | ||
vmap n (MkFn params (Cosh y)) x = Cosh (vmap n (MkFn params y) x) | ||
vmap n (MkFn params (Tanh y)) x = Tanh (vmap n (MkFn params y) x) | ||
vmap n (MkFn params (Asinh y)) x = Asinh (vmap n (MkFn params y) x) | ||
vmap n (MkFn params (Acosh y)) x = Acosh (vmap n (MkFn params y) x) | ||
vmap n (MkFn params (Atanh y)) x = Atanh (vmap n (MkFn params y) x) | ||
vmap n (MkFn params (Select pred t f)) x = | ||
Select (vmap n (MkFn params pred) x) (vmap n (MkFn params t) x) (vmap n (MkFn params f) x) | ||
vmap n (MkFn params (Cond pred ft t ff f)) x = | ||
let condition = (vmap n (MkFn params $ Broadcast {dtype=PRED} [] ?vmap_cond_pred_shape pred) x) | ||
in Select condition (vmap n (MkFn params $ call ft [t]) x) (vmap n (MkFn params $ call ff [f]) x) | ||
vmap n (MkFn params (Dot y z)) x = ?vmap_dot | ||
vmap n (MkFn params (Cholesky y)) x = Cholesky (vmap n (MkFn params y) x) | ||
vmap n (MkFn params (TriangularSolve a b lower)) x = | ||
TriangularSolve (vmap n (MkFn params a) x) (vmap n (MkFn params b) x) lower | ||
vmap n (MkFn params (UniformFloatingPoint y z w v xs)) x = ?vmap_uniformFloatingPoint | ||
vmap n (MkFn params (NormalFloatingPoint y z xs)) x = ?vmap_normalFloatingPoint |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters