Skip to content

Commit

Permalink
output layer based on 1 convolutions
Browse files Browse the repository at this point in the history
  • Loading branch information
splis committed Feb 4, 2024
1 parent 759b6cf commit 704eb03
Showing 1 changed file with 91 additions and 3 deletions.
94 changes: 91 additions & 3 deletions js/brainchop/mainMeshNetFunctions.js
Original file line number Diff line number Diff line change
Expand Up @@ -4010,8 +4010,48 @@ accumulateArrBufSizes = (bufferSizesArr) => {
* @return {tf.Tensor}
*
*/
function processTensorInChunks(input, filter, biases, sliceSize) {
const inChannels = input.shape[4];
const numSlices = Math.ceil(inChannels / sliceSize);

let outputTensor = null;

for (let i = 0; i < numSlices; i++) {
const startChannel = i * sliceSize;
const endChannel = Math.min((i + 1) * sliceSize, inChannels);

// Only proceed if there are channels to process
if (startChannel < inChannels) {
const inputSlice = input.slice([0, 0, 0, 0, startChannel], [-1, -1, -1, -1, endChannel - startChannel]);
const filterSlice = filter.slice([0, 0, 0, startChannel, 0], [-1, -1, -1, endChannel - startChannel, -1]);

// Perform the convolution for the current slice
const resultSlice = tf.conv3d(inputSlice, filterSlice, [1, 1, 1], 'valid', 'NDHWC', [1, 1, 1]);

if (outputTensor === null) {
outputTensor = resultSlice;
} else {
const updatedOutputTensor = outputTensor.add(resultSlice);
outputTensor.dispose();
outputTensor = updatedOutputTensor;
}

// Dispose of the intermediate tensors
inputSlice.dispose();
filterSlice.dispose();
}
}

// Add the biases to the accumulated convolutions
const biasedOutputTensor = outputTensor.add(biases.reshape([1, 1, 1, 1, -1]));

function processTensorInChunks(inputTensor, vector, chunkSize) {
// Clean up the last reference to outputTensor
outputTensor.dispose();

return biasedOutputTensor;
}

function processTensorInChunks1(inputTensor, vector, chunkSize) {
const rank = inputTensor.rank;
const lastDimension = inputTensor.shape[rank - 1];

Expand Down Expand Up @@ -4177,12 +4217,13 @@ class SequentialConvLayer {
// -- e.g. filterWeights.shape [ 1, 1, 1, 5, 1 ]
const filterBiases = biases.slice([chIdx], [1]);
//-- e.g. filterBiases.shape [1] -> Tensor [-0.7850812]
const outA = processTensorInChunks(tf.squeeze(inputTensor), tf.squeeze(filterWeights), self.chunkSize).add(filterBiases);
//const outA = processTensorInChunks(tf.squeeze(inputTensor), tf.squeeze(filterWeights), self.chunkSize).add(filterBiases);
const outA = tf.squeeze(processTensorInChunks(inputTensor, filterWeights, filterBiases, self.chunkSize));
const greater = tf.greater(outA, outB);
const newoutB = tf.where(greater, outA, outB);
const newoutC = tf.where(greater, tf.fill(outC.shape, chIdx), outC);
// Dispose the old tensors before reassigning
tf.dispose([outB, outC]);
tf.dispose([outB, outC, filterWeights, filterBiases]);
return [newoutC, newoutB];
});

Expand Down Expand Up @@ -4228,6 +4269,53 @@ class SequentialConvLayer {
*
*/

function convByInputSlicing_oneconv(input, filter, biases, stride, pad, dilationRate, sliceSize) {
const batchSize = input.shape[0];
const depth = input.shape[1];
const height = input.shape[2];
const width = input.shape[3];
const inChannels = input.shape[4];
const outChannels = filter.shape[4];

// Create an empty array to hold the output channels
let outputChannels = null;

const numSlices = Math.ceil(inChannels / sliceSize);
const biasesSlice = biases;
let outputChannel = null;

for (let i = 0; i < numSlices; i++) {
const startChannel = i * sliceSize;
const endChannel = Math.min((i + 1) * sliceSize, inChannels);

// Only proceed if there are channels to process
if (startChannel < inChannels) {
const resultSlice = tf.tidy(() => {
const inputSlice = input.slice([0, 0, 0, 0, startChannel], [-1, -1, -1, -1, endChannel - startChannel]);
const filterSlice = filter.slice([0, 0, 0, startChannel, 0], [-1, -1, -1, endChannel - startChannel, 1]);
// Perform the convolution for the current slice and output channel
return tf.conv3d(inputSlice, filterSlice, stride, pad, 'NDHWC', dilationRate);
});

if (outputChannel === null) {
outputChannel = resultSlice;
} else {
const updatedOutputChannel = outputChannel.add(resultSlice);
outputChannel.dispose();
resultSlice.dispose();
outputChannel = updatedOutputChannel;
}
}
}

// Add the biases to the accumulated convolutions for this channel
const biasedOutputChannel = outputChannel.add(biasesSlice);
outputChannel.dispose();
biasesSlice.dispose();

return biasedOutputChannel;
}

function convByOutputChannelAndInputSlicing(input, filter, biases, stride, pad, dilationRate, sliceSize) {
const batchSize = input.shape[0];
const depth = input.shape[1];
Expand Down

0 comments on commit 704eb03

Please sign in to comment.