From 1533c2093c3d675e082c9f77374a0f3b0b465384 Mon Sep 17 00:00:00 2001 From: Joel Berkeley <16429957+joelberkeley@users.noreply.github.com> Date: Sun, 31 Jul 2022 19:06:51 +0100 Subject: [PATCH] add `vmap` for mapping a function over a leading dimension --- spidr.ipkg | 1 + src/Compiler/Eval.idr | 1 + src/Compiler/Transform.idr | 196 +++++++++++++++++++++++++++++++++++++ src/Tensor.idr | 53 ++++++++++ test/Main.idr | 2 +- test/Unit/TestTensor.idr | 47 ++++++++- 6 files changed, 298 insertions(+), 2 deletions(-) create mode 100644 src/Compiler/Transform.idr diff --git a/spidr.ipkg b/spidr.ipkg index 2e7b8fff2..9cfd18c5b 100644 --- a/spidr.ipkg +++ b/spidr.ipkg @@ -11,6 +11,7 @@ modules = Compiler.Eval, Compiler.Expr, Compiler.LiteralRW, + Compiler.Transform, Compiler.Xla.Prim.TensorFlow.Compiler.Xla.Client.Lib.Constants, Compiler.Xla.Prim.TensorFlow.Compiler.Xla.Client.Lib.Math, diff --git a/src/Compiler/Eval.idr b/src/Compiler/Eval.idr index 05437b190..8760c98bd 100644 --- a/src/Compiler/Eval.idr +++ b/src/Compiler/Eval.idr @@ -19,6 +19,7 @@ import Control.Monad.State import Data.List import Data.List.Elem import Data.SortedMap +import Debug.Trace import Decidable.Equality import Data.Hashable diff --git a/src/Compiler/Transform.idr b/src/Compiler/Transform.idr new file mode 100644 index 000000000..3bfbc116d --- /dev/null +++ b/src/Compiler/Transform.idr @@ -0,0 +1,196 @@ +{-- +Copyright 2022 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +--} +module Compiler.Transform + +import Compiler.Expr +import Compiler.LiteralRW +import Literal +import Primitive +import Types + +partial +call : Fn arity Expr -> Vect arity Expr -> Expr +call (MkFn _ x@(FromLiteral _)) _ = x +call (MkFn params x@(Parameter _ _ _)) args = + let Just idx = findIndex (== x) params | Nothing => ?call_parameterNotFound + in index idx args +call (MkFn params (Tuple xs)) args = Tuple (map (\x => call (MkFn params x) args) xs) +call (MkFn params (GetTupleElement k x)) args = GetTupleElement k (call (MkFn params x) args) +call (MkFn _ x@MinFiniteValue) _ = x +call (MkFn _ x@MaxFiniteValue) _ = x +call (MkFn params (ConvertElementType {dtype} x)) args = + ConvertElementType {dtype} (call (MkFn params x) args) +call (MkFn params (Reshape from to x)) args = Reshape from to (call (MkFn params x) args) +call (MkFn params (Slice starts stops strides x)) args = + Slice starts stops strides (call (MkFn params x) args) +call (MkFn params (DynamicSlice starts sizes x)) args = + DynamicSlice (map (\x => (call (MkFn params x) args)) starts) sizes (call (MkFn params x) args) +call (MkFn params (Concat k x y)) args = + Concat k (call (MkFn params x) args) (call (MkFn params y) args) +call (MkFn params (Diag x)) args = Diag (call (MkFn params x) args) +call (MkFn params (Triangle lower x)) args = Triangle lower (call (MkFn params x) args) +call (MkFn params (Transpose axes x)) args = Transpose axes (call (MkFn params x) args) +call (MkFn _ x@(Identity _)) _ = x +call (MkFn params (Broadcast {dtype} from to x)) args = + Broadcast {dtype} from to (call (MkFn params x) args) +call (MkFn params (Map x xs ys)) args = ?call_map +call (MkFn params (Reduce x y xs z)) args = ?call_reduce +call (MkFn params (Sort x k y xs)) args = ?call_sort +call (MkFn params (Reverse axes x)) args = Reverse axes (call (MkFn params x) args) +call (MkFn params (Eq x y)) args = Eq (call (MkFn params x) args) (call (MkFn params y) args) +call (MkFn params (Ne x y)) args = Ne (call (MkFn params x) args) (call (MkFn params y) args) +call (MkFn params (Add x y)) args = Add (call (MkFn params x) args) (call (MkFn params y) args) +call (MkFn params (Sub x y)) args = Sub (call (MkFn params x) args) (call (MkFn params y) args) +call (MkFn params (Mul x y)) args = Mul (call (MkFn params x) args) (call (MkFn params y) args) +call (MkFn params (Div x y)) args = Div (call (MkFn params x) args) (call (MkFn params y) args) +call (MkFn params (Pow x y)) args = Pow (call (MkFn params x) args) (call (MkFn params y) args) +call (MkFn params (Lt x y)) args = Lt (call (MkFn params x) args) (call (MkFn params y) args) +call (MkFn params (Gt x y)) args = Gt (call (MkFn params x) args) (call (MkFn params y) args) +call (MkFn params (Le x y)) args = Le (call (MkFn params x) args) (call (MkFn params y) args) +call (MkFn params (Ge x y)) args = Ge (call (MkFn params x) args) (call (MkFn params y) args) +call (MkFn params (And x y)) args = And (call (MkFn params x) args) (call (MkFn params y) args) +call (MkFn params (Or x y)) args = Or (call (MkFn params x) args) (call (MkFn params y) args) +call (MkFn params (Min x y)) args = Min (call (MkFn params x) args) (call (MkFn params y) args) +call (MkFn params (Max x y)) args = Max (call (MkFn params x) args) (call (MkFn params y) args) +call (MkFn params (Not x)) args = Not (call (MkFn params x) args) +call (MkFn params (Neg x)) args = Neg (call (MkFn params x) args) +call (MkFn params (Reciprocal x)) args = Reciprocal (call (MkFn params x) args) +call (MkFn params (Abs x)) args = Abs (call (MkFn params x) args) +call (MkFn params (Ceil x)) args = Ceil (call (MkFn params x) args) +call (MkFn params (Floor x)) args = Floor (call (MkFn params x) args) +call (MkFn params (Log x)) args = Log (call (MkFn params x) args) +call (MkFn params (Exp x)) args = Exp (call (MkFn params x) args) +call (MkFn params (Logistic x)) args = Logistic (call (MkFn params x) args) +call (MkFn params (Erf x)) args = Erf (call (MkFn params x) args) +call (MkFn params (Square x)) args = Square (call (MkFn params x) args) +call (MkFn params (Sqrt x)) args = Sqrt (call (MkFn params x) args) +call (MkFn params (Sin x)) args = Sin (call (MkFn params x) args) +call (MkFn params (Cos x)) args = Cos (call (MkFn params x) args) +call (MkFn params (Tan x)) args = Tan (call (MkFn params x) args) +call (MkFn params (Asin x)) args = Asin (call (MkFn params x) args) +call (MkFn params (Acos x)) args = Acos (call (MkFn params x) args) +call (MkFn params (Atan x)) args = Atan (call (MkFn params x) args) +call (MkFn params (Sinh x)) args = Sinh (call (MkFn params x) args) +call (MkFn params (Cosh x)) args = Cosh (call (MkFn params x) args) +call (MkFn params (Tanh x)) args = Tanh (call (MkFn params x) args) +call (MkFn params (Asinh x)) args = Asinh (call (MkFn params x) args) +call (MkFn params (Acosh x)) args = Acosh (call (MkFn params x) args) +call (MkFn params (Atanh x)) args = Atanh (call (MkFn params x) args) +call (MkFn params (Select pred t f)) args = + Select (call (MkFn params pred) args) (call (MkFn params t) args) (call (MkFn params f) args) +call (MkFn params (Cond pred ft t ff f)) args = ?call_cond +call (MkFn params (Dot x y)) args = Dot (call (MkFn params x) args) (call (MkFn params y) args) +call (MkFn params (Cholesky x)) args = Cholesky (call (MkFn params x) args) +call (MkFn params (TriangularSolve x y lower)) args = + TriangularSolve (call (MkFn params x) args) (call (MkFn params y) args) lower +call (MkFn params (UniformFloatingPoint key state minval maxval shape)) args = + UniformFloatingPoint (call (MkFn params key) args) + (call (MkFn params state) args) + (call (MkFn params minval) args) + (call (MkFn params maxval) args) shape +call (MkFn params (NormalFloatingPoint key state shape)) args = + NormalFloatingPoint (call (MkFn params key) args) (call (MkFn params state) args) shape + +-- what if there are parameters in the expression from the surrounding scope? like if we use `vmap` in `sort`? +-- +-- there's a traversal going on here that we can abstract. What is it? +export partial +vmap : Nat -> Fn arity Expr -> Vect arity Expr -> Expr +vmap n (MkFn _ res@(FromLiteral {shape} {dtype} _)) _ = Broadcast {dtype} shape (n :: shape) res +vmap n (MkFn params res@(Parameter _ _ _)) args = + let Just idx = findIndex (== res) params | Nothing => ?vmap_parameterNotFound + in index idx args +vmap n (MkFn params (Tuple xs)) x = ?vmap_tuple +vmap n (MkFn params (GetTupleElement k y)) x = ?vmap_getTupleElement +vmap n (MkFn _ (MinFiniteValue {dtype})) _ = Broadcast {dtype} [] [n] (MinFiniteValue {dtype}) +vmap n (MkFn _ (MaxFiniteValue {dtype})) _ = Broadcast {dtype} [] [n] (MinFiniteValue {dtype}) +vmap n (MkFn params (ConvertElementType {dtype} y)) x = + ConvertElementType {dtype} (vmap n (MkFn params y) x) +vmap n (MkFn params (Reshape from to y)) x = + Reshape (n :: from) (n :: to) (vmap n (MkFn params y) x) +vmap n (MkFn params (Slice starts stops strides y)) x = + Slice (0 :: starts) (n :: stops) (1 :: strides) (vmap n (MkFn params y) x) +vmap n (MkFn params (DynamicSlice starts sizes y)) x = + -- takes scalar arguments + let starts = (FromLiteral {dtype=U64} (Scalar Z) :: starts) + in DynamicSlice starts (n :: sizes) (vmap n (MkFn params y) x) +vmap n (MkFn params (Concat axis y z)) x = + Concat (S axis) (vmap n (MkFn params y) x) (vmap n (MkFn params z) x) +vmap n (MkFn params (Diag y)) x = Diag (vmap n (MkFn params y) x) +vmap n (MkFn params (Triangle lower y)) x = Triangle lower (vmap n (MkFn params y) x) +vmap n (MkFn params (Transpose axes y)) x = Transpose (0 :: [| S axes |]) (vmap n (MkFn params y) x) +vmap n (MkFn params (Identity {dtype} k)) _ = + Broadcast {dtype} [k, k] [n, k, k] (Identity {dtype} k) +vmap n (MkFn params (Broadcast {dtype} from to y)) x = + Broadcast {dtype} (n :: from) (n :: to) (vmap n (MkFn params y) x) +vmap n (MkFn params (Map f operands dimensions)) x = ?vmap_map +vmap n (MkFn params (Reduce f neutral axes y)) x = + -- takes scalar arguments + Reduce f neutral [| S axes |] (vmap n (MkFn params y) x) +vmap n (MkFn params (Sort f dimension isStable ys)) x = + -- takes scalar arguments + Sort f (S dimension) isStable (map (\op => vmap n (MkFn params op) x) ys) +vmap n (MkFn params (Reverse axes y)) x = Reverse [| S axes |] (vmap n (MkFn params y) x) +vmap n (MkFn params (Eq y z)) x = Eq (vmap n (MkFn params y) x) (vmap n (MkFn params z) x) +vmap n (MkFn params (Ne y z)) x = Ne (vmap n (MkFn params y) x) (vmap n (MkFn params z) x) +vmap n (MkFn params (Add y z)) x = Add (vmap n (MkFn params y) x) (vmap n (MkFn params z) x) +vmap n (MkFn params (Sub y z)) x = Sub (vmap n (MkFn params y) x) (vmap n (MkFn params z) x) +vmap n (MkFn params (Mul y z)) x = Mul (vmap n (MkFn params y) x) (vmap n (MkFn params z) x) +vmap n (MkFn params (Div y z)) x = Div (vmap n (MkFn params y) x) (vmap n (MkFn params z) x) +vmap n (MkFn params (Pow y z)) x = Pow (vmap n (MkFn params y) x) (vmap n (MkFn params z) x) +vmap n (MkFn params (Lt y z)) x = Lt (vmap n (MkFn params y) x) (vmap n (MkFn params z) x) +vmap n (MkFn params (Gt y z)) x = Gt (vmap n (MkFn params y) x) (vmap n (MkFn params z) x) +vmap n (MkFn params (Le y z)) x = Le (vmap n (MkFn params y) x) (vmap n (MkFn params z) x) +vmap n (MkFn params (Ge y z)) x = Ge (vmap n (MkFn params y) x) (vmap n (MkFn params z) x) +vmap n (MkFn params (And y z)) x = And (vmap n (MkFn params y) x) (vmap n (MkFn params z) x) +vmap n (MkFn params (Or y z)) x = Or (vmap n (MkFn params y) x) (vmap n (MkFn params z) x) +vmap n (MkFn params (Min y z)) x = Min (vmap n (MkFn params y) x) (vmap n (MkFn params z) x) +vmap n (MkFn params (Max y z)) x = Max (vmap n (MkFn params y) x) (vmap n (MkFn params z) x) +vmap n (MkFn params (Not y)) x = Not (vmap n (MkFn params y) x) +vmap n (MkFn params (Neg y)) x = Neg (vmap n (MkFn params y) x) +vmap n (MkFn params (Reciprocal y)) x = Reciprocal (vmap n (MkFn params y) x) +vmap n (MkFn params (Abs y)) x = Abs (vmap n (MkFn params y) x) +vmap n (MkFn params (Ceil y)) x = Ceil (vmap n (MkFn params y) x) +vmap n (MkFn params (Floor y)) x = Floor (vmap n (MkFn params y) x) +vmap n (MkFn params (Log y)) x = Log (vmap n (MkFn params y) x) +vmap n (MkFn params (Exp y)) x = Exp (vmap n (MkFn params y) x) +vmap n (MkFn params (Logistic y)) x = Logistic (vmap n (MkFn params y) x) +vmap n (MkFn params (Erf y)) x = Erf (vmap n (MkFn params y) x) +vmap n (MkFn params (Square y)) x = Square (vmap n (MkFn params y) x) +vmap n (MkFn params (Sqrt y)) x = Sqrt (vmap n (MkFn params y) x) +vmap n (MkFn params (Sin y)) x = Sin (vmap n (MkFn params y) x) +vmap n (MkFn params (Cos y)) x = Cos (vmap n (MkFn params y) x) +vmap n (MkFn params (Tan y)) x = Tan (vmap n (MkFn params y) x) +vmap n (MkFn params (Asin y)) x = Asin (vmap n (MkFn params y) x) +vmap n (MkFn params (Acos y)) x = Acos (vmap n (MkFn params y) x) +vmap n (MkFn params (Atan y)) x = Atan (vmap n (MkFn params y) x) +vmap n (MkFn params (Sinh y)) x = Sinh (vmap n (MkFn params y) x) +vmap n (MkFn params (Cosh y)) x = Cosh (vmap n (MkFn params y) x) +vmap n (MkFn params (Tanh y)) x = Tanh (vmap n (MkFn params y) x) +vmap n (MkFn params (Asinh y)) x = Asinh (vmap n (MkFn params y) x) +vmap n (MkFn params (Acosh y)) x = Acosh (vmap n (MkFn params y) x) +vmap n (MkFn params (Atanh y)) x = Atanh (vmap n (MkFn params y) x) +vmap n (MkFn params (Select pred t f)) x = + Select (vmap n (MkFn params pred) x) (vmap n (MkFn params t) x) (vmap n (MkFn params f) x) +vmap n (MkFn params (Cond pred ft t ff f)) x = + let condition = (vmap n (MkFn params $ Broadcast {dtype=PRED} [] ?vmap_cond_pred_shape pred) x) + in Select condition (vmap n (MkFn params $ call ft [t]) x) (vmap n (MkFn params $ call ff [f]) x) +vmap n (MkFn params (Dot y z)) x = ?vmap_dot +vmap n (MkFn params (Cholesky y)) x = Cholesky (vmap n (MkFn params y) x) +vmap n (MkFn params (TriangularSolve a b lower)) x = + TriangularSolve (vmap n (MkFn params a) x) (vmap n (MkFn params b) x) lower +vmap n (MkFn params (UniformFloatingPoint y z w v xs)) x = ?vmap_uniformFloatingPoint +vmap n (MkFn params (NormalFloatingPoint y z xs)) x = ?vmap_normalFloatingPoint diff --git a/src/Tensor.idr b/src/Tensor.idr index 0ec6f88a3..8e6078a59 100644 --- a/src/Tensor.idr +++ b/src/Tensor.idr @@ -28,6 +28,7 @@ import Data.Hashable import Compiler.Eval import Compiler.Expr import Compiler.LiteralRW +import Compiler.Transform import Literal import public Primitive import public Types @@ -580,6 +581,58 @@ fill = broadcast {shapesOK=scalarToAnyOk shape} . fromLiteral . Scalar ----------------------------- generic operations ---------------------------- +||| Apply a function between `Tensor`s to the trailing dimensions of a `Tensor`. For example, for +||| ``` +||| x : Tensor [2, 3, 3] S32 +||| x = const [[[ 0, 1, 2], +||| [ 3, 4, 5], +||| [ 6, 7, 8]], +||| [[ 9, 10, 11], +||| [12, 13, 14], +||| [15, 16, 17]]] +||| ``` +||| `vmap diag x` is equivalent to `const [[0, 4, 8], [9, 13, 17]]`. +export partial +vmap : + Primitive a => + (Tensor from a -> Tensor to b) -> + Tensor (n :: from) a -> Tensor (n :: to) b +vmap f (MkTensor {shape=n :: from} expr) = + let param = Parameter 0 from {dtype=a} "" + MkTensor fres = f (MkTensor param) + in MkTensor (vmap n (MkFn [param] fres) [expr]) + +namespace Binary + ||| `vmap` for mapping over binary functions. + export partial + vmap : + (Primitive d0, Primitive d1) => + (Tensor s0 d0 -> Tensor s1 d1 -> Tensor s2 d2) -> + Tensor (n :: s0) d0 -> Tensor (n :: s1) d1 -> Tensor (n :: s2) d2 + vmap f (MkTensor {shape=n :: s0} expr0) (MkTensor {shape=n :: s1} expr1) = + let p0 = Parameter 0 s0 {dtype=d0} "" + p1 = Parameter 1 s1 {dtype=d1} "" + MkTensor fres = f (MkTensor p0) (MkTensor p1) + in MkTensor (vmap n (MkFn [p0, p1] fres) [expr0, expr1]) + +namespace Ternary + ||| `vmap` for mapping over ternary functions. + export partial + vmap : + (Primitive d0, Primitive d1, Primitive d2) => + (Tensor s0 d0 -> Tensor s1 d1 -> Tensor s2 d2 -> Tensor s3 d3) -> + Tensor (n :: s0) d0 -> Tensor (n :: s1) d1 -> Tensor (n :: s2) d2 -> Tensor (n :: s3) d3 + vmap + f + (MkTensor {shape=n :: s0} expr0) + (MkTensor {shape=n :: s1} expr1) + (MkTensor {shape=n :: s2} expr2) = + let p0 = Parameter 0 s0 {dtype=d0} "" + p1 = Parameter 1 s1 {dtype=d1} "" + p2 = Parameter 2 s2 {dtype=d2} "" + MkTensor fres = f (MkTensor p0) (MkTensor p1) (MkTensor p2) + in MkTensor (vmap n (MkFn [p0, p1, p2] fres) [expr0, expr1, expr2]) + ||| Lift a unary function on scalars to an element-wise function on `Tensor`s of arbitrary shape. ||| For example, ||| ```idris diff --git a/test/Main.idr b/test/Main.idr index 3417af665..1384cba80 100644 --- a/test/Main.idr +++ b/test/Main.idr @@ -28,7 +28,7 @@ import Unit.TestTensor import Unit.TestLiteral import Unit.TestUtil -covering +partial main : IO () main = test [ Utils.TestComparison.group diff --git a/test/Unit/TestTensor.idr b/test/Unit/TestTensor.idr index b558838f2..6f2105821 100644 --- a/test/Unit/TestTensor.idr +++ b/test/Unit/TestTensor.idr @@ -528,6 +528,43 @@ transpose = fixedProperty $ do slice [all, at 1, at 0] (transpose [0, 2, 1, 3] x) ===# slice [all, at 0, at 1] x slice [at 2, at 4, at 0, at 1] (transpose [2, 3, 1, 0] x) ===# slice [at 1, at 0, at 2, at 4] x +partial +vmap : Property +vmap = fixedProperty $ do + let xs = fromLiteral {dtype=S32} [[[0, 1], [2, 3]], [[4, 5], [6, 3]]] + y = fromLiteral {dtype=S32} [[4, -2], [5, 1]] + vmap (\x => x - y) xs ===# fromLiteral [[[-4, 3], [-3, 2]], [[0, 7], [1, 2]]] + vmap (y -) xs ===# fromLiteral [[[4, -3], [3, -2]], [[0, -7], [-1, -2]]] + vmap (+ y) xs ===# fromLiteral [[[4, -1], [7, 4]], [[8, 3], [11, 4]]] + vmap (y +) xs ===# fromLiteral [[[4, -1], [7, 4]], [[8, 3], [11, 4]]] + vmap (const y) xs ===# broadcast y + + vmap (\x => concat 0 y x) xs ===# fromLiteral [ + [[4, -2], [5, 1], [0, 1], [2, 3]], [[4, -2], [5, 1], [4, 5], [6, 3]] + ] + vmap (\x => concat 1 x y) xs ===# fromLiteral [ + [[0, 1, 4, -2], [2, 3, 5, 1]], [[4, 5, 4, -2], [6, 3, 5, 1]] + ] + + vmap (\x => reduce @{Sum} [0] x) xs ===# fromLiteral [[2, 4], [10, 8]] + + let preds = fromLiteral [True, False] + vmap (\x => cond x id 1 id 0) preds ===# fromLiteral [1, 0] + vmap (\x => cond (fromLiteral True) id x id (fill {shape=[2, 2]} 0)) xs ===# xs + vmap (\x => cond (fromLiteral True) (const x) (fill {shape=[]} {dtype=U32} 1) id (fill 0)) xs + ===# xs + + vmap diag xs ===# fromLiteral [[0, 3], [4, 3]] + -- [[2, 3], [0, 1]] + [[0, 3], [4, -2]] + -- [[6, 3], [4, 5]] + [[4, 3]], [4, -2]] + vmap (\x => reverse [0] x + concat 0 (expand 0 (diag x)) (slice [0.to 1] y)) xs ===# + fromLiteral [[[2, 6], [4, -1]], [[10, 6], [8, 3]]] + + let a = fromLiteral [[[1.0, 0.0], [-3.0, 2.2]], [[-2.0, 0.0], [-2.5, 1.5]]] + x = fromLiteral [[1.1, -1.2], [2.0, 2.2]] + b = fromLiteral [[1.1, -5.94], [-4.0, -1.7]] + vmap (|\) a b ===# x + covering mapResult : Property mapResult = property $ do @@ -1315,7 +1352,13 @@ normalIsReproducible = withTests 20 . property $ do sample ===# sample' -export covering +xlaGraphs : Property +xlaGraphs = fixedProperty $ do + let x = fromLiteral {dtype=S32} [0, 1, 2] + y = map (\x => fromLiteral (toLiteral x)) x + y ===# x + +export partial group : Group group = MkGroup "Tensor" $ [ ("toLiteral . fromLiteral", fromLiteralThenToLiteral) @@ -1335,6 +1378,7 @@ group = MkGroup "Tensor" $ [ , ("squeeze", squeeze) , ("(.T)", (.T)) , ("transpose", transpose) + , ("vmap", vmap) , ("map", mapResult) , ("map with non-trivial function", mapNonTrivial) , ("map2", map2Result) @@ -1377,4 +1421,5 @@ group = MkGroup "Tensor" $ [ , ("normal", normal) , ("normal updates seed", normalSeedIsUpdated) , ("normal produces same samples for same seed", normalIsReproducible) + , ("XLA Graph edge cases", xlaGraphs) ]