Skip to content

Commit

Permalink
add iota for range-like tensors (#385)
Browse files Browse the repository at this point in the history
  • Loading branch information
joelberkeley authored Jan 21, 2024
1 parent cc49d9b commit 63a512e
Show file tree
Hide file tree
Showing 11 changed files with 163 additions and 17 deletions.
2 changes: 1 addition & 1 deletion backend/VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.0.8
0.0.9
7 changes: 7 additions & 0 deletions backend/src/tensorflow/compiler/xla/client/xla_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,13 @@ extern "C" {

XlaOp* Pow(XlaOp& lhs, XlaOp& rhs) { return binOp(xla::Pow, lhs, rhs); }

XlaOp* Iota(XlaBuilder* builder, Shape& shape, int iota_dimension) {
auto builder_ = reinterpret_cast<xla::XlaBuilder*>(builder);
auto& shape_ = reinterpret_cast<xla::Shape&>(shape);
xla::XlaOp res = xla::Iota(builder_, shape_, iota_dimension);
return reinterpret_cast<XlaOp*>(new xla::XlaOp(res));
}

XlaOp* ConvertElementType(XlaOp& operand, int new_element_type) {
auto& operand_ = reinterpret_cast<xla::XlaOp&>(operand);
auto new_element_type_ = (xla::PrimitiveType) new_element_type;
Expand Down
2 changes: 2 additions & 0 deletions backend/src/tensorflow/compiler/xla/client/xla_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,8 @@ extern "C" {

XlaOp* Pow(XlaOp& lhs, XlaOp& rhs);

XlaOp* Iota(XlaBuilder* builder, Shape& shape, int iota_dimension);

XlaOp* ConvertElementType(XlaOp& operand, int new_element_type);

XlaOp* Neg(XlaOp& operand);
Expand Down
1 change: 1 addition & 0 deletions src/Compiler/Eval.idr
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ interpret xlaBuilder (MkFn params root env) = do
interpretE (MinFiniteValue {dtype}) = minFiniteValue {dtype} xlaBuilder
interpretE (MaxFiniteValue {dtype}) = maxFiniteValue {dtype} xlaBuilder
interpretE (ConvertElementType x) = convertElementType {dtype = F64} !(get x)
interpretE (Iota {dtype} shape dim) = iota xlaBuilder !(mkShape {dtype} shape) dim
interpretE (Reshape from to x) = reshape !(get x) (range $ length from) to
interpretE (Slice starts stops strides x) = slice !(get x) starts stops strides
interpretE (DynamicSlice starts sizes x) =
Expand Down
1 change: 1 addition & 0 deletions src/Compiler/Expr.idr
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ data Expr : Type where
MaxValue : Primitive dtype => Expr
MinFiniteValue : Primitive dtype => Expr
MaxFiniteValue : Primitive dtype => Expr
Iota : Primitive dtype => Shape -> Nat -> Expr
ConvertElementType : Primitive dtype => Nat -> Expr
Reshape : Shape -> Shape -> Nat -> Expr
Slice : List Nat -> List Nat -> List Nat -> Nat -> Expr
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,10 @@ export
%foreign (libxla "Pow")
prim__pow : GCAnyPtr -> GCAnyPtr -> PrimIO AnyPtr

export
%foreign (libxla "Iota")
prim__iota : GCAnyPtr -> GCAnyPtr -> Int -> PrimIO AnyPtr

export
%foreign (libxla "ConvertElementType")
prim__convertElementType : GCAnyPtr -> Int -> PrimIO AnyPtr
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,13 @@ export
pow : HasIO io => XlaOp -> XlaOp -> io XlaOp
pow = binaryOp prim__pow

export
iota : HasIO io => XlaBuilder -> Xla.Shape -> Nat -> io XlaOp
iota (MkXlaBuilder xlaBuilder) (MkShape shape) iota_dimension = do
opPtr <- primIO $ prim__iota xlaBuilder shape (cast iota_dimension)
opPtr <- onCollectAny opPtr XlaOp.delete
pure (MkXlaOp opPtr)

export
convertElementType : (HasIO io, Primitive dtype) => XlaOp -> io XlaOp
convertElementType (MkXlaOp operand) = do
Expand Down
72 changes: 56 additions & 16 deletions src/Tensor.idr
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ namespace Squeezable

||| Remove dimensions of length one from a `Tensor` such that it has the desired shape. For example:
|||
||| ```idris
||| ```
||| x : Graph $ Tensor [2, 1, 3, 1] S32
||| x = tensor [[[[4], [5], [6]]],
||| [[[7], [8], [9]]]]
Expand All @@ -202,7 +202,7 @@ namespace Squeezable
||| y = squeeze !x
||| ```
||| is
||| ```idris
||| ```
||| y : Graph $ Tensor [2, 1, 3] S32
||| y = tensor [[[4, 5, 6]],
||| [[7, 8, 9]]]
Expand Down Expand Up @@ -586,16 +586,30 @@ namespace Broadcastable
||| [3] to [5, 3]
Nest : Broadcastable f t -> Broadcastable f (_ :: t)

||| A shape can be extended with any number of leading dimensions.
|||
||| @leading The leading dimensions.
export
broadcastableByLeading : (leading : List Nat) -> Broadcastable shape (leading ++ shape)
broadcastableByLeading [] = Same
broadcastableByLeading (l :: ls) = Nest (broadcastableByLeading ls)

||| A scalar can be broadcast to any shape.
%hint
export
scalarToAnyOk : (to : Shape) -> Broadcastable [] to
scalarToAnyOk to = rewrite sym $ appendNilRightNeutral to in broadcastableByLeading to

||| Broadcast a `Tensor` to a new compatible shape. For example,
|||
||| ```idris
||| ```
||| x : Graph $ Tensor [2, 3] S32
||| x = broadcast !(tensor [4, 5, 6])
||| ```
|||
||| is
|||
||| ```idris
||| ```
||| x : Graph $ Tensor [2, 3] S32
||| x = tensor [[4, 5, 6], [4, 5, 6]]
||| ```
Expand All @@ -608,20 +622,14 @@ broadcast :
Graph $ Tensor to dtype
broadcast $ MkTensor {shape = _} x = addTensor $ Broadcast {dtype} from to x

%hint
export
scalarToAnyOk : (to : Shape) -> Broadcastable [] to
scalarToAnyOk [] = Same
scalarToAnyOk (_ :: xs) = Nest (scalarToAnyOk xs)

||| A `Tensor` where every element has the specified value. For example,
|||
||| ```idris
||| ```
||| fives : Graph $ Tensor [2, 3] S32
||| fives = fill 5
||| ```
||| is
||| ```idris
||| ```
||| fives : Graph $ Tensor [2, 3] S32
||| fives = tensor [[5, 5, 5],
||| [5, 5, 5]]
Expand All @@ -630,11 +638,43 @@ export
fill : PrimitiveRW dtype ty => {shape : _} -> ty -> Graph $ Tensor shape dtype
fill x = broadcast {shapesOK=scalarToAnyOk shape} !(tensor (Scalar x))

||| A constant where values increment from zero along the specified `axis`. For example,
||| ```
||| x : Graph $ Tensor [3, 5] S32
||| x = iota 1
||| ```
||| is the same as
||| ```
||| x : Graph $ Tensor [3, 5] S32
||| x = tensor [[0, 1, 2, 3, 4],
||| [0, 1, 2, 3, 4],
||| [0, 1, 2, 3, 4]]
||| ```
||| and
||| ```
||| x : Graph $ Tensor [3, 5] S32
||| x = iota 0
||| ```
||| is the same as
||| ```
||| x : Graph $ Tensor [3, 5] S32
||| x = tensor [[0, 0, 0, 0, 0],
||| [1, 1, 1, 1, 1],
||| [2, 2, 2, 2, 2]]
||| ```
export
iota : Primitive.Num dtype =>
{shape : _} ->
(axis : Nat) ->
{auto 0 inBounds : InBounds axis shape} ->
Graph $ Tensor shape dtype
iota dimension = addTensor $ Iota shape {dtype} dimension

----------------------------- generic operations ----------------------------

||| Lift a unary function on scalars to an element-wise function on `Tensor`s of arbitrary shape.
||| For example,
||| ```idris
||| ```
||| recip : Tensor [] F64 -> Graph $ Tensor [] F64
||| recip x = 1.0 / pure x
||| ```
Expand All @@ -652,7 +692,7 @@ map f $ MkTensor {shape = _} x = do

||| Lift a binary function on scalars to an element-wise function on `Tensor`s of arbitrary shape.
||| For example,
||| ```idris
||| ```
||| addRecip : Tensor [] F64 -> Tensor [] F64 -> Graph $ Tensor [] F64
||| addRecip x y = pure x + 1.0 / pure y
||| ```
Expand Down Expand Up @@ -947,7 +987,7 @@ namespace Matrix
||| Matrix multiplication with a matrix or vector. Contraction is along the last axis of the first
||| and the first axis of the last. For example:
|||
||| ```idris
||| ```
||| x : Graph $ Tensor [2, 3] S32
||| x = tensor [[-1, -2, -3],
||| [ 0, 1, 2]]
Expand All @@ -961,7 +1001,7 @@ namespace Matrix
|||
||| is
|||
||| ```idris
||| ```
||| z : Graph $ Tensor [2, 1] S32
||| z = tensor [-19, 10]
||| ```
Expand Down
1 change: 1 addition & 0 deletions test.ipkg
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ modules =

Utils.Cases,
Utils.Comparison,
Utils.Proof,
Utils.TestComparison,

Main,
Expand Down
55 changes: 55 additions & 0 deletions test/Unit/TestTensor.idr
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import Tensor
import Utils
import Utils.Comparison
import Utils.Cases
import Utils.Proof

partial
tensorThenEval : Property
Expand Down Expand Up @@ -114,6 +115,58 @@ boundedNonFinite = fixedProperty $ do
unsafeEval {dtype=F64} (Types.min @{NonFinite}) === -inf
unsafeEval {dtype=F64} (Types.max @{NonFinite}) === inf

partial
iota : Property
iota = property $ do
init <- forAll shapes
mid <- forAll dims
tail <- forAll shapes

let broadcastTail : Primitive dtype =>
{n : _} ->
(tail : Shape) ->
Tensor [n] dtype ->
Graph $ Tensor (n :: tail) dtype
broadcastTail [] x = pure x
broadcastTail (d :: ds) x = do
x <- broadcastTail ds x
broadcast !(expand 1 x)

let rangeFull = do
rangeV <- tensor {dtype = U64} $ cast (Vect.range mid)
rangeVTail <- broadcastTail tail rangeV
broadcast {shapesOK = broadcastableByLeading init} rangeVTail
inBounds = appendNonEmptyLengthInBounds init mid tail
actual : Graph (Tensor (init ++ mid :: tail) U64) = iota {inBounds} (length init)

actual ===# rangeFull

let actual : Graph (Tensor (init ++ mid :: tail) F64) = iota {inBounds} (length init)

actual ===# (do castDtype !rangeFull)

partial
iotaExamples : Property
iotaExamples = fixedProperty $ do
iota 0 ===# tensor {dtype = S32} [0, 1, 2, 3]
iota 1 ===# tensor {dtype = S32} [[0], [0], [0], [0]]

iota 1 ===# tensor {dtype = S32} [[0, 1, 2, 3, 4],
[0, 1, 2, 3, 4],
[0, 1, 2, 3, 4]]

iota 0 ===# tensor {dtype = S32} [[0, 0, 0, 0, 0],
[1, 1, 1, 1, 1],
[2, 2, 2, 2, 2]]

iota 1 ===# tensor {dtype = F64} [[0.0, 1.0, 2.0, 3.0, 4.0],
[0.0, 1.0, 2.0, 3.0, 4.0],
[0.0, 1.0, 2.0, 3.0, 4.0]]

iota 0 ===# tensor {dtype = F64} [[0.0, 0.0, 0.0, 0.0, 0.0],
[1.0, 1.0, 1.0, 1.0, 1.0],
[2.0, 2.0, 2.0, 2.0, 2.0]]

partial
show : Property
show = fixedProperty $ do
Expand Down Expand Up @@ -351,6 +404,8 @@ group = MkGroup "Tensor" $ [
("eval . tensor", tensorThenEval)
, ("can read/write finite numeric bounds to/from XLA", canConvertAtXlaNumericBounds)
, ("bounded non-finite", boundedNonFinite)
, ("iota", iota)
, ("iota examples", iotaExamples)
, ("show", show)
, ("cast", cast)
, ("identity", identity)
Expand Down
28 changes: 28 additions & 0 deletions test/Utils/Proof.idr
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
{--
Copyright 2024 Joel Berkeley
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
--}
module Utils.Proof

import Util

import Data.List

export
appendNonEmptyLengthInBounds : (xs : List a) ->
(y : a) ->
(ys : List a) ->
InBounds (length xs) (xs ++ y :: ys)
appendNonEmptyLengthInBounds [] _ _ = InFirst
appendNonEmptyLengthInBounds (x :: xs) y ys = InLater $ appendNonEmptyLengthInBounds xs y ys

0 comments on commit 63a512e

Please sign in to comment.