From a2cfda73199cf6b90f2e47cb31e03f71a1f8a701 Mon Sep 17 00:00:00 2001 From: Dirk Toewe Date: Wed, 30 Jan 2019 21:46:14 +0100 Subject: [PATCH 1/6] tf.linalg.bandPart --- src/ops/linalg_ops.ts | 92 ++++++++++++++++++++++++++++- src/ops/linalg_ops_test.ts | 118 +++++++++++++++++++++++++++++++++++++ 2 files changed, 209 insertions(+), 1 deletion(-) diff --git a/src/ops/linalg_ops.ts b/src/ops/linalg_ops.ts index 7a15b94e16..5336fd6b16 100644 --- a/src/ops/linalg_ops.ts +++ b/src/ops/linalg_ops.ts @@ -22,13 +22,17 @@ import {ENV} from '../environment'; import {dispose} from '../globals'; import {Tensor, Tensor1D, Tensor2D} from '../tensor'; +import {convertToTensor} from '../tensor_util_env'; +import {TensorLike} from '../types'; import {assert} from '../util'; import {eye, squeeze, stack, unstack} from './array_ops'; +import {sub} from './binary_ops'; import {split} from './concat_split'; +import {logicalAnd, where} from './logical_ops'; import {norm} from './norm'; import {op} from './operation'; import {sum} from './reduction_ops'; -import {tensor2d} from './tensor_ops'; +import {range, scalar, tensor2d, zeros} from './tensor_ops'; /** * Gram-Schmidt orthogonalization. @@ -260,5 +264,91 @@ function qr2d(x: Tensor2D, fullMatrices = false): [Tensor2D, Tensor2D] { }) as [Tensor2D, Tensor2D]; } +/** + * Copies a tensor of matrices, setting everything outside a central band + * in each matrix to zero. + * + * ```js + * >>> const a = tf.tensor2d([[11, 12, 13, 14], + * ... [21, 22, 23, 24], + * ... [31, 32, 33, 34], + * ... [41, 42, 43, 44]]); + * >>> tf.linalg.bandPart(a,0,2).print(); + * [[11, 12, 13, 0], + * [ 0, 22, 23, 24], + * [ 0, 0, 33, 34], + * [ 0, 0, 0, 44]] + * + * >>> tf.linalg.bandPart(a,1,-1).print(); + * [[11, 12, 13, 14], + * [21, 22, 23, 24], + * [ 0, 32, 33, 34], + * [ 0, 0, 43, 44]] + * ``` + * + * @param a Tensor of matrices from which the band part is extracted. + * @param numLower The number of subdiagonal lines to be copied. + * If set to `-1`, all entries below the diagonal are + * copied. + * @param numUpper The number of superdiagonal lines to be copied. + * If set to `-1`, all entries above the diagonal are + * copied. + */ +/** + * @doc {heading:'Operations', + * subheading:'Linear Algebra', + * namespace:'linalg'} + */ +function bandPart_( + a: T|TensorLike, numLower: number, numUpper: number +): T +{ + if( numLower%1 !== 0 ){ + throw new Error(`bandPart(): numLower=${numLower} not an integer.`); + } + if( numUpper%1 !== 0 ){ + throw new Error(`bandPart(): numUpper=${numUpper} not an integer.`); + } + + return ENV.engine.tidy( () => { + const $a = convertToTensor(a,'a','bandPart'); + a = undefined; + + if( $a.rank < 2 ) { + throw new Error(`bandPart(): a.rank = ${$a.rank} < 2.`); + } + + const shape = $a.shape, + [M,N] = $a.shape.slice(-2); + + if( !(numLower <= M) ) { + throw new Error(`bandPart() check failed: numLower <= #rows.` ); + } + if( !(numUpper <= N) ) { + throw new Error(`bandPart() check failed: numUpper <= #columns.`); + } + + if( numLower < 0 ) { numLower = M; } + if( numUpper < 0 ) { numUpper = N; } + + const i = range(0,M, 1, 'int32').reshape([-1,1]), + j = range(0,N, 1, 'int32'); + + const inBand = logicalAnd( + sub(i,j).lessEqual( scalar(numLower,'int32') ), + sub(j,i).lessEqual( scalar(numUpper,'int32') ) + ); + + const zero = zeros([M,N], $a.dtype); + + return stack( + unstack( $a.reshape([-1,M,N]) ).map( + mat => where(inBand, mat, zero) + ) + ).reshape(shape) as T; + }); +} + export const gramSchmidt = op({gramSchmidt_}); +export const bandPart = op({bandPart_}); export const qr = op({qr_}); diff --git a/src/ops/linalg_ops_test.ts b/src/ops/linalg_ops_test.ts index bfbb5ef62b..a69af58cd3 100644 --- a/src/ops/linalg_ops_test.ts +++ b/src/ops/linalg_ops_test.ts @@ -241,3 +241,121 @@ describeWithFlags('qr', ALL_ENVS, () => { expect(() => tf.linalg.qr(x2)).toThrowError(/rank >= 2.*got rank 1/); }); }); + +describeWithFlags('bandPart', ALL_ENVS, () => { + const la = tf.linalg; + + // FIXME: shouldn't 1*x be lossless? + // It's even in the IEEE spec somewhere... + // Yet this fails on Travis with `expectArraysEqual`... + const expectArraysEqual = expectArraysClose; + + it('works for 3x4 example', () => { + const a = tf.tensor2d([ + [1, 2, 3, 4], + [5, 6, 7, 8], + [9,10,11,12] + ]); + expectArraysEqual( + la.bandPart(a,0,0), + tf.tensor2d([[1, 0, 0, 0], + [0, 6, 0, 0], + [0, 0,11, 0]]) + ); + expectArraysEqual( + la.bandPart(a,0,1), + tf.tensor2d([[1, 2, 0, 0], + [0, 6, 7, 0], + [0, 0,11,12]]) + ); + expectArraysEqual( + la.bandPart(a,0,2), + tf.tensor2d([[1, 2, 3, 0], + [0, 6, 7, 8], + [0, 0,11,12]]) + ); + expectArraysEqual( + la.bandPart(a,0,2), + tf.tensor2d([[1, 2, 3, 0], + [0, 6, 7, 8], + [0, 0,11,12]]) + ); + for( const numUpper of [3,4,-1,-2] ) { + expectArraysEqual( + la.bandPart(a,0,numUpper), + tf.tensor2d([[1, 2, 3, 4], + [0, 6, 7, 8], + [0, 0,11,12]]) + ); + } + + expectArraysEqual( + la.bandPart(a,1,0), + tf.tensor2d([[1, 0, 0, 0], + [5, 6, 0, 0], + [0,10,11, 0]]) + ); + expectArraysEqual( + la.bandPart(a,1,1), + tf.tensor2d([[1, 2, 0, 0], + [5, 6, 7, 0], + [0,10,11,12]]) + ); + expectArraysEqual( + la.bandPart(a,1,2), + tf.tensor2d([[1, 2, 3, 0], + [5, 6, 7, 8], + [0,10,11,12]]) + ); + expectArraysEqual( + la.bandPart(a,1,2), + tf.tensor2d([[1, 2, 3, 0], + [5, 6, 7, 8], + [0,10,11,12]]) + ); + for( const numUpper of [3,4,-1,-2] ) { + expectArraysEqual( + la.bandPart(a,1,numUpper), + tf.tensor2d([[1, 2, 3, 4], + [5, 6, 7, 8], + [0,10,11,12]]) + ); + } + + for( const numLower of [2,3,-1,-2]) + { + expectArraysEqual( + la.bandPart(a,numLower,0), + tf.tensor2d([[1, 0, 0, 0], + [5, 6, 0, 0], + [9,10,11, 0]]) + ); + expectArraysEqual( + la.bandPart(a,numLower,1), + tf.tensor2d([[1, 2, 0, 0], + [5, 6, 7, 0], + [9,10,11,12]]) + ); + expectArraysEqual( + la.bandPart(a,numLower,2), + tf.tensor2d([[1, 2, 3, 0], + [5, 6, 7, 8], + [9,10,11,12]]) + ); + expectArraysEqual( + la.bandPart(a,numLower,2), + tf.tensor2d([[1, 2, 3, 0], + [5, 6, 7, 8], + [9,10,11,12]]) + ); + for( const numUpper of [3,4,-1,-2] ) { + expectArraysEqual( + la.bandPart(a,numLower,numUpper), + tf.tensor2d([[1, 2, 3, 4], + [5, 6, 7, 8], + [9,10,11,12]]) + ); + } + } + }); +}); From 6e57e5cafba1db8be9e26965ad2a975c64ee8c91 Mon Sep 17 00:00:00 2001 From: Dirk T Date: Thu, 31 Jan 2019 07:35:53 +0100 Subject: [PATCH 2/6] Remove workaround (previous GPU precision issues) --- src/ops/linalg_ops_test.ts | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/ops/linalg_ops_test.ts b/src/ops/linalg_ops_test.ts index a69af58cd3..4980bae843 100644 --- a/src/ops/linalg_ops_test.ts +++ b/src/ops/linalg_ops_test.ts @@ -245,11 +245,6 @@ describeWithFlags('qr', ALL_ENVS, () => { describeWithFlags('bandPart', ALL_ENVS, () => { const la = tf.linalg; - // FIXME: shouldn't 1*x be lossless? - // It's even in the IEEE spec somewhere... - // Yet this fails on Travis with `expectArraysEqual`... - const expectArraysEqual = expectArraysClose; - it('works for 3x4 example', () => { const a = tf.tensor2d([ [1, 2, 3, 4], From 0264f4dfac03d436dec436285063b760929518af Mon Sep 17 00:00:00 2001 From: Dirk T Date: Thu, 31 Jan 2019 10:09:20 +0100 Subject: [PATCH 3/6] Update linalg_ops_test.ts --- src/ops/linalg_ops_test.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ops/linalg_ops_test.ts b/src/ops/linalg_ops_test.ts index 4980bae843..75749c47ba 100644 --- a/src/ops/linalg_ops_test.ts +++ b/src/ops/linalg_ops_test.ts @@ -18,7 +18,7 @@ import * as tf from '../index'; import {describeWithFlags} from '../jasmine_util'; import {Tensor1D, Tensor2D} from '../tensor'; -import {ALL_ENVS, expectArraysClose, WEBGL_ENVS} from '../test_util'; +import {ALL_ENVS, expectArraysClose, expectArraysEqual, WEBGL_ENVS} from '../test_util'; import {scalar, tensor1d, tensor2d, tensor3d, tensor4d} from './ops'; From 8a9620a60f24084d00e0837c8a2e444d9be98e3a Mon Sep 17 00:00:00 2001 From: Dirk Toewe Date: Thu, 31 Jan 2019 19:35:52 +0100 Subject: [PATCH 4/6] Made bandPart test float16-aware --- src/ops/linalg_ops_test.ts | 39 +++++++++++++++++++++++--------------- 1 file changed, 24 insertions(+), 15 deletions(-) diff --git a/src/ops/linalg_ops_test.ts b/src/ops/linalg_ops_test.ts index 75749c47ba..9ef2874a87 100644 --- a/src/ops/linalg_ops_test.ts +++ b/src/ops/linalg_ops_test.ts @@ -245,38 +245,47 @@ describeWithFlags('qr', ALL_ENVS, () => { describeWithFlags('bandPart', ALL_ENVS, () => { const la = tf.linalg; + const expectArrayEq = (() => { + switch( tf.ENV.backend.floatPrecision() ) + { + default: return expectArraysClose; + case 32: + case 64: return expectArraysEqual; + } + })(); + it('works for 3x4 example', () => { const a = tf.tensor2d([ [1, 2, 3, 4], [5, 6, 7, 8], [9,10,11,12] ]); - expectArraysEqual( + expectArrayEq( la.bandPart(a,0,0), tf.tensor2d([[1, 0, 0, 0], [0, 6, 0, 0], [0, 0,11, 0]]) ); - expectArraysEqual( + expectArrayEq( la.bandPart(a,0,1), tf.tensor2d([[1, 2, 0, 0], [0, 6, 7, 0], [0, 0,11,12]]) ); - expectArraysEqual( + expectArrayEq( la.bandPart(a,0,2), tf.tensor2d([[1, 2, 3, 0], [0, 6, 7, 8], [0, 0,11,12]]) ); - expectArraysEqual( + expectArrayEq( la.bandPart(a,0,2), tf.tensor2d([[1, 2, 3, 0], [0, 6, 7, 8], [0, 0,11,12]]) ); for( const numUpper of [3,4,-1,-2] ) { - expectArraysEqual( + expectArrayEq( la.bandPart(a,0,numUpper), tf.tensor2d([[1, 2, 3, 4], [0, 6, 7, 8], @@ -284,32 +293,32 @@ describeWithFlags('bandPart', ALL_ENVS, () => { ); } - expectArraysEqual( + expectArrayEq( la.bandPart(a,1,0), tf.tensor2d([[1, 0, 0, 0], [5, 6, 0, 0], [0,10,11, 0]]) ); - expectArraysEqual( + expectArrayEq( la.bandPart(a,1,1), tf.tensor2d([[1, 2, 0, 0], [5, 6, 7, 0], [0,10,11,12]]) ); - expectArraysEqual( + expectArrayEq( la.bandPart(a,1,2), tf.tensor2d([[1, 2, 3, 0], [5, 6, 7, 8], [0,10,11,12]]) ); - expectArraysEqual( + expectArrayEq( la.bandPart(a,1,2), tf.tensor2d([[1, 2, 3, 0], [5, 6, 7, 8], [0,10,11,12]]) ); for( const numUpper of [3,4,-1,-2] ) { - expectArraysEqual( + expectArrayEq( la.bandPart(a,1,numUpper), tf.tensor2d([[1, 2, 3, 4], [5, 6, 7, 8], @@ -319,32 +328,32 @@ describeWithFlags('bandPart', ALL_ENVS, () => { for( const numLower of [2,3,-1,-2]) { - expectArraysEqual( + expectArrayEq( la.bandPart(a,numLower,0), tf.tensor2d([[1, 0, 0, 0], [5, 6, 0, 0], [9,10,11, 0]]) ); - expectArraysEqual( + expectArrayEq( la.bandPart(a,numLower,1), tf.tensor2d([[1, 2, 0, 0], [5, 6, 7, 0], [9,10,11,12]]) ); - expectArraysEqual( + expectArrayEq( la.bandPart(a,numLower,2), tf.tensor2d([[1, 2, 3, 0], [5, 6, 7, 8], [9,10,11,12]]) ); - expectArraysEqual( + expectArrayEq( la.bandPart(a,numLower,2), tf.tensor2d([[1, 2, 3, 0], [5, 6, 7, 8], [9,10,11,12]]) ); for( const numUpper of [3,4,-1,-2] ) { - expectArraysEqual( + expectArrayEq( la.bandPart(a,numLower,numUpper), tf.tensor2d([[1, 2, 3, 4], [5, 6, 7, 8], From 9fcbfd7b92d06682cfdc4808e0803b7b24cf4963 Mon Sep 17 00:00:00 2001 From: Dirk Toewe Date: Sun, 3 Feb 2019 10:45:14 +0100 Subject: [PATCH 5/6] Test precision now depending on tested ENVS. --- src/ops/linalg_ops_test.ts | 174 ++++++++++++++++++------------------- 1 file changed, 85 insertions(+), 89 deletions(-) diff --git a/src/ops/linalg_ops_test.ts b/src/ops/linalg_ops_test.ts index 9ef2874a87..9da088202f 100644 --- a/src/ops/linalg_ops_test.ts +++ b/src/ops/linalg_ops_test.ts @@ -18,7 +18,7 @@ import * as tf from '../index'; import {describeWithFlags} from '../jasmine_util'; import {Tensor1D, Tensor2D} from '../tensor'; -import {ALL_ENVS, expectArraysClose, expectArraysEqual, WEBGL_ENVS} from '../test_util'; +import {ALL_ENVS, CPU_ENVS, expectArraysClose, expectArraysEqual, WEBGL_ENVS} from '../test_util'; import {scalar, tensor1d, tensor2d, tensor3d, tensor4d} from './ops'; @@ -242,124 +242,120 @@ describeWithFlags('qr', ALL_ENVS, () => { }); }); -describeWithFlags('bandPart', ALL_ENVS, () => { - const la = tf.linalg; +for( const ENV of [CPU_ENVS, WEBGL_ENVS] ) +{ + describeWithFlags('bandPart', ALL_ENVS, () => { + const la = tf.linalg; - const expectArrayEq = (() => { - switch( tf.ENV.backend.floatPrecision() ) - { - default: return expectArraysClose; - case 32: - case 64: return expectArraysEqual; - } - })(); + const expectArrayEq = Object.is(ENV, CPU_ENVS) + ? expectArraysEqual + : expectArraysClose; - it('works for 3x4 example', () => { - const a = tf.tensor2d([ - [1, 2, 3, 4], - [5, 6, 7, 8], - [9,10,11,12] - ]); - expectArrayEq( - la.bandPart(a,0,0), - tf.tensor2d([[1, 0, 0, 0], - [0, 6, 0, 0], - [0, 0,11, 0]]) - ); - expectArrayEq( - la.bandPart(a,0,1), - tf.tensor2d([[1, 2, 0, 0], - [0, 6, 7, 0], - [0, 0,11,12]]) - ); - expectArrayEq( - la.bandPart(a,0,2), - tf.tensor2d([[1, 2, 3, 0], - [0, 6, 7, 8], - [0, 0,11,12]]) - ); - expectArrayEq( - la.bandPart(a,0,2), - tf.tensor2d([[1, 2, 3, 0], - [0, 6, 7, 8], - [0, 0,11,12]]) - ); - for( const numUpper of [3,4,-1,-2] ) { + it('works for 3x4 example', () => { + const a = tf.tensor2d([[1, 2, 3, 4], + [5, 6, 7, 8], + [9,10,11,12]]); + expectArrayEq( + la.bandPart(a,0,0), + tf.tensor2d([[1, 0, 0, 0], + [0, 6, 0, 0], + [0, 0,11, 0]]) + ); + expectArrayEq( + la.bandPart(a,0,1), + tf.tensor2d([[1, 2, 0, 0], + [0, 6, 7, 0], + [0, 0,11,12]]) + ); expectArrayEq( - la.bandPart(a,0,numUpper), - tf.tensor2d([[1, 2, 3, 4], + la.bandPart(a,0,2), + tf.tensor2d([[1, 2, 3, 0], [0, 6, 7, 8], [0, 0,11,12]]) ); - } - - expectArrayEq( - la.bandPart(a,1,0), - tf.tensor2d([[1, 0, 0, 0], - [5, 6, 0, 0], - [0,10,11, 0]]) - ); - expectArrayEq( - la.bandPart(a,1,1), - tf.tensor2d([[1, 2, 0, 0], - [5, 6, 7, 0], - [0,10,11,12]]) - ); - expectArrayEq( - la.bandPart(a,1,2), - tf.tensor2d([[1, 2, 3, 0], - [5, 6, 7, 8], - [0,10,11,12]]) - ); - expectArrayEq( - la.bandPart(a,1,2), - tf.tensor2d([[1, 2, 3, 0], - [5, 6, 7, 8], - [0,10,11,12]]) - ); - for( const numUpper of [3,4,-1,-2] ) { expectArrayEq( - la.bandPart(a,1,numUpper), - tf.tensor2d([[1, 2, 3, 4], - [5, 6, 7, 8], - [0,10,11,12]]) + la.bandPart(a,0,2), + tf.tensor2d([[1, 2, 3, 0], + [0, 6, 7, 8], + [0, 0,11,12]]) ); - } + for( const numUpper of [3,4,-1,-2] ) { + expectArrayEq( + la.bandPart(a,0,numUpper), + tf.tensor2d([[1, 2, 3, 4], + [0, 6, 7, 8], + [0, 0,11,12]]) + ); + } - for( const numLower of [2,3,-1,-2]) - { expectArrayEq( - la.bandPart(a,numLower,0), + la.bandPart(a,1,0), tf.tensor2d([[1, 0, 0, 0], [5, 6, 0, 0], - [9,10,11, 0]]) + [0,10,11, 0]]) ); expectArrayEq( - la.bandPart(a,numLower,1), + la.bandPart(a,1,1), tf.tensor2d([[1, 2, 0, 0], [5, 6, 7, 0], - [9,10,11,12]]) + [0,10,11,12]]) ); expectArrayEq( - la.bandPart(a,numLower,2), + la.bandPart(a,1,2), tf.tensor2d([[1, 2, 3, 0], [5, 6, 7, 8], - [9,10,11,12]]) + [0,10,11,12]]) ); expectArrayEq( - la.bandPart(a,numLower,2), + la.bandPart(a,1,2), tf.tensor2d([[1, 2, 3, 0], [5, 6, 7, 8], - [9,10,11,12]]) + [0,10,11,12]]) ); for( const numUpper of [3,4,-1,-2] ) { expectArrayEq( - la.bandPart(a,numLower,numUpper), + la.bandPart(a,1,numUpper), tf.tensor2d([[1, 2, 3, 4], + [5, 6, 7, 8], + [0,10,11,12]]) + ); + } + + for( const numLower of [2,3,-1,-2]) + { + expectArrayEq( + la.bandPart(a,numLower,0), + tf.tensor2d([[1, 0, 0, 0], + [5, 6, 0, 0], + [9,10,11, 0]]) + ); + expectArrayEq( + la.bandPart(a,numLower,1), + tf.tensor2d([[1, 2, 0, 0], + [5, 6, 7, 0], + [9,10,11,12]]) + ); + expectArrayEq( + la.bandPart(a,numLower,2), + tf.tensor2d([[1, 2, 3, 0], [5, 6, 7, 8], [9,10,11,12]]) ); + expectArrayEq( + la.bandPart(a,numLower,2), + tf.tensor2d([[1, 2, 3, 0], + [5, 6, 7, 8], + [9,10,11,12]]) + ); + for( const numUpper of [3,4,-1,-2] ) { + expectArrayEq( + la.bandPart(a,numLower,numUpper), + tf.tensor2d([[1, 2, 3, 4], + [5, 6, 7, 8], + [9,10,11,12]]) + ); + } } - } + }); }); -}); +} From 8c798e604ebf208e79e53a7b995f93d3b1440bbe Mon Sep 17 00:00:00 2001 From: Dirk Toewe Date: Sun, 3 Feb 2019 11:12:10 +0100 Subject: [PATCH 6/6] Fixed ENV used with describeWithFlags. --- src/ops/linalg_ops_test.ts | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/ops/linalg_ops_test.ts b/src/ops/linalg_ops_test.ts index 9da088202f..f6e731d7c0 100644 --- a/src/ops/linalg_ops_test.ts +++ b/src/ops/linalg_ops_test.ts @@ -244,12 +244,12 @@ describeWithFlags('qr', ALL_ENVS, () => { for( const ENV of [CPU_ENVS, WEBGL_ENVS] ) { - describeWithFlags('bandPart', ALL_ENVS, () => { - const la = tf.linalg; + const expectArrayEq = Object.is(ENV, CPU_ENVS) + ? expectArraysEqual + : expectArraysClose; - const expectArrayEq = Object.is(ENV, CPU_ENVS) - ? expectArraysEqual - : expectArraysClose; + describeWithFlags('bandPart', ENV, () => { + const la = tf.linalg; it('works for 3x4 example', () => { const a = tf.tensor2d([[1, 2, 3, 4],