Skip to content

Implement GradFn_crossentropy and fix GradFn_softmax #8

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 46 additions & 8 deletions src/nn.c
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,23 @@ Tensor nn_relu(Tensor self) {
/* nn.softmax */
static Tensor GradFn_softmax(Tensor self, int i) {
Tensor input = self.node->inputs[i];
Tensor res = Tensor_new(input.shape, false);
for(int j = 0; j < input.data->numel; j++) {
float softmax_j = self.data->flex[j];
for(int k = 0; k < input.data->numel; k++) {
float softmax_k = self.data->flex[k];
float delta_jk = (j == k) ? 1.0f : 0.0f;
res.data->flex[j * input.data->numel + k] = softmax_j * (delta_jk - softmax_k);
int num_classes = input.shape[TensorShape_dim(input.shape) - 1];
int batch_size = input.data->numel / num_classes;

Tensor res = Tensor_zeros(input.shape, false);

for (int batch = 0; batch < batch_size; batch++) {
float grad_sum = 0.0f;

for (int k = 0; k < num_classes; k++) {
grad_sum += self.data->flex[batch * num_classes + k] * self.node->grad.data->flex[batch * num_classes + k];
}

for (int j = 0; j < num_classes; j++) {
float softmax_j = self.data->flex[batch * num_classes + j];
float grad_j = self.node->grad.data->flex[batch * num_classes + j];

res.data->flex[batch * num_classes + j] = softmax_j * (grad_j - grad_sum);
}
}
return res;
Expand Down Expand Up @@ -88,6 +98,27 @@ Tensor nn_softmax(Tensor self) {
return res;
}

Tensor GradFn_crossentropy(Tensor self, int i) {
Tensor y_true = self.node->inputs[0];
Tensor y_pred = self.node->inputs[1];

Tensor grad = Tensor_zeros(y_pred.shape, false);

if (i == 1) {
// f'(y_true, y_pred) = -y_true / y_pred
int n_samples = y_pred.shape[0];
int n_classes = y_pred.shape[1];

for (int s = 0; s < n_samples; s++) {
for (int c = 0; c < n_classes; c++) {
int idx = s * n_classes + c;
grad.data->flex[idx] = -y_true.data->flex[idx] / y_pred.data->flex[idx];
}
}
}
return grad;
}

/* nn.cross_entropy */
Tensor nn_crossentropy(Tensor y_true, Tensor y_pred) {
// y_true: [None, n_classes]
Expand All @@ -101,7 +132,7 @@ Tensor nn_crossentropy(Tensor y_true, Tensor y_pred) {
assert(n_classes == y_pred.shape[1]);

bool requires_grad = !cten_is_eval() && (y_true.node != NULL || y_pred.node != NULL);
Tensor res = Tensor_new((TensorShape){n_samples}, requires_grad);
Tensor res = Tensor_new((TensorShape){n_samples, 1}, requires_grad);
for(int i = 0; i < n_samples; i++) {
float loss = 0;
for(int j = 0; j < n_classes; j++) {
Expand All @@ -110,5 +141,12 @@ Tensor nn_crossentropy(Tensor y_true, Tensor y_pred) {
}
res.data->flex[i] = -loss;
}

if (requires_grad) {
res.node->grad_fn = GradFn_crossentropy;
res.node->inputs[0] = y_true;
res.node->inputs[1] = y_pred;
res.node->n_inputs = 2;
}
return Tensor_mean(res);
}