diff --git a/tfjs-backend-webgl/src/kernels/RotateWithOffset.ts b/tfjs-backend-webgl/src/kernels/RotateWithOffset.ts index cc50e954900..5383f13d704 100644 --- a/tfjs-backend-webgl/src/kernels/RotateWithOffset.ts +++ b/tfjs-backend-webgl/src/kernels/RotateWithOffset.ts @@ -15,7 +15,7 @@ * ============================================================================= */ -import {KernelConfig, Tensor4D} from '@tensorflow/tfjs-core'; +import {backend_util, KernelConfig, Tensor4D} from '@tensorflow/tfjs-core'; import {RotateWithOffset, RotateWithOffsetAttrs, RotateWithOffsetInputs} from '@tensorflow/tfjs-core'; import {MathBackendWebGL} from '../backend_webgl'; @@ -29,9 +29,13 @@ export const rotateWithOffsetConfig: KernelConfig = { const {radians, fillValue, center} = attrs as {} as RotateWithOffsetAttrs; const webglBackend = backend as MathBackendWebGL; - const program = new RotateProgram( - (image as Tensor4D).shape, radians, fillValue, center); - const output = webglBackend.runWebGLProgram(program, [image], image.dtype); + const program = new RotateProgram((image as Tensor4D).shape, fillValue); + const [centerX, centerY] = + backend_util.getImageCenter(center, image.shape[1], image.shape[2]); + const customSetup = program.getCustomSetupFunc( + centerX, centerY, Math.sin(radians), Math.cos(radians)); + const output = webglBackend.runWebGLProgram( + program, [image], image.dtype, customSetup); return output; } }; diff --git a/tfjs-backend-webgl/src/rotate_gpu.ts b/tfjs-backend-webgl/src/rotate_gpu.ts index 62357e0eacc..4a35ca8abec 100644 --- a/tfjs-backend-webgl/src/rotate_gpu.ts +++ b/tfjs-backend-webgl/src/rotate_gpu.ts @@ -15,30 +15,21 @@ * ============================================================================= */ -import {backend_util} from '@tensorflow/tfjs-core'; - +import {GPGPUContext} from './gpgpu_context'; import {GPGPUProgram} from './gpgpu_math'; export class RotateProgram implements GPGPUProgram { variableNames = ['Image']; outputShape: number[] = []; userCode: string; - + paramsLoc: WebGLUniformLocation; constructor( - imageShape: [number, number, number, number], radians: number, - fillValue: number|[number, number, number], - center: number|[number, number]) { + imageShape: [number, number, number, number], + fillValue: number|[number, number, number]) { const imageHeight = imageShape[1]; const imageWidth = imageShape[2]; - const sinFactor = Math.sin(radians).toFixed(3); - const cosFactor = Math.cos(radians).toFixed(3); this.outputShape = imageShape; - const [centerX, centerY] = - backend_util.getImageCenter(center, imageHeight, imageWidth); - const centerXString = centerX.toFixed(3); - const centerYString = centerY.toFixed(3); - let fillSnippet = ''; if (typeof fillValue === 'number') { fillSnippet = `float outputValue = ${fillValue.toFixed(2)};`; @@ -49,16 +40,17 @@ export class RotateProgram implements GPGPUProgram { } this.userCode = ` + uniform vec4 params; void main() { ivec4 coords = getOutputCoords(); int x = coords[2]; int y = coords[1]; - float coordXFloat = (float(x) - ${centerXString}) * ${ - cosFactor} - (float(y) - ${centerYString}) * ${sinFactor}; - float coordYFloat = (float(x) - ${centerXString}) * ${ - sinFactor} + (float(y) - ${centerYString}) * ${cosFactor}; - int coordX = int(round(coordXFloat + ${centerXString})); - int coordY = int(round(coordYFloat + ${centerYString})); + float coordXFloat = (float(x) - params[0]) * params[3] - + (float(y) - params[1]) * params[2]; + float coordYFloat = (float(x) - params[0]) * params[2] + + (float(y) - params[1]) * params[3]; + int coordX = int(round(coordXFloat + params[0])); + int coordY = int(round(coordYFloat + params[1])); ${fillSnippet} if(coordX >= 0 && coordX < ${imageWidth} && coordY >= 0 && coordY < ${ imageHeight}) { @@ -68,4 +60,16 @@ export class RotateProgram implements GPGPUProgram { } `; } + + getCustomSetupFunc( + centerX: number, centerY: number, sinFactor: number, cosFactor: number) { + return (gpgpu: GPGPUContext, webGLProgram: WebGLProgram) => { + if (this.paramsLoc == null) { + this.paramsLoc = + gpgpu.getUniformLocationNoThrow(webGLProgram, 'params'); + } + gpgpu.gl.uniform4f( + this.paramsLoc, centerX, centerY, sinFactor, cosFactor); + }; + } }