From 704eb031fd6877f2155c0b9ee4233ecf2e24a539 Mon Sep 17 00:00:00 2001 From: splis Date: Sat, 3 Feb 2024 22:23:13 -0500 Subject: [PATCH] output layer based on 1 convolutions --- js/brainchop/mainMeshNetFunctions.js | 94 +++++++++++++++++++++++++++- 1 file changed, 91 insertions(+), 3 deletions(-) diff --git a/js/brainchop/mainMeshNetFunctions.js b/js/brainchop/mainMeshNetFunctions.js index ac37951..eee8e93 100644 --- a/js/brainchop/mainMeshNetFunctions.js +++ b/js/brainchop/mainMeshNetFunctions.js @@ -4010,8 +4010,48 @@ accumulateArrBufSizes = (bufferSizesArr) => { * @return {tf.Tensor} * */ +function processTensorInChunks(input, filter, biases, sliceSize) { + const inChannels = input.shape[4]; + const numSlices = Math.ceil(inChannels / sliceSize); + + let outputTensor = null; + + for (let i = 0; i < numSlices; i++) { + const startChannel = i * sliceSize; + const endChannel = Math.min((i + 1) * sliceSize, inChannels); + + // Only proceed if there are channels to process + if (startChannel < inChannels) { + const inputSlice = input.slice([0, 0, 0, 0, startChannel], [-1, -1, -1, -1, endChannel - startChannel]); + const filterSlice = filter.slice([0, 0, 0, startChannel, 0], [-1, -1, -1, endChannel - startChannel, -1]); + + // Perform the convolution for the current slice + const resultSlice = tf.conv3d(inputSlice, filterSlice, [1, 1, 1], 'valid', 'NDHWC', [1, 1, 1]); + + if (outputTensor === null) { + outputTensor = resultSlice; + } else { + const updatedOutputTensor = outputTensor.add(resultSlice); + outputTensor.dispose(); + outputTensor = updatedOutputTensor; + } + + // Dispose of the intermediate tensors + inputSlice.dispose(); + filterSlice.dispose(); + } + } + + // Add the biases to the accumulated convolutions + const biasedOutputTensor = outputTensor.add(biases.reshape([1, 1, 1, 1, -1])); -function processTensorInChunks(inputTensor, vector, chunkSize) { + // Clean up the last reference to outputTensor + outputTensor.dispose(); + + return biasedOutputTensor; +} + +function processTensorInChunks1(inputTensor, vector, chunkSize) { const rank = inputTensor.rank; const lastDimension = inputTensor.shape[rank - 1]; @@ -4177,12 +4217,13 @@ class SequentialConvLayer { // -- e.g. filterWeights.shape [ 1, 1, 1, 5, 1 ] const filterBiases = biases.slice([chIdx], [1]); //-- e.g. filterBiases.shape [1] -> Tensor [-0.7850812] - const outA = processTensorInChunks(tf.squeeze(inputTensor), tf.squeeze(filterWeights), self.chunkSize).add(filterBiases); + //const outA = processTensorInChunks(tf.squeeze(inputTensor), tf.squeeze(filterWeights), self.chunkSize).add(filterBiases); + const outA = tf.squeeze(processTensorInChunks(inputTensor, filterWeights, filterBiases, self.chunkSize)); const greater = tf.greater(outA, outB); const newoutB = tf.where(greater, outA, outB); const newoutC = tf.where(greater, tf.fill(outC.shape, chIdx), outC); // Dispose the old tensors before reassigning - tf.dispose([outB, outC]); + tf.dispose([outB, outC, filterWeights, filterBiases]); return [newoutC, newoutB]; }); @@ -4228,6 +4269,53 @@ class SequentialConvLayer { * */ +function convByInputSlicing_oneconv(input, filter, biases, stride, pad, dilationRate, sliceSize) { + const batchSize = input.shape[0]; + const depth = input.shape[1]; + const height = input.shape[2]; + const width = input.shape[3]; + const inChannels = input.shape[4]; + const outChannels = filter.shape[4]; + + // Create an empty array to hold the output channels + let outputChannels = null; + + const numSlices = Math.ceil(inChannels / sliceSize); + const biasesSlice = biases; + let outputChannel = null; + + for (let i = 0; i < numSlices; i++) { + const startChannel = i * sliceSize; + const endChannel = Math.min((i + 1) * sliceSize, inChannels); + + // Only proceed if there are channels to process + if (startChannel < inChannels) { + const resultSlice = tf.tidy(() => { + const inputSlice = input.slice([0, 0, 0, 0, startChannel], [-1, -1, -1, -1, endChannel - startChannel]); + const filterSlice = filter.slice([0, 0, 0, startChannel, 0], [-1, -1, -1, endChannel - startChannel, 1]); + // Perform the convolution for the current slice and output channel + return tf.conv3d(inputSlice, filterSlice, stride, pad, 'NDHWC', dilationRate); + }); + + if (outputChannel === null) { + outputChannel = resultSlice; + } else { + const updatedOutputChannel = outputChannel.add(resultSlice); + outputChannel.dispose(); + resultSlice.dispose(); + outputChannel = updatedOutputChannel; + } + } + } + + // Add the biases to the accumulated convolutions for this channel + const biasedOutputChannel = outputChannel.add(biasesSlice); + outputChannel.dispose(); + biasesSlice.dispose(); + + return biasedOutputChannel; +} + function convByOutputChannelAndInputSlicing(input, filter, biases, stride, pad, dilationRate, sliceSize) { const batchSize = input.shape[0]; const depth = input.shape[1];