From 63a512e66e9eec260b73c4fdde58505b392643af Mon Sep 17 00:00:00 2001 From: Joel Berkeley <16429957+joelberkeley@users.noreply.github.com> Date: Sun, 21 Jan 2024 21:32:59 +0000 Subject: [PATCH] add iota for range-like tensors (#385) --- backend/VERSION | 2 +- .../compiler/xla/client/xla_builder.cpp | 7 ++ .../compiler/xla/client/xla_builder.h | 2 + src/Compiler/Eval.idr | 1 + src/Compiler/Expr.idr | 1 + .../Compiler/Xla/Client/XlaBuilder.idr | 4 ++ .../Compiler/Xla/Client/XlaBuilder.idr | 7 ++ src/Tensor.idr | 72 ++++++++++++++----- test.ipkg | 1 + test/Unit/TestTensor.idr | 55 ++++++++++++++ test/Utils/Proof.idr | 28 ++++++++ 11 files changed, 163 insertions(+), 17 deletions(-) create mode 100644 test/Utils/Proof.idr diff --git a/backend/VERSION b/backend/VERSION index d169b2f2d..c5d54ec32 100644 --- a/backend/VERSION +++ b/backend/VERSION @@ -1 +1 @@ -0.0.8 +0.0.9 diff --git a/backend/src/tensorflow/compiler/xla/client/xla_builder.cpp b/backend/src/tensorflow/compiler/xla/client/xla_builder.cpp index d2bf1d091..e4baf9d4b 100644 --- a/backend/src/tensorflow/compiler/xla/client/xla_builder.cpp +++ b/backend/src/tensorflow/compiler/xla/client/xla_builder.cpp @@ -350,6 +350,13 @@ extern "C" { XlaOp* Pow(XlaOp& lhs, XlaOp& rhs) { return binOp(xla::Pow, lhs, rhs); } + XlaOp* Iota(XlaBuilder* builder, Shape& shape, int iota_dimension) { + auto builder_ = reinterpret_cast(builder); + auto& shape_ = reinterpret_cast(shape); + xla::XlaOp res = xla::Iota(builder_, shape_, iota_dimension); + return reinterpret_cast(new xla::XlaOp(res)); + } + XlaOp* ConvertElementType(XlaOp& operand, int new_element_type) { auto& operand_ = reinterpret_cast(operand); auto new_element_type_ = (xla::PrimitiveType) new_element_type; diff --git a/backend/src/tensorflow/compiler/xla/client/xla_builder.h b/backend/src/tensorflow/compiler/xla/client/xla_builder.h index 1fd3a0a90..bc1ac6a88 100644 --- a/backend/src/tensorflow/compiler/xla/client/xla_builder.h +++ b/backend/src/tensorflow/compiler/xla/client/xla_builder.h @@ -144,6 +144,8 @@ extern "C" { XlaOp* Pow(XlaOp& lhs, XlaOp& rhs); + XlaOp* Iota(XlaBuilder* builder, Shape& shape, int iota_dimension); + XlaOp* ConvertElementType(XlaOp& operand, int new_element_type); XlaOp* Neg(XlaOp& operand); diff --git a/src/Compiler/Eval.idr b/src/Compiler/Eval.idr index cd65dbb3d..8960d864f 100644 --- a/src/Compiler/Eval.idr +++ b/src/Compiler/Eval.idr @@ -113,6 +113,7 @@ interpret xlaBuilder (MkFn params root env) = do interpretE (MinFiniteValue {dtype}) = minFiniteValue {dtype} xlaBuilder interpretE (MaxFiniteValue {dtype}) = maxFiniteValue {dtype} xlaBuilder interpretE (ConvertElementType x) = convertElementType {dtype = F64} !(get x) + interpretE (Iota {dtype} shape dim) = iota xlaBuilder !(mkShape {dtype} shape) dim interpretE (Reshape from to x) = reshape !(get x) (range $ length from) to interpretE (Slice starts stops strides x) = slice !(get x) starts stops strides interpretE (DynamicSlice starts sizes x) = diff --git a/src/Compiler/Expr.idr b/src/Compiler/Expr.idr index d72d9448c..f4882beba 100644 --- a/src/Compiler/Expr.idr +++ b/src/Compiler/Expr.idr @@ -122,6 +122,7 @@ data Expr : Type where MaxValue : Primitive dtype => Expr MinFiniteValue : Primitive dtype => Expr MaxFiniteValue : Primitive dtype => Expr + Iota : Primitive dtype => Shape -> Nat -> Expr ConvertElementType : Primitive dtype => Nat -> Expr Reshape : Shape -> Shape -> Nat -> Expr Slice : List Nat -> List Nat -> List Nat -> Nat -> Expr diff --git a/src/Compiler/Xla/Prim/TensorFlow/Compiler/Xla/Client/XlaBuilder.idr b/src/Compiler/Xla/Prim/TensorFlow/Compiler/Xla/Client/XlaBuilder.idr index cb19972a6..297537b93 100644 --- a/src/Compiler/Xla/Prim/TensorFlow/Compiler/Xla/Client/XlaBuilder.idr +++ b/src/Compiler/Xla/Prim/TensorFlow/Compiler/Xla/Client/XlaBuilder.idr @@ -225,6 +225,10 @@ export %foreign (libxla "Pow") prim__pow : GCAnyPtr -> GCAnyPtr -> PrimIO AnyPtr +export +%foreign (libxla "Iota") +prim__iota : GCAnyPtr -> GCAnyPtr -> Int -> PrimIO AnyPtr + export %foreign (libxla "ConvertElementType") prim__convertElementType : GCAnyPtr -> Int -> PrimIO AnyPtr diff --git a/src/Compiler/Xla/TensorFlow/Compiler/Xla/Client/XlaBuilder.idr b/src/Compiler/Xla/TensorFlow/Compiler/Xla/Client/XlaBuilder.idr index 3bd95894f..f73d73c0d 100644 --- a/src/Compiler/Xla/TensorFlow/Compiler/Xla/Client/XlaBuilder.idr +++ b/src/Compiler/Xla/TensorFlow/Compiler/Xla/Client/XlaBuilder.idr @@ -355,6 +355,13 @@ export pow : HasIO io => XlaOp -> XlaOp -> io XlaOp pow = binaryOp prim__pow +export +iota : HasIO io => XlaBuilder -> Xla.Shape -> Nat -> io XlaOp +iota (MkXlaBuilder xlaBuilder) (MkShape shape) iota_dimension = do + opPtr <- primIO $ prim__iota xlaBuilder shape (cast iota_dimension) + opPtr <- onCollectAny opPtr XlaOp.delete + pure (MkXlaOp opPtr) + export convertElementType : (HasIO io, Primitive dtype) => XlaOp -> io XlaOp convertElementType (MkXlaOp operand) = do diff --git a/src/Tensor.idr b/src/Tensor.idr index 942758578..ddbb017fb 100644 --- a/src/Tensor.idr +++ b/src/Tensor.idr @@ -193,7 +193,7 @@ namespace Squeezable ||| Remove dimensions of length one from a `Tensor` such that it has the desired shape. For example: ||| -||| ```idris +||| ``` ||| x : Graph $ Tensor [2, 1, 3, 1] S32 ||| x = tensor [[[[4], [5], [6]]], ||| [[[7], [8], [9]]]] @@ -202,7 +202,7 @@ namespace Squeezable ||| y = squeeze !x ||| ``` ||| is -||| ```idris +||| ``` ||| y : Graph $ Tensor [2, 1, 3] S32 ||| y = tensor [[[4, 5, 6]], ||| [[7, 8, 9]]] @@ -586,16 +586,30 @@ namespace Broadcastable ||| [3] to [5, 3] Nest : Broadcastable f t -> Broadcastable f (_ :: t) +||| A shape can be extended with any number of leading dimensions. +||| +||| @leading The leading dimensions. +export +broadcastableByLeading : (leading : List Nat) -> Broadcastable shape (leading ++ shape) +broadcastableByLeading [] = Same +broadcastableByLeading (l :: ls) = Nest (broadcastableByLeading ls) + +||| A scalar can be broadcast to any shape. +%hint +export +scalarToAnyOk : (to : Shape) -> Broadcastable [] to +scalarToAnyOk to = rewrite sym $ appendNilRightNeutral to in broadcastableByLeading to + ||| Broadcast a `Tensor` to a new compatible shape. For example, ||| -||| ```idris +||| ``` ||| x : Graph $ Tensor [2, 3] S32 ||| x = broadcast !(tensor [4, 5, 6]) ||| ``` ||| ||| is ||| -||| ```idris +||| ``` ||| x : Graph $ Tensor [2, 3] S32 ||| x = tensor [[4, 5, 6], [4, 5, 6]] ||| ``` @@ -608,20 +622,14 @@ broadcast : Graph $ Tensor to dtype broadcast $ MkTensor {shape = _} x = addTensor $ Broadcast {dtype} from to x -%hint -export -scalarToAnyOk : (to : Shape) -> Broadcastable [] to -scalarToAnyOk [] = Same -scalarToAnyOk (_ :: xs) = Nest (scalarToAnyOk xs) - ||| A `Tensor` where every element has the specified value. For example, ||| -||| ```idris +||| ``` ||| fives : Graph $ Tensor [2, 3] S32 ||| fives = fill 5 ||| ``` ||| is -||| ```idris +||| ``` ||| fives : Graph $ Tensor [2, 3] S32 ||| fives = tensor [[5, 5, 5], ||| [5, 5, 5]] @@ -630,11 +638,43 @@ export fill : PrimitiveRW dtype ty => {shape : _} -> ty -> Graph $ Tensor shape dtype fill x = broadcast {shapesOK=scalarToAnyOk shape} !(tensor (Scalar x)) +||| A constant where values increment from zero along the specified `axis`. For example, +||| ``` +||| x : Graph $ Tensor [3, 5] S32 +||| x = iota 1 +||| ``` +||| is the same as +||| ``` +||| x : Graph $ Tensor [3, 5] S32 +||| x = tensor [[0, 1, 2, 3, 4], +||| [0, 1, 2, 3, 4], +||| [0, 1, 2, 3, 4]] +||| ``` +||| and +||| ``` +||| x : Graph $ Tensor [3, 5] S32 +||| x = iota 0 +||| ``` +||| is the same as +||| ``` +||| x : Graph $ Tensor [3, 5] S32 +||| x = tensor [[0, 0, 0, 0, 0], +||| [1, 1, 1, 1, 1], +||| [2, 2, 2, 2, 2]] +||| ``` +export +iota : Primitive.Num dtype => + {shape : _} -> + (axis : Nat) -> + {auto 0 inBounds : InBounds axis shape} -> + Graph $ Tensor shape dtype +iota dimension = addTensor $ Iota shape {dtype} dimension + ----------------------------- generic operations ---------------------------- ||| Lift a unary function on scalars to an element-wise function on `Tensor`s of arbitrary shape. ||| For example, -||| ```idris +||| ``` ||| recip : Tensor [] F64 -> Graph $ Tensor [] F64 ||| recip x = 1.0 / pure x ||| ``` @@ -652,7 +692,7 @@ map f $ MkTensor {shape = _} x = do ||| Lift a binary function on scalars to an element-wise function on `Tensor`s of arbitrary shape. ||| For example, -||| ```idris +||| ``` ||| addRecip : Tensor [] F64 -> Tensor [] F64 -> Graph $ Tensor [] F64 ||| addRecip x y = pure x + 1.0 / pure y ||| ``` @@ -947,7 +987,7 @@ namespace Matrix ||| Matrix multiplication with a matrix or vector. Contraction is along the last axis of the first ||| and the first axis of the last. For example: ||| - ||| ```idris + ||| ``` ||| x : Graph $ Tensor [2, 3] S32 ||| x = tensor [[-1, -2, -3], ||| [ 0, 1, 2]] @@ -961,7 +1001,7 @@ namespace Matrix ||| ||| is ||| - ||| ```idris + ||| ``` ||| z : Graph $ Tensor [2, 1] S32 ||| z = tensor [-19, 10] ||| ``` diff --git a/test.ipkg b/test.ipkg index 9d203ec1f..7542183b3 100644 --- a/test.ipkg +++ b/test.ipkg @@ -26,6 +26,7 @@ modules = Utils.Cases, Utils.Comparison, + Utils.Proof, Utils.TestComparison, Main, diff --git a/test/Unit/TestTensor.idr b/test/Unit/TestTensor.idr index c59e9a9a7..509fc8eea 100644 --- a/test/Unit/TestTensor.idr +++ b/test/Unit/TestTensor.idr @@ -31,6 +31,7 @@ import Tensor import Utils import Utils.Comparison import Utils.Cases +import Utils.Proof partial tensorThenEval : Property @@ -114,6 +115,58 @@ boundedNonFinite = fixedProperty $ do unsafeEval {dtype=F64} (Types.min @{NonFinite}) === -inf unsafeEval {dtype=F64} (Types.max @{NonFinite}) === inf +partial +iota : Property +iota = property $ do + init <- forAll shapes + mid <- forAll dims + tail <- forAll shapes + + let broadcastTail : Primitive dtype => + {n : _} -> + (tail : Shape) -> + Tensor [n] dtype -> + Graph $ Tensor (n :: tail) dtype + broadcastTail [] x = pure x + broadcastTail (d :: ds) x = do + x <- broadcastTail ds x + broadcast !(expand 1 x) + + let rangeFull = do + rangeV <- tensor {dtype = U64} $ cast (Vect.range mid) + rangeVTail <- broadcastTail tail rangeV + broadcast {shapesOK = broadcastableByLeading init} rangeVTail + inBounds = appendNonEmptyLengthInBounds init mid tail + actual : Graph (Tensor (init ++ mid :: tail) U64) = iota {inBounds} (length init) + + actual ===# rangeFull + + let actual : Graph (Tensor (init ++ mid :: tail) F64) = iota {inBounds} (length init) + + actual ===# (do castDtype !rangeFull) + +partial +iotaExamples : Property +iotaExamples = fixedProperty $ do + iota 0 ===# tensor {dtype = S32} [0, 1, 2, 3] + iota 1 ===# tensor {dtype = S32} [[0], [0], [0], [0]] + + iota 1 ===# tensor {dtype = S32} [[0, 1, 2, 3, 4], + [0, 1, 2, 3, 4], + [0, 1, 2, 3, 4]] + + iota 0 ===# tensor {dtype = S32} [[0, 0, 0, 0, 0], + [1, 1, 1, 1, 1], + [2, 2, 2, 2, 2]] + + iota 1 ===# tensor {dtype = F64} [[0.0, 1.0, 2.0, 3.0, 4.0], + [0.0, 1.0, 2.0, 3.0, 4.0], + [0.0, 1.0, 2.0, 3.0, 4.0]] + + iota 0 ===# tensor {dtype = F64} [[0.0, 0.0, 0.0, 0.0, 0.0], + [1.0, 1.0, 1.0, 1.0, 1.0], + [2.0, 2.0, 2.0, 2.0, 2.0]] + partial show : Property show = fixedProperty $ do @@ -351,6 +404,8 @@ group = MkGroup "Tensor" $ [ ("eval . tensor", tensorThenEval) , ("can read/write finite numeric bounds to/from XLA", canConvertAtXlaNumericBounds) , ("bounded non-finite", boundedNonFinite) + , ("iota", iota) + , ("iota examples", iotaExamples) , ("show", show) , ("cast", cast) , ("identity", identity) diff --git a/test/Utils/Proof.idr b/test/Utils/Proof.idr new file mode 100644 index 000000000..4681ba701 --- /dev/null +++ b/test/Utils/Proof.idr @@ -0,0 +1,28 @@ +{-- +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +--} +module Utils.Proof + +import Util + +import Data.List + +export +appendNonEmptyLengthInBounds : (xs : List a) -> + (y : a) -> + (ys : List a) -> + InBounds (length xs) (xs ++ y :: ys) +appendNonEmptyLengthInBounds [] _ _ = InFirst +appendNonEmptyLengthInBounds (x :: xs) y ys = InLater $ appendNonEmptyLengthInBounds xs y ys