From 7616863e5a5df4f4dc5ec42a20ddcb1a7f195ea3 Mon Sep 17 00:00:00 2001 From: Numbers0689 Date: Thu, 6 Mar 2025 16:00:49 +0530 Subject: [PATCH] Implement and fix nn GradFns Fix GradFn_softmax to index correctly and store only the vector jacobian product. Implement GradFn_crossentropy for proper backpropagation. --- src/nn.c | 54 ++++++++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 46 insertions(+), 8 deletions(-) diff --git a/src/nn.c b/src/nn.c index 3eecd31..b842dfb 100644 --- a/src/nn.c +++ b/src/nn.c @@ -39,13 +39,23 @@ Tensor nn_relu(Tensor self) { /* nn.softmax */ static Tensor GradFn_softmax(Tensor self, int i) { Tensor input = self.node->inputs[i]; - Tensor res = Tensor_new(input.shape, false); - for(int j = 0; j < input.data->numel; j++) { - float softmax_j = self.data->flex[j]; - for(int k = 0; k < input.data->numel; k++) { - float softmax_k = self.data->flex[k]; - float delta_jk = (j == k) ? 1.0f : 0.0f; - res.data->flex[j * input.data->numel + k] = softmax_j * (delta_jk - softmax_k); + int num_classes = input.shape[TensorShape_dim(input.shape) - 1]; + int batch_size = input.data->numel / num_classes; + + Tensor res = Tensor_zeros(input.shape, false); + + for (int batch = 0; batch < batch_size; batch++) { + float grad_sum = 0.0f; + + for (int k = 0; k < num_classes; k++) { + grad_sum += self.data->flex[batch * num_classes + k] * self.node->grad.data->flex[batch * num_classes + k]; + } + + for (int j = 0; j < num_classes; j++) { + float softmax_j = self.data->flex[batch * num_classes + j]; + float grad_j = self.node->grad.data->flex[batch * num_classes + j]; + + res.data->flex[batch * num_classes + j] = softmax_j * (grad_j - grad_sum); } } return res; @@ -88,6 +98,27 @@ Tensor nn_softmax(Tensor self) { return res; } +Tensor GradFn_crossentropy(Tensor self, int i) { + Tensor y_true = self.node->inputs[0]; + Tensor y_pred = self.node->inputs[1]; + + Tensor grad = Tensor_zeros(y_pred.shape, false); + + if (i == 1) { + // f'(y_true, y_pred) = -y_true / y_pred + int n_samples = y_pred.shape[0]; + int n_classes = y_pred.shape[1]; + + for (int s = 0; s < n_samples; s++) { + for (int c = 0; c < n_classes; c++) { + int idx = s * n_classes + c; + grad.data->flex[idx] = -y_true.data->flex[idx] / y_pred.data->flex[idx]; + } + } + } + return grad; +} + /* nn.cross_entropy */ Tensor nn_crossentropy(Tensor y_true, Tensor y_pred) { // y_true: [None, n_classes] @@ -101,7 +132,7 @@ Tensor nn_crossentropy(Tensor y_true, Tensor y_pred) { assert(n_classes == y_pred.shape[1]); bool requires_grad = !cten_is_eval() && (y_true.node != NULL || y_pred.node != NULL); - Tensor res = Tensor_new((TensorShape){n_samples}, requires_grad); + Tensor res = Tensor_new((TensorShape){n_samples, 1}, requires_grad); for(int i = 0; i < n_samples; i++) { float loss = 0; for(int j = 0; j < n_classes; j++) { @@ -110,5 +141,12 @@ Tensor nn_crossentropy(Tensor y_true, Tensor y_pred) { } res.data->flex[i] = -loss; } + + if (requires_grad) { + res.node->grad_fn = GradFn_crossentropy; + res.node->inputs[0] = y_true; + res.node->inputs[1] = y_pred; + res.node->n_inputs = 2; + } return Tensor_mean(res); } \ No newline at end of file