From 09beec3f33d233f4c35d0cb4972a3ec73b5ba608 Mon Sep 17 00:00:00 2001 From: Joel Berkeley <16429957+joelberkeley@users.noreply.github.com> Date: Mon, 29 Aug 2022 18:44:51 +0100 Subject: [PATCH] xla test --- src/Compiler/Transform.idr | 4 +++- test/Unit/TestTensor.idr | 7 +++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/src/Compiler/Transform.idr b/src/Compiler/Transform.idr index f5ab24792..3bfbc116d 100644 --- a/src/Compiler/Transform.idr +++ b/src/Compiler/Transform.idr @@ -124,7 +124,7 @@ vmap n (MkFn params (Reshape from to y)) x = vmap n (MkFn params (Slice starts stops strides y)) x = Slice (0 :: starts) (n :: stops) (1 :: strides) (vmap n (MkFn params y) x) vmap n (MkFn params (DynamicSlice starts sizes y)) x = - -- DynamicSlice takes scalar arguments `starts` + -- takes scalar arguments let starts = (FromLiteral {dtype=U64} (Scalar Z) :: starts) in DynamicSlice starts (n :: sizes) (vmap n (MkFn params y) x) vmap n (MkFn params (Concat axis y z)) x = @@ -138,8 +138,10 @@ vmap n (MkFn params (Broadcast {dtype} from to y)) x = Broadcast {dtype} (n :: from) (n :: to) (vmap n (MkFn params y) x) vmap n (MkFn params (Map f operands dimensions)) x = ?vmap_map vmap n (MkFn params (Reduce f neutral axes y)) x = + -- takes scalar arguments Reduce f neutral [| S axes |] (vmap n (MkFn params y) x) vmap n (MkFn params (Sort f dimension isStable ys)) x = + -- takes scalar arguments Sort f (S dimension) isStable (map (\op => vmap n (MkFn params op) x) ys) vmap n (MkFn params (Reverse axes y)) x = Reverse [| S axes |] (vmap n (MkFn params y) x) vmap n (MkFn params (Eq y z)) x = Eq (vmap n (MkFn params y) x) (vmap n (MkFn params z) x) diff --git a/test/Unit/TestTensor.idr b/test/Unit/TestTensor.idr index 01aabd5bc..6f2105821 100644 --- a/test/Unit/TestTensor.idr +++ b/test/Unit/TestTensor.idr @@ -1352,6 +1352,12 @@ normalIsReproducible = withTests 20 . property $ do sample ===# sample' +xlaGraphs : Property +xlaGraphs = fixedProperty $ do + let x = fromLiteral {dtype=S32} [0, 1, 2] + y = map (\x => fromLiteral (toLiteral x)) x + y ===# x + export partial group : Group group = MkGroup "Tensor" $ [ @@ -1415,4 +1421,5 @@ group = MkGroup "Tensor" $ [ , ("normal", normal) , ("normal updates seed", normalSeedIsUpdated) , ("normal produces same samples for same seed", normalIsReproducible) + , ("XLA Graph edge cases", xlaGraphs) ]