From ed1b9aeece5cc1a6577c47ff411e6c322b2a86b9 Mon Sep 17 00:00:00 2001 From: Na Li Date: Mon, 8 Mar 2021 22:09:09 -0800 Subject: [PATCH] Add image.transform. (#4770) FEATURE --- tfjs-backend-cpu/src/kernels/Transform.ts | 251 ++++++++++++++++++ tfjs-backend-cpu/src/register_all_kernels.ts | 2 + tfjs-backend-webgl/src/kernels/Transform.ts | 48 ++++ .../src/register_all_kernels.ts | 2 + tfjs-backend-webgl/src/transform_gpu.ts | 165 ++++++++++++ tfjs-core/src/kernel_names.ts | 9 + tfjs-core/src/ops/image/transform.ts | 93 +++++++ tfjs-core/src/ops/image/transform_test.ts | 107 ++++++++ tfjs-core/src/ops/ops.ts | 4 +- tfjs-core/src/tests.ts | 1 + tfjs-node/src/run_tests.ts | 4 +- 11 files changed, 684 insertions(+), 2 deletions(-) create mode 100644 tfjs-backend-cpu/src/kernels/Transform.ts create mode 100644 tfjs-backend-webgl/src/kernels/Transform.ts create mode 100644 tfjs-backend-webgl/src/transform_gpu.ts create mode 100644 tfjs-core/src/ops/image/transform.ts create mode 100644 tfjs-core/src/ops/image/transform_test.ts diff --git a/tfjs-backend-cpu/src/kernels/Transform.ts b/tfjs-backend-cpu/src/kernels/Transform.ts new file mode 100644 index 00000000000..78e1ce0da43 --- /dev/null +++ b/tfjs-backend-cpu/src/kernels/Transform.ts @@ -0,0 +1,251 @@ +/** + * @license + * Copyright 2021 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {KernelConfig, KernelFunc, NumericDataType, TensorInfo, Transform, TransformAttrs, TransformInputs, TypedArray, util} from '@tensorflow/tfjs-core'; + +import {MathBackendCPU} from '../backend_cpu'; + +export function transform(args: { + inputs: TransformInputs, + attrs: TransformAttrs, + backend: MathBackendCPU +}): TensorInfo { + const {inputs, attrs, backend} = args; + const {image, transforms} = inputs; + const {interpolation, fillMode, fillValue, outputShape} = attrs; + + const [batch, imageHeight, imageWidth, numChannels] = image.shape; + const [outHeight, outWidth] = + outputShape != null ? outputShape : [imageHeight, imageWidth]; + const outShape = [batch, outHeight, outWidth, numChannels]; + + const strides = util.computeStrides(image.shape); + const batchStride = strides[0]; + const rowStride = strides[1]; + const colStride = strides[2]; + + const outVals = util.getTypedArrayFromDType( + image.dtype as NumericDataType, util.sizeFromShape(outShape)); + + outVals.fill(fillValue); + + const imageVals = backend.data.get(image.dataId).values as TypedArray; + const transformVals = + backend.data.get(transforms.dataId).values as TypedArray; + + // Ref TF implementation: + // https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/image/image_ops.h + for (let b = 0; b < batch; ++b) { + const transform = transforms.shape[0] === 1 ? + transformVals : + transformVals.subarray(b * 8, b * 8 + 8); + + for (let outY = 0; outY < outHeight; ++outY) { + for (let outX = 0; outX < outWidth; ++outX) { + for (let channel = 0; channel < numChannels; ++channel) { + let val; + + const projection = transform[6] * outX + transform[7] * outY + 1; + + if (projection === 0) { + // Return the fill value for infinite coordinates, + // which are outside the input image + continue; + } + + const inX = + (transform[0] * outX + transform[1] * outY + transform[2]) / + projection; + const inY = + (transform[3] * outX + transform[4] * outY + transform[5]) / + projection; + + const x = mapCoord(inX, imageWidth, fillMode); + const y = mapCoord(inY, imageHeight, fillMode); + + switch (interpolation) { + case 'nearest': + val = nearestInterpolation( + imageVals, imageHeight, imageWidth, batchStride, rowStride, + colStride, b, y, x, channel, fillValue); + break; + case 'bilinear': + val = bilinearInterpolation( + imageVals, imageHeight, imageWidth, batchStride, rowStride, + colStride, b, y, x, channel, fillValue); + break; + default: + throw new Error( + `Error in Transform: Expect 'nearest' or ` + + `'bilinear', but got ${interpolation}`); + } + + const ind = + b * batchStride + outY * rowStride + outX * colStride + channel; + + outVals[ind] = val; + } + } + } + + return backend.makeTensorInfo(outShape, image.dtype, outVals); + } + + const dataId = backend.write(outVals, outShape, image.dtype); + return {dataId, shape: image.shape, dtype: image.dtype}; +} + +export const transformConfig: KernelConfig = { + kernelName: Transform, + backendName: 'cpu', + kernelFunc: transform as {} as KernelFunc +}; + +function mapCoord( + outCoord: number, len: number, + mode: 'constant'|'reflect'|'wrap'|'nearest') { + switch (mode) { + case 'reflect': + return mapCoordReflect(outCoord, len); + case 'wrap': + return mapCoordWrap(outCoord, len); + case 'nearest': + return mapCoordNearest(outCoord, len); + case 'constant': + default: + return mapCoordConstant(outCoord, len); + } +} + +function mapCoordReflect(outCoord: number, len: number): number { + // Reflect [abcd] to [dcba|abcd|dcba]. + let inCoord = outCoord; + if (inCoord < 0) { + if (len <= 1) { + inCoord = 0; + } else { + const sz2 = 2 * len; + if (inCoord < sz2) { + inCoord = sz2 * Math.trunc(-inCoord / sz2) + inCoord; + } + inCoord = inCoord < -len ? inCoord + sz2 : -inCoord - 1; + } + } else if (inCoord > len - 1) { + if (len <= 1) { + inCoord = 0; + } else { + const sz2 = 2 * len; + inCoord -= sz2 * Math.trunc(inCoord / sz2); + if (inCoord >= len) { + inCoord = sz2 - inCoord - 1; + } + } + } + // clamp is necessary because when outCoord = 3.5 and len = 4, + // inCoord = 3.5 and will be rounded to 4 in nearest interpolation. + return util.clamp(0, inCoord, len - 1); +} + +function mapCoordWrap(outCoord: number, len: number): number { + // Wrap [abcd] to [abcd|abcd|abcd]. + let inCoord = outCoord; + if (inCoord < 0) { + if (len <= 1) { + inCoord = 0; + } else { + const sz = len - 1; + inCoord += len * (Math.trunc(-inCoord / sz) + 1); + } + } else if (inCoord > len - 1) { + if (len <= 1) { + inCoord = 0; + } else { + const sz = len - 1; + inCoord -= len * Math.trunc(inCoord / sz); + } + } + // clamp is necessary because when outCoord = -0.5 and len = 4, + // inCoord = 3.5 and will be rounded to 4 in nearest interpolation. + return util.clamp(0, inCoord, len - 1); +} + +function mapCoordConstant(outCoord: number, len: number): number { + return outCoord; +} + +function mapCoordNearest(outCoord: number, len: number): number { + return util.clamp(0, outCoord, len - 1); +} + +function readWithFillValue( + imageVals: TypedArray, imageHeight: number, imageWidth: number, + batchStride: number, rowStride: number, colStride: number, batch: number, + y: number, x: number, channel: number, fillValue: number): number { + const ind = batch * batchStride + y * rowStride + x * colStride + channel; + if (0 <= y && y < imageHeight && 0 <= x && x < imageWidth) { + return imageVals[ind]; + } else { + return fillValue; + } +} + +function nearestInterpolation( + imageVals: TypedArray, imageHeight: number, imageWidth: number, + batchStride: number, rowStride: number, colStride: number, batch: number, + y: number, x: number, channel: number, fillValue: number): number { + const $y = Math.round(y); + const $x = Math.round(x); + + return readWithFillValue( + imageVals, imageHeight, imageWidth, batchStride, rowStride, colStride, + batch, $y, $x, channel, fillValue); +} + +function bilinearInterpolation( + imageVals: TypedArray, imageHeight: number, imageWidth: number, + batchStride: number, rowStride: number, colStride: number, batch: number, + y: number, x: number, channel: number, fillValue: number) { + const yFloor = Math.floor(y); + const xFloor = Math.floor(x); + const yCeil = yFloor + 1; + const xCeil = xFloor + 1; + // f(x, yFloor) = (xCeil - x) / (xCeil - xFloor) * f(xFloor, yFloor) + // + (x - xFloor) / (xCeil - xFloor) * f(xCeil, yFloor) + const valueYFloor = + (xCeil - x) * + readWithFillValue( + imageVals, imageHeight, imageWidth, batchStride, rowStride, + colStride, batch, yFloor, xFloor, channel, fillValue) + + (x - xFloor) * + readWithFillValue( + imageVals, imageHeight, imageWidth, batchStride, rowStride, + colStride, batch, yFloor, xCeil, channel, fillValue); + // f(x, yCeil) = (xCeil - x) / (xCeil - xFloor) * f(xFloor, yCeil) + // + (x - xFloor) / (xCeil - xFloor) * f(xCeil, yCeil) + const valueYCeil = + (xCeil - x) * + readWithFillValue( + imageVals, imageHeight, imageWidth, batchStride, rowStride, + colStride, batch, yCeil, xFloor, channel, fillValue) + + (x - xFloor) * + readWithFillValue( + imageVals, imageHeight, imageWidth, batchStride, rowStride, + colStride, batch, yCeil, xCeil, channel, fillValue); + // f(x, y) = (yCeil - y) / (yCeil - yFloor) * f(x, yFloor) + // + (y - yFloor) / (yCeil - yFloor) * f(x, yCeil) + return (yCeil - y) * valueYFloor + (y - yFloor) * valueYCeil; +} diff --git a/tfjs-backend-cpu/src/register_all_kernels.ts b/tfjs-backend-cpu/src/register_all_kernels.ts index d4a5247d616..e874af5ee96 100644 --- a/tfjs-backend-cpu/src/register_all_kernels.ts +++ b/tfjs-backend-cpu/src/register_all_kernels.ts @@ -167,6 +167,7 @@ import {tanConfig} from './kernels/Tan'; import {tanhConfig} from './kernels/Tanh'; import {tileConfig} from './kernels/Tile'; import {topKConfig} from './kernels/TopK'; +import {transformConfig} from './kernels/Transform'; import {transposeConfig} from './kernels/Transpose'; import {uniqueConfig} from './kernels/Unique'; import {unpackConfig} from './kernels/Unpack'; @@ -324,6 +325,7 @@ const kernelConfigs: KernelConfig[] = [ tileConfig, topKConfig, transposeConfig, + transformConfig, uniqueConfig, unpackConfig, unsortedSegmentSumConfig, diff --git a/tfjs-backend-webgl/src/kernels/Transform.ts b/tfjs-backend-webgl/src/kernels/Transform.ts new file mode 100644 index 00000000000..45e84e81ba3 --- /dev/null +++ b/tfjs-backend-webgl/src/kernels/Transform.ts @@ -0,0 +1,48 @@ +/** + * @license + * Copyright 2021 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {KernelConfig, KernelFunc, TensorInfo, Transform, TransformAttrs, TransformInputs} from '@tensorflow/tfjs-core'; + +import {MathBackendWebGL} from '../backend_webgl'; +import {TransformProgram} from '../transform_gpu'; + +export function transform(args: { + inputs: TransformInputs, + backend: MathBackendWebGL, + attrs: TransformAttrs +}): TensorInfo { + const {inputs, backend, attrs} = args; + const {image, transforms} = inputs; + const {interpolation, fillMode, fillValue, outputShape} = attrs; + + const [batch, imageHeight, imageWidth, numChannels] = image.shape; + const [outHeight, outWidth] = + outputShape != null ? outputShape : [imageHeight, imageWidth]; + const outShape = + [batch, outHeight, outWidth, + numChannels] as [number, number, number, number]; + + const program = new TransformProgram( + imageHeight, imageWidth, interpolation, fillMode, fillValue, outShape); + return backend.runWebGLProgram(program, [image, transforms], 'float32'); +} + +export const transformConfig: KernelConfig = { + kernelName: Transform, + backendName: 'webgl', + kernelFunc: transform as {} as KernelFunc +}; diff --git a/tfjs-backend-webgl/src/register_all_kernels.ts b/tfjs-backend-webgl/src/register_all_kernels.ts index 1647b91c2e5..d123acc701a 100644 --- a/tfjs-backend-webgl/src/register_all_kernels.ts +++ b/tfjs-backend-webgl/src/register_all_kernels.ts @@ -163,6 +163,7 @@ import {tanConfig} from './kernels/Tan'; import {tanhConfig} from './kernels/Tanh'; import {tileConfig} from './kernels/Tile'; import {topKConfig} from './kernels/TopK'; +import {transformConfig} from './kernels/Transform'; import {transposeConfig} from './kernels/Transpose'; import {uniqueConfig} from './kernels/Unique'; import {unpackConfig} from './kernels/Unpack'; @@ -318,6 +319,7 @@ const kernelConfigs: KernelConfig[] = [ tanhConfig, tileConfig, topKConfig, + transformConfig, transposeConfig, uniqueConfig, unpackConfig, diff --git a/tfjs-backend-webgl/src/transform_gpu.ts b/tfjs-backend-webgl/src/transform_gpu.ts new file mode 100644 index 00000000000..8b364933c56 --- /dev/null +++ b/tfjs-backend-webgl/src/transform_gpu.ts @@ -0,0 +1,165 @@ +/** + * @license + * Copyright 2021 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {GPGPUProgram} from './gpgpu_math'; + +export class TransformProgram implements GPGPUProgram { + variableNames = ['Image', 'Transforms']; + outputShape: number[]; + userCode: string; + + constructor( + imageHeight: number, imageWidth: number, + interpolation: 'nearest'|'bilinear', + fillMode: 'constant'|'reflect'|'wrap'|'nearest', fillValue: number, + outShape: [number, number, number, number]) { + this.outputShape = outShape; + const interpolationModeId = interpolation === 'nearest' ? 1 : 2; + let fillModeId; + switch (fillMode) { + case 'constant': + fillModeId = 1; + break; + case 'reflect': + fillModeId = 2; + break; + case 'wrap': + fillModeId = 3; + break; + case 'nearest': + fillModeId = 4; + break; + default: + fillModeId = 1; + break; + } + this.userCode = ` + float mapCoord(float outCoord, float len) { + float inCoord = outCoord; + if(${fillModeId} == 2) { + if (inCoord < 0.0) { + if (len <= 1.0) { + inCoord = 0.0; + } else { + float sz2 = 2.0 * len; + if (inCoord < sz2) { + inCoord = sz2 * float(int(float(-inCoord / sz2))) + + inCoord; + } + inCoord = inCoord < -len ? inCoord + sz2 : -inCoord - 1.0; + } + } else if (inCoord > len - 1.0) { + if (len <= 1.0) { + inCoord = 0.0; + } else { + float sz2 = 2.0 * len; + inCoord -= sz2 * float(int(float(inCoord / sz2))); + if (inCoord >= len) { + inCoord = sz2 - inCoord - 1.0; + } + } + } + return clamp(inCoord, 0.0, len - 1.0); + } else if (${fillModeId} == 3) { + if (inCoord < 0.0) { + if (len <= 1.0) { + inCoord = 0.0; + } else { + float sz = len - 1.0; + inCoord += len * (float(int(float(-inCoord / sz))) + 1.0); + } + } else if (inCoord > len - 1.0) { + if (len <= 1.0) { + inCoord = 0.0; + } else { + float sz = len - 1.0; + inCoord -= len * float(int(float(inCoord / sz))); + } + } + return clamp(inCoord, 0.0, len - 1.0); + } else if (${fillModeId} == 4) { + return clamp(outCoord, 0.0, len - 1.0); + } else { + return outCoord; + } + } + + float readWithFillValue(int batch, int coordY, int coordX, + int channel) { + float outputValue; + if (0 <= coordY && coordY < ${ + imageHeight} && 0 <= coordX && coordX < ${imageWidth}) { + outputValue = getImage(batch, coordY, coordX, channel); + } else { + outputValue = float(${fillValue}); + } + return outputValue; + } + + void main() { + ivec4 coords = getOutputCoords(); + float outputValue; + int batch = coords[0]; + int x = coords[2]; + int y = coords[1]; + int channel = coords[3]; + float xf = float(x); + float yf = float(y); + float a1 = getTransforms(batch, 0); + float a2 = getTransforms(batch, 1); + float a3 = getTransforms(batch, 2); + float b1 = getTransforms(batch, 3); + float b2 = getTransforms(batch, 4); + float b3 = getTransforms(batch, 5); + float c1 = getTransforms(batch, 6); + float c2 = getTransforms(batch, 7); + float projection = c1 * xf + c2 * yf + 1.0; + if (projection == 0.0) { + outputValue = float(${fillValue}); + } else { + float inX = (a1 * xf + a2 * yf + a3) / projection; + float inY = (b1 * xf + b2 * yf + b3) / projection; + float mapX = mapCoord(inX, float(${imageWidth})); + float mapY = mapCoord(inY, float(${imageHeight})); + + if (${interpolationModeId} == 1) { + int coordY = int(round(mapY)); + int coordX = int(round(mapX)); + outputValue = readWithFillValue(batch, coordY, coordX, + channel); + } else { + float yFloor = floor(mapY); + float xFloor = floor(mapX); + float yCeil = yFloor + 1.0; + float xCeil = xFloor + 1.0; + float valueYFloor = (xCeil - mapX) * + readWithFillValue(batch, int(yFloor), int(xFloor), channel) + + (mapX - xFloor) * + readWithFillValue(batch, int(yFloor), int(xCeil), channel); + float valueYCeil = (xCeil - mapX) * + readWithFillValue(batch, int(yCeil), int(xFloor), channel) + + (mapX - xFloor) * + readWithFillValue(batch, int(yCeil), int(xCeil), channel); + outputValue = (yCeil - mapY) * valueYFloor + + (mapY - yFloor) * valueYCeil; + } + } + setOutput(outputValue); + } + `; + } +} diff --git a/tfjs-core/src/kernel_names.ts b/tfjs-core/src/kernel_names.ts index a6303814e6c..05aa5242437 100644 --- a/tfjs-core/src/kernel_names.ts +++ b/tfjs-core/src/kernel_names.ts @@ -808,6 +808,15 @@ export interface TopKAttrs { sorted: boolean; } +export const Transform = 'Transform'; +export type TransformInputs = Pick; +export interface TransformAttrs { + interpolation: 'nearest'|'bilinear'; + fillMode: 'constant'|'reflect'|'wrap'|'nearest'; + fillValue: number; + outputShape?: [number, number]; +} + export const Transpose = 'Transpose'; export type TransposeInputs = Pick; export interface TransposeAttrs { diff --git a/tfjs-core/src/ops/image/transform.ts b/tfjs-core/src/ops/image/transform.ts new file mode 100644 index 00000000000..519f92a5c7a --- /dev/null +++ b/tfjs-core/src/ops/image/transform.ts @@ -0,0 +1,93 @@ +/** + * @license + * Copyright 2021 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ +import {ENGINE} from '../../engine'; +import {Transform, TransformAttrs, TransformInputs} from '../../kernel_names'; +import {NamedAttrMap} from '../../kernel_registry'; +import {Tensor2D, Tensor4D} from '../../tensor'; +import {NamedTensorMap} from '../../tensor_types'; +import {convertToTensor} from '../../tensor_util_env'; +import {TensorLike} from '../../types'; +import * as util from '../../util'; + +import {op} from '../operation'; + +/** + * Applies the given transform(s) to the image(s). + * + * @param image 4d tensor of shape `[batch, imageHeight, imageWidth, depth]`. + * @param transforms Projective transform matrix/matrices. A tensor1d of length + * 8 or tensor of size N x 8. If one row of transforms is [a0, a1, a2, b0 + * b1, b2, c0, c1], then it maps the output point (x, y) to a transformed + * input point (x', y') = ((a0 x + a1 y + a2) / k, (b0 x + b1 y + b2) / k), + * where k = c0 x + c1 y + 1. The transforms are inverted compared to the + * transform mapping input points to output points. + * @param interpolation Interpolation mode. + * Supported values: 'nearest', 'bilinear'. Default to 'nearest'. + * @param fillMode Points outside the boundaries of the input are filled + * according to the given mode, one of 'constant', 'reflect', 'wrap', + * 'nearest'. Default to 'constant'. + * 'reflect': (d c b a | a b c d | d c b a ) The input is extended by + * reflecting about the edge of the last pixel. + * 'constant': (k k k k | a b c d | k k k k) The input is extended by + * filling all values beyond the edge with the same constant value k. + * 'wrap': (a b c d | a b c d | a b c d) The input is extended by + * wrapping around to the opposite edge. + * 'nearest': (a a a a | a b c d | d d d d) The input is extended by + * the nearest pixel. + * @param fillValue A float represents the value to be filled outside the + * boundaries when fillMode is 'constant'. + * @param Output dimension after the transform, [height, width]. If undefined, + * output is the same size as input image. + * + * @doc {heading: 'Operations', subheading: 'Images', namespace: 'image'} + */ +function transform_( + image: Tensor4D|TensorLike, transforms: Tensor2D|TensorLike, + interpolation: 'nearest'|'bilinear' = 'nearest', + fillMode: 'constant'|'reflect'|'wrap'|'nearest' = 'constant', fillValue = 0, + outputShape?: [number, number]): Tensor4D { + const $image = convertToTensor(image, 'image', 'transform', 'float32'); + const $transforms = + convertToTensor(transforms, 'transforms', 'transform', 'float32'); + + util.assert( + $image.rank === 4, + () => 'Error in transform: image must be rank 4,' + + `but got rank ${$image.rank}.`); + + util.assert( + $transforms.rank === 2 && + ($transforms.shape[0] === $image.shape[0] || + $transforms.shape[0] === 1) && + $transforms.shape[1] === 8, + () => `Error in transform: Input transform should be batch x 8 or 1 x 8`); + + util.assert( + outputShape == null || outputShape.length === 2, + () => + 'Error in transform: outputShape must be [height, width] or null, ' + + `but got ${outputShape}.`); + + const inputs: TransformInputs = {image: $image, transforms: $transforms}; + const attrs: + TransformAttrs = {interpolation, fillMode, fillValue, outputShape}; + + return ENGINE.runKernel( + Transform, inputs as {} as NamedTensorMap, attrs as {} as NamedAttrMap); +} + +export const transform = op({transform_}); diff --git a/tfjs-core/src/ops/image/transform_test.ts b/tfjs-core/src/ops/image/transform_test.ts new file mode 100644 index 00000000000..6cde442456c --- /dev/null +++ b/tfjs-core/src/ops/image/transform_test.ts @@ -0,0 +1,107 @@ +/** + * @license + * Copyright 2021 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ +import * as tf from '../../index'; +import {ALL_ENVS, describeWithFlags} from '../../jasmine_util'; +import {expectArraysClose} from '../../test_util'; + +describeWithFlags('image.transform', ALL_ENVS, () => { + it('extreme projective transform.', async () => { + const images = tf.tensor4d( + [1, 0, 1, 0, 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 1], [1, 4, 4, 1]); + const transform = tf.tensor2d([1, 0, 0, 0, 1, 0, -1, 0], [1, 8]); + const transformedImages = tf.image.transform(images, transform).toInt(); + const transformedImagesData = await transformedImages.data(); + + const expected = [1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0]; + + expectArraysClose(expected, transformedImagesData); + }); + + it('static output shape.', async () => { + const images = tf.tensor4d([1, 2, 3, 4], [1, 2, 2, 1]); + const transform = tf.randomUniform([1, 8], -1, 1); + const transformedImages = tf.image.transform( + images, transform as tf.Tensor2D, 'nearest', 'constant', 0, [3, 5]); + + expectArraysClose(transformedImages.shape, [1, 3, 5, 1]); + }); + + it('fill=constant, interpolation=nearest.', async () => { + const images = tf.tensor4d( + [1, 1, 1, 0, 0, 1, 0, 0, 1, 0, 1, 0, 1, 1, 0, 0], [1, 4, 4, 1]); + const transform = tf.tensor2d([0, 0.5, 1, -1, 2, 3, 0, 0], [1, 8]); + const transformedImages = tf.image.transform(images, transform); + const transformedImagesData = await transformedImages.data(); + + const expected = [1, 0, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0]; + + expectArraysClose(expected, transformedImagesData); + }); + + it('fill=constant, interpolation=bilinear.', async () => { + const images = tf.tensor4d( + [1, 1, 1, 0, 0, 1, 0, 0, 1, 0, 1, 0, 1, 1, 0, 0], [1, 4, 4, 1]); + const transform = tf.tensor2d([0, 0.5, 1, -1, 2, 3, 0, 0], [1, 8]); + const transformedImages = tf.image.transform(images, transform, 'bilinear'); + const transformedImagesData = await transformedImages.data(); + + const expected = [1, 0, 1, 1, 0, 0, 0.5, 0.5, 0, 0, 0, 0, 0, 0, 0, 0]; + + expectArraysClose(expected, transformedImagesData); + }); + + it('fill=reflect, interpolation=bilinear.', async () => { + const images = tf.tensor4d( + [1, 1, 1, 0, 0, 1, 0, 0, 1, 0, 1, 0, 1, 1, 0, 0], [1, 4, 4, 1]); + const transform = tf.tensor2d([0, 0.5, 1, -1, 2, 3, 0, 0], [1, 8]); + const transformedImages = + tf.image.transform(images, transform, 'bilinear', 'reflect'); + const transformedImagesData = await transformedImages.data(); + + const expected = + [1, 0, 1, 1, 0.5, 0.5, 0.5, 0.5, 1, 0, 1, 0, 0, 0.5, 0.5, 0]; + + expectArraysClose(expected, transformedImagesData); + }); + + it('fill=wrap, interpolation=bilinear.', async () => { + const images = tf.tensor4d( + [1, 1, 1, 0, 0, 1, 0, 0, 1, 0, 1, 0, 1, 1, 0, 0], [1, 4, 4, 1]); + const transform = tf.tensor2d([0, 0.5, 1, -1, 2, 3, 0, 0], [1, 8]); + const transformedImages = + tf.image.transform(images, transform, 'bilinear', 'wrap'); + const transformedImagesData = await transformedImages.data(); + + const expected = + [1, 0, 1, 1, 0.5, 1, 0.5, 0.5, 1, 1, 0, 1, 0.5, 0.5, 0.5, 0.5]; + + expectArraysClose(expected, transformedImagesData); + }); + + it('fill=nearest, interpolation=bilinear.', async () => { + const images = tf.tensor4d( + [1, 1, 1, 0, 0, 1, 0, 0, 1, 0, 1, 0, 1, 1, 0, 0], [1, 4, 4, 1]); + const transform = tf.tensor2d([0, 0.5, 1, -1, 2, 3, 0, 0], [1, 8]); + const transformedImages = + tf.image.transform(images, transform, 'bilinear', 'nearest'); + const transformedImagesData = await transformedImages.data(); + + const expected = [1, 0, 1, 1, 0.5, 0.5, 0.5, 0.5, 0, 0, 0, 0, 0, 0, 0, 0]; + + expectArraysClose(expected, transformedImagesData); + }); +}); diff --git a/tfjs-core/src/ops/ops.ts b/tfjs-core/src/ops/ops.ts index b6d8cf59c5a..627726edaeb 100644 --- a/tfjs-core/src/ops/ops.ts +++ b/tfjs-core/src/ops/ops.ts @@ -246,6 +246,7 @@ import {nonMaxSuppressionPadded} from './image/non_max_suppression_padded'; import {nonMaxSuppressionPaddedAsync} from './image/non_max_suppression_padded_async'; import {resizeBilinear} from './image/resize_bilinear'; import {resizeNearestNeighbor} from './image/resize_nearest_neighbor'; +import {transform} from './image/transform'; const image = { flipLeftRight, resizeNearestNeighbor, @@ -257,7 +258,8 @@ const image = { nonMaxSuppressionWithScore, nonMaxSuppressionWithScoreAsync, nonMaxSuppressionPadded, - nonMaxSuppressionPaddedAsync + nonMaxSuppressionPaddedAsync, + transform }; // linalg namespace diff --git a/tfjs-core/src/tests.ts b/tfjs-core/src/tests.ts index ff95069b4ef..8eaeecf95a6 100644 --- a/tfjs-core/src/tests.ts +++ b/tfjs-core/src/tests.ts @@ -113,6 +113,7 @@ import './ops/image/non_max_suppression_test'; import './ops/image/resize_bilinear_test'; import './ops/image/resize_nearest_neighbor_test'; import './ops/image/rotate_with_offset_test'; +import './ops/image/transform_test'; import './ops/in_top_k_test'; import './ops/is_finite_test'; import './ops/is_inf_test'; diff --git a/tfjs-node/src/run_tests.ts b/tfjs-node/src/run_tests.ts index c7cdd03efb6..55099da0d64 100644 --- a/tfjs-node/src/run_tests.ts +++ b/tfjs-node/src/run_tests.ts @@ -91,7 +91,9 @@ const IGNORE_LIST: string[] = [ // tslint:disable-next-line:max-line-length 'pool test-tensorflow {} avg x=[2,2,3] f=[1,1] s=2 p=1 fractional outputs default rounding', // not available in tf yet. - 'denseBincount' + 'denseBincount', + // only available in tf addon. + 'image.transform' ]; if (process.platform === 'win32') {