Skip to content

Commit

Permalink
[webgl] Rotate performance fix (#4784)
Browse files Browse the repository at this point in the history
PERF
* fixed pad unpacked shader failure

* use uniform to improve rotate performance to avoid shader recompilation due to center and radian change

Co-authored-by: Na Li <[email protected]>
  • Loading branch information
pyu10055 and lina128 authored Mar 8, 2021
1 parent 42c7f25 commit 9effdc8
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 23 deletions.
12 changes: 8 additions & 4 deletions tfjs-backend-webgl/src/kernels/RotateWithOffset.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
* =============================================================================
*/

import {KernelConfig, Tensor4D} from '@tensorflow/tfjs-core';
import {backend_util, KernelConfig, Tensor4D} from '@tensorflow/tfjs-core';
import {RotateWithOffset, RotateWithOffsetAttrs, RotateWithOffsetInputs} from '@tensorflow/tfjs-core';

import {MathBackendWebGL} from '../backend_webgl';
Expand All @@ -29,9 +29,13 @@ export const rotateWithOffsetConfig: KernelConfig = {
const {radians, fillValue, center} = attrs as {} as RotateWithOffsetAttrs;
const webglBackend = backend as MathBackendWebGL;

const program = new RotateProgram(
(image as Tensor4D).shape, radians, fillValue, center);
const output = webglBackend.runWebGLProgram(program, [image], image.dtype);
const program = new RotateProgram((image as Tensor4D).shape, fillValue);
const [centerX, centerY] =
backend_util.getImageCenter(center, image.shape[1], image.shape[2]);
const customSetup = program.getCustomSetupFunc(
centerX, centerY, Math.sin(radians), Math.cos(radians));
const output = webglBackend.runWebGLProgram(
program, [image], image.dtype, customSetup);
return output;
}
};
42 changes: 23 additions & 19 deletions tfjs-backend-webgl/src/rotate_gpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,30 +15,21 @@
* =============================================================================
*/

import {backend_util} from '@tensorflow/tfjs-core';

import {GPGPUContext} from './gpgpu_context';
import {GPGPUProgram} from './gpgpu_math';

export class RotateProgram implements GPGPUProgram {
variableNames = ['Image'];
outputShape: number[] = [];
userCode: string;

paramsLoc: WebGLUniformLocation;
constructor(
imageShape: [number, number, number, number], radians: number,
fillValue: number|[number, number, number],
center: number|[number, number]) {
imageShape: [number, number, number, number],
fillValue: number|[number, number, number]) {
const imageHeight = imageShape[1];
const imageWidth = imageShape[2];
const sinFactor = Math.sin(radians).toFixed(3);
const cosFactor = Math.cos(radians).toFixed(3);
this.outputShape = imageShape;

const [centerX, centerY] =
backend_util.getImageCenter(center, imageHeight, imageWidth);
const centerXString = centerX.toFixed(3);
const centerYString = centerY.toFixed(3);

let fillSnippet = '';
if (typeof fillValue === 'number') {
fillSnippet = `float outputValue = ${fillValue.toFixed(2)};`;
Expand All @@ -49,16 +40,17 @@ export class RotateProgram implements GPGPUProgram {
}

this.userCode = `
uniform vec4 params;
void main() {
ivec4 coords = getOutputCoords();
int x = coords[2];
int y = coords[1];
float coordXFloat = (float(x) - ${centerXString}) * ${
cosFactor} - (float(y) - ${centerYString}) * ${sinFactor};
float coordYFloat = (float(x) - ${centerXString}) * ${
sinFactor} + (float(y) - ${centerYString}) * ${cosFactor};
int coordX = int(round(coordXFloat + ${centerXString}));
int coordY = int(round(coordYFloat + ${centerYString}));
float coordXFloat = (float(x) - params[0]) * params[3] -
(float(y) - params[1]) * params[2];
float coordYFloat = (float(x) - params[0]) * params[2] +
(float(y) - params[1]) * params[3];
int coordX = int(round(coordXFloat + params[0]));
int coordY = int(round(coordYFloat + params[1]));
${fillSnippet}
if(coordX >= 0 && coordX < ${imageWidth} && coordY >= 0 && coordY < ${
imageHeight}) {
Expand All @@ -68,4 +60,16 @@ export class RotateProgram implements GPGPUProgram {
}
`;
}

getCustomSetupFunc(
centerX: number, centerY: number, sinFactor: number, cosFactor: number) {
return (gpgpu: GPGPUContext, webGLProgram: WebGLProgram) => {
if (this.paramsLoc == null) {
this.paramsLoc =
gpgpu.getUniformLocationNoThrow(webGLProgram, 'params');
}
gpgpu.gl.uniform4f(
this.paramsLoc, centerX, centerY, sinFactor, cosFactor);
};
}
}

0 comments on commit 9effdc8

Please sign in to comment.