diff --git a/tfjs-backend-webgpu/src/backend_webgpu.ts b/tfjs-backend-webgpu/src/backend_webgpu.ts index 69e6a04e7b1..85b14a3550a 100644 --- a/tfjs-backend-webgpu/src/backend_webgpu.ts +++ b/tfjs-backend-webgpu/src/backend_webgpu.ts @@ -474,16 +474,16 @@ export class WebGPUBackend extends KernelBackend { public runWebGPUProgram( program: webgpu_program.WebGPUProgram, inputs: TensorInfo[], - outputDtype: DataType, programUniforms?: number[]): TensorInfo { + outputDtype: DataType, + programUniforms?: Uint32Array|Int32Array|Float32Array): TensorInfo { const output = this.makeTensorInfo(program.outputShape, outputDtype); let uniformDataLength; let uniforms: GPUBindingResource; if (program.uniforms) { // TODO: handle padding of program-specific uniforms - const uniformData = new Int32Array(programUniforms); - uniformDataLength = uniformData.byteLength; - uniforms = this.makeUniforms(uniformData); + uniformDataLength = programUniforms.byteLength; + uniforms = this.makeUniforms(programUniforms); } const inputsData = inputs.map((input: TensorInfo, i: number) => { @@ -594,7 +594,8 @@ export class WebGPUBackend extends KernelBackend { return timeElapsedNanos / 1000000; } - private makeUniforms(data: Uint32Array|Int32Array): GPUBindingResource { + private makeUniforms(data: Uint32Array|Int32Array| + Float32Array): GPUBindingResource { const dimensionsBuffer = this.acquireBuffer( data.byteLength, GPUBufferUsage.COPY_DST | GPUBufferUsage.UNIFORM); this.queue.writeBuffer(dimensionsBuffer, 0, data); diff --git a/tfjs-backend-webgpu/src/kernels/ArgMax.ts b/tfjs-backend-webgpu/src/kernels/ArgMax.ts index 91c35db94b2..daeb23f64f8 100644 --- a/tfjs-backend-webgpu/src/kernels/ArgMax.ts +++ b/tfjs-backend-webgpu/src/kernels/ArgMax.ts @@ -40,7 +40,8 @@ export function argMax( backend_util.assertAxesAreInnerMostDims('argMax', [axes[0]], $x.shape.length); const program = new ArgMinMaxProgram($x.shape, axes[0], 'max'); - const out = backend.runWebGPUProgram(program, [$x], 'int32', [axes[0]]); + const uniformData = new Int32Array([axes[0]]); + const out = backend.runWebGPUProgram(program, [$x], 'int32', uniformData); intermediateTensorInfos.forEach(t => backend.disposeData(t.dataId)); return out; } diff --git a/tfjs-backend-webgpu/src/kernels/ArgMin.ts b/tfjs-backend-webgpu/src/kernels/ArgMin.ts index d5f4eb05887..44810c26dd4 100644 --- a/tfjs-backend-webgpu/src/kernels/ArgMin.ts +++ b/tfjs-backend-webgpu/src/kernels/ArgMin.ts @@ -40,7 +40,8 @@ export function argMin( backend_util.assertAxesAreInnerMostDims('argMin', [axes[0]], $x.shape.length); const program = new ArgMinMaxProgram($x.shape, axes[0], 'min'); - const out = backend.runWebGPUProgram(program, [$x], 'int32', [axes[0]]); + const uniformData = new Int32Array([axes[0]]); + const out = backend.runWebGPUProgram(program, [$x], 'int32', uniformData); intermediateTensorInfos.forEach(t => backend.disposeData(t.dataId)); return out; } diff --git a/tfjs-backend-webgpu/src/kernels/AvgPool.ts b/tfjs-backend-webgpu/src/kernels/AvgPool.ts index 04d15e03bc5..ae6d63e536c 100644 --- a/tfjs-backend-webgpu/src/kernels/AvgPool.ts +++ b/tfjs-backend-webgpu/src/kernels/AvgPool.ts @@ -52,8 +52,8 @@ export function avgPool( convInfo.effectiveFilterWidth, convInfo.effectiveFilterHeight // Filter dims. ]; - - return backend.runWebGPUProgram(program, [x], x.dtype, dimensions); + const uniformData = new Int32Array(dimensions); + return backend.runWebGPUProgram(program, [x], x.dtype, uniformData); } export const avgPoolConfig: KernelConfig = { diff --git a/tfjs-backend-webgpu/src/kernels/Conv2D.ts b/tfjs-backend-webgpu/src/kernels/Conv2D.ts index 7a4e7b54969..0038fe01cc3 100644 --- a/tfjs-backend-webgpu/src/kernels/Conv2D.ts +++ b/tfjs-backend-webgpu/src/kernels/Conv2D.ts @@ -70,8 +70,9 @@ export function conv2d( convInfo.strideHeight, convInfo.strideWidth, convInfo.dilationHeight, convInfo.dilationWidth ]; + const uniformData = new Int32Array(dimensions); - return backend.runWebGPUProgram(program, [x, filter], x.dtype, dimensions); + return backend.runWebGPUProgram(program, [x, filter], x.dtype, uniformData); } export const conv2DConfig: KernelConfig = { diff --git a/tfjs-backend-webgpu/src/kernels/DepthwiseConv2dNative.ts b/tfjs-backend-webgpu/src/kernels/DepthwiseConv2dNative.ts index 6508645e062..ffd6942413e 100644 --- a/tfjs-backend-webgpu/src/kernels/DepthwiseConv2dNative.ts +++ b/tfjs-backend-webgpu/src/kernels/DepthwiseConv2dNative.ts @@ -46,7 +46,8 @@ export function depthwiseConv2dNative(args: { convInfo.dilationHeight, convInfo.dilationWidth, convInfo.inHeight, convInfo.inWidth ]; - return backend.runWebGPUProgram(program, [x, filter], x.dtype, dimensions); + const uniformData = new Int32Array(dimensions); + return backend.runWebGPUProgram(program, [x, filter], x.dtype, uniformData); } export const depthwiseConv2dNativeConfig: KernelConfig = { diff --git a/tfjs-backend-webgpu/src/kernels/Fill.ts b/tfjs-backend-webgpu/src/kernels/Fill.ts index 507f7f8c4cf..2f749b1fbb3 100644 --- a/tfjs-backend-webgpu/src/kernels/Fill.ts +++ b/tfjs-backend-webgpu/src/kernels/Fill.ts @@ -34,8 +34,9 @@ export function fill(args: {backend: WebGPUBackend, attrs: FillAttrs}): values.fill(value as string); return backend.makeTensorInfo(shape, dtype, values); } else { - const program = new FillProgram(shape, value as number); - return backend.runWebGPUProgram(program, [], dtype); + const program = new FillProgram(shape); + const uniformData = new Float32Array([value as number]); + return backend.runWebGPUProgram(program, [], dtype, uniformData); } } diff --git a/tfjs-backend-webgpu/src/kernels/FusedConv2D.ts b/tfjs-backend-webgpu/src/kernels/FusedConv2D.ts index 760baa7b0be..5af22ea60b9 100644 --- a/tfjs-backend-webgpu/src/kernels/FusedConv2D.ts +++ b/tfjs-backend-webgpu/src/kernels/FusedConv2D.ts @@ -96,7 +96,7 @@ export function fusedConv2d(args: { convInfo.strideHeight, convInfo.strideWidth, convInfo.dilationHeight, convInfo.dilationWidth ]; - + const uniformData = new Int32Array(dimensions); const inputVar: TensorInfo[] = [x, filter]; if (hasBias) { inputVar.push(bias); @@ -104,7 +104,7 @@ export function fusedConv2d(args: { if (hasPreluActivationWeights) { inputVar.push(preluActivationWeights); } - return backend.runWebGPUProgram(program, inputVar, x.dtype, dimensions); + return backend.runWebGPUProgram(program, inputVar, x.dtype, uniformData); } export const fusedConv2DConfig: KernelConfig = { diff --git a/tfjs-backend-webgpu/src/kernels/FusedDepthwiseConv2D.ts b/tfjs-backend-webgpu/src/kernels/FusedDepthwiseConv2D.ts index 9caa02e2a34..ac32b4e3bc3 100644 --- a/tfjs-backend-webgpu/src/kernels/FusedDepthwiseConv2D.ts +++ b/tfjs-backend-webgpu/src/kernels/FusedDepthwiseConv2D.ts @@ -66,8 +66,9 @@ export function fusedDepthwiseConv2D(args: { convInfo.dilationHeight, convInfo.dilationWidth, convInfo.inHeight, convInfo.inWidth ]; + const uniformData = new Int32Array(dimensions); const result = - backend.runWebGPUProgram(program, programInputs, 'float32', dimensions); + backend.runWebGPUProgram(program, programInputs, 'float32', uniformData); return result; } diff --git a/tfjs-backend-webgpu/src/kernels/MaxPool.ts b/tfjs-backend-webgpu/src/kernels/MaxPool.ts index 59b19b50f93..1ef0706f07d 100644 --- a/tfjs-backend-webgpu/src/kernels/MaxPool.ts +++ b/tfjs-backend-webgpu/src/kernels/MaxPool.ts @@ -50,8 +50,8 @@ export function maxPool( convInfo.effectiveFilterWidth, convInfo.effectiveFilterHeight // Filter dims. ]; - - return backend.runWebGPUProgram(program, [x], x.dtype, dimensions); + const uniformData = new Int32Array(dimensions); + return backend.runWebGPUProgram(program, [x], x.dtype, uniformData); } export const maxPoolConfig: KernelConfig = { diff --git a/tfjs-backend-webgpu/src/kernels/PadV2.ts b/tfjs-backend-webgpu/src/kernels/PadV2.ts index 7ad80aae1ce..765c2345375 100644 --- a/tfjs-backend-webgpu/src/kernels/PadV2.ts +++ b/tfjs-backend-webgpu/src/kernels/PadV2.ts @@ -26,9 +26,9 @@ export const padV2 = const {inputs, backend, attrs} = args; const {x} = inputs; const {paddings, constantValue} = attrs; - - const program = new PadProgram(x.shape, paddings, constantValue); - return backend.runWebGPUProgram(program, [x], x.dtype); + const uniformData = new Float32Array([constantValue]); + const program = new PadProgram(x.shape, paddings); + return backend.runWebGPUProgram(program, [x], x.dtype, uniformData); }; export const padV2Config: KernelConfig = { diff --git a/tfjs-backend-webgpu/src/kernels/fill_webgpu.ts b/tfjs-backend-webgpu/src/kernels/fill_webgpu.ts index 04876ca4c94..668bb1448bf 100644 --- a/tfjs-backend-webgpu/src/kernels/fill_webgpu.ts +++ b/tfjs-backend-webgpu/src/kernels/fill_webgpu.ts @@ -24,19 +24,18 @@ export class FillProgram implements WebGPUProgram { shaderKey: string; dispatchLayout: {x: number[]}; dispatch: [number, number, number]; + uniforms = 'float value;'; workPerThread = 4; workGroupSize: [number, number, number] = [16, 1, 1]; - value: number; - constructor(shape: number[], value: number) { + constructor(shape: number[]) { this.outputShape = shape; this.dispatchLayout = flatDispatchLayout(this.outputShape); this.dispatch = computeDispatch( this.dispatchLayout, this.outputShape, this.workGroupSize, [this.workPerThread, 1, 1]); - this.value = value; - this.shaderKey = `fill_${value}`; + this.shaderKey = 'fill'; } getUserCode(): string { @@ -47,7 +46,7 @@ export class FillProgram implements WebGPUProgram { for (int i = 0; i < ${this.workPerThread}; i++) { int flatIndex = index * ${this.workPerThread} + i; if (flatIndex < ${size}) { - setOutput(flatIndex,${this.value}); + setOutput(flatIndex, value); } } } diff --git a/tfjs-backend-webgpu/src/kernels/pad_webgpu.ts b/tfjs-backend-webgpu/src/kernels/pad_webgpu.ts index aca12bf98a1..a6e0883deda 100644 --- a/tfjs-backend-webgpu/src/kernels/pad_webgpu.ts +++ b/tfjs-backend-webgpu/src/kernels/pad_webgpu.ts @@ -28,15 +28,13 @@ export class PadProgram implements WebGPUProgram { dispatchLayout: {x: number[]}; dispatch: [number, number, number]; variableNames = ['x']; + uniforms = 'float constantValue;'; workPerThread = 8; workGroupSize: [number, number, number] = [16, 1, 1]; xShape: number[]; paddings: Array<[number, number]>; - constantValue: number; - constructor( - xShape: number[], paddings: Array<[number, number]>, - constantValue: number) { + constructor(xShape: number[], paddings: Array<[number, number]>) { this.outputShape = paddings.map( (p, i) => p[0] /* beforePad */ + xShape[i] + p[1] /* afterPad */); this.dispatchLayout = flatDispatchLayout(this.outputShape); @@ -46,8 +44,7 @@ export class PadProgram implements WebGPUProgram { this.xShape = xShape; this.paddings = paddings; - this.constantValue = constantValue; - this.shaderKey = `pad_${paddings}_${constantValue}`; + this.shaderKey = `pad_${paddings}`; } getUserCode(): string { @@ -82,7 +79,7 @@ export class PadProgram implements WebGPUProgram { ${type} outC = getCoordsFromFlatIndex(flatIndex); if (${leftPadCondition} || ${rightPadCondition}) { - setOutput(flatIndex, ${this.constantValue}); + setOutput(flatIndex, constantValue); } else { ${type} coords = outC - start; setOutput(flatIndex, getX(${unpackedCoords})); diff --git a/tfjs-backend-webgpu/src/setup_test.ts b/tfjs-backend-webgpu/src/setup_test.ts index 78d48dd060a..3568f4bbfb0 100644 --- a/tfjs-backend-webgpu/src/setup_test.ts +++ b/tfjs-backend-webgpu/src/setup_test.ts @@ -411,6 +411,14 @@ const TEST_FILTERS: TestFilter[] = [ 'dilation2d' // 'dilation2d' not yet implemented. ] }, + { + include: 'fill', + excludes: [ + 'string', // String is not yet implemented. + '5D', // Rank 5 is not yet supported. + 'rotateWithOffset', // 'RotateWithOffset' not registered. + ] + }, { include: 'Reduction: max', excludes: [