Skip to content
This repository has been archived by the owner on May 25, 2024. It is now read-only.

Commit

Permalink
add GaussianNB
Browse files Browse the repository at this point in the history
  • Loading branch information
agrimagsrl committed Aug 2, 2020
1 parent 7e5128f commit 60e7cde
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 2 deletions.
2 changes: 1 addition & 1 deletion micromlgen/templates/dot.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ float dot(float *x, {{ signature }}) {
float dot = 0.0;

for (uint16_t i = 0; i < {{ dimension }}; i++) {
const float wi = {{ wi }}
const float wi = {{ wi }};
dot += {{ expr }};
}

Expand Down
40 changes: 40 additions & 0 deletions micromlgen/templates/gaussiannb/gaussiannb.jinja
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
#pragma once

namespace Eloquent {
namespace ML {
namespace Port {

class {{ classname }} {
public:

/**
* Predict class for features vector
*/
int predict(float *x) {
float votes[{{ classes|length }}] = { 0.0f };

{% include 'gaussiannb/vote.jinja' %}
{% include 'vote.jinja' %}
}

{% include 'classmap.jinja' %}

protected:

/**
* Compute gaussian value
*/
float gauss(float *x, float *theta, float *sigma) {
float gauss = 0.0f;

for (uint16_t i = 0; i < {{ theta[0]|length }}; i++) {
gauss += log(sigma[i]);
gauss += pow(x[i] - theta[i], 2) / sigma[i];
}

return gauss;
}
};
}
}
}
8 changes: 8 additions & 0 deletions micromlgen/templates/gaussiannb/vote.jinja
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
float theta[{{ theta[0]|length }}] = { 0 };
float sigma[{{ sigma[0]|length }}] = { 0 };

{% for i, (t, s) in f.enumerate(f.zip(theta, sigma)) %}
{% for j, tj in f.enumerate(t) %}theta[{{ j }}] = {{ f.round(tj) }}; {% endfor %}
{% for j, sj in f.enumerate(s) %}sigma[{{ j }}] = {{ f.round(sj) }}; {% endfor %}
votes[{{ i }}] = {{ f.round(prior[i]) }} - gauss(x, theta, sigma);
{% endfor %}
2 changes: 1 addition & 1 deletion micromlgen/templates/vote.jinja
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// return argmax of votes
uint8_t classIdx = 0;
uint8_t maxVotes = votes[0];
float maxVotes = votes[0];

for (uint8_t i = 1; i < {{ n_classes }}; i++) {
if (votes[i] > maxVotes) {
Expand Down

0 comments on commit 60e7cde

Please sign in to comment.