Skip to content

Commit

Permalink
xla test
Browse files Browse the repository at this point in the history
  • Loading branch information
joelberkeley committed Aug 31, 2022
1 parent 6157d7a commit f5fba08
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 1 deletion.
4 changes: 3 additions & 1 deletion src/Compiler/Transform.idr
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ vmap n (MkFn params (Reshape from to y)) x =
vmap n (MkFn params (Slice starts stops strides y)) x =
Slice (0 :: starts) (n :: stops) (1 :: strides) (vmap n (MkFn params y) x)
vmap n (MkFn params (DynamicSlice starts sizes y)) x =
-- DynamicSlice takes scalar arguments `starts`
-- takes scalar arguments
let starts = (FromLiteral {dtype=U64} (Scalar Z) :: starts)
in DynamicSlice starts (n :: sizes) (vmap n (MkFn params y) x)
vmap n (MkFn params (Concat axis y z)) x =
Expand All @@ -138,8 +138,10 @@ vmap n (MkFn params (Broadcast {dtype} from to y)) x =
Broadcast {dtype} (n :: from) (n :: to) (vmap n (MkFn params y) x)
vmap n (MkFn params (Map f operands dimensions)) x = ?vmap_map
vmap n (MkFn params (Reduce f neutral axes y)) x =
-- takes scalar arguments
Reduce f neutral [| S axes |] (vmap n (MkFn params y) x)
vmap n (MkFn params (Sort f dimension isStable ys)) x =
-- takes scalar arguments
Sort f (S dimension) isStable (map (\op => vmap n (MkFn params op) x) ys)
vmap n (MkFn params (Reverse axes y)) x = Reverse [| S axes |] (vmap n (MkFn params y) x)
vmap n (MkFn params (Eq y z)) x = Eq (vmap n (MkFn params y) x) (vmap n (MkFn params z) x)
Expand Down
7 changes: 7 additions & 0 deletions test/Unit/TestTensor.idr
Original file line number Diff line number Diff line change
Expand Up @@ -1352,6 +1352,12 @@ normalIsReproducible = withTests 20 . property $ do

sample ===# sample'

xlaGraphs : Property
xlaGraphs = fixedProperty $ do
let x = fromLiteral {dtype=S32} [0, 1, 2]
y = map (\x => fromLiteral (toLiteral x)) x
y ===# x

export partial
group : Group
group = MkGroup "Tensor" $ [
Expand Down Expand Up @@ -1415,4 +1421,5 @@ group = MkGroup "Tensor" $ [
, ("normal", normal)
, ("normal updates seed", normalSeedIsUpdated)
, ("normal produces same samples for same seed", normalIsReproducible)
, ("XLA Graph edge cases", xlaGraphs)
]

0 comments on commit f5fba08

Please sign in to comment.