From 60e7cde74c530993345cdf4601e4024ffd6007bc Mon Sep 17 00:00:00 2001 From: Agrimagsrl Date: Fri, 31 Jul 2020 21:47:33 +0200 Subject: [PATCH] add GaussianNB --- micromlgen/templates/dot.jinja | 2 +- .../templates/gaussiannb/gaussiannb.jinja | 40 +++++++++++++++++++ micromlgen/templates/gaussiannb/vote.jinja | 8 ++++ micromlgen/templates/vote.jinja | 2 +- 4 files changed, 50 insertions(+), 2 deletions(-) create mode 100644 micromlgen/templates/gaussiannb/gaussiannb.jinja create mode 100644 micromlgen/templates/gaussiannb/vote.jinja diff --git a/micromlgen/templates/dot.jinja b/micromlgen/templates/dot.jinja index 45ab771..8a82485 100644 --- a/micromlgen/templates/dot.jinja +++ b/micromlgen/templates/dot.jinja @@ -16,7 +16,7 @@ float dot(float *x, {{ signature }}) { float dot = 0.0; for (uint16_t i = 0; i < {{ dimension }}; i++) { - const float wi = {{ wi }} + const float wi = {{ wi }}; dot += {{ expr }}; } diff --git a/micromlgen/templates/gaussiannb/gaussiannb.jinja b/micromlgen/templates/gaussiannb/gaussiannb.jinja new file mode 100644 index 0000000..8210ed5 --- /dev/null +++ b/micromlgen/templates/gaussiannb/gaussiannb.jinja @@ -0,0 +1,40 @@ +#pragma once + +namespace Eloquent { + namespace ML { + namespace Port { + + class {{ classname }} { + public: + + /** + * Predict class for features vector + */ + int predict(float *x) { + float votes[{{ classes|length }}] = { 0.0f }; + + {% include 'gaussiannb/vote.jinja' %} + {% include 'vote.jinja' %} + } + + {% include 'classmap.jinja' %} + + protected: + + /** + * Compute gaussian value + */ + float gauss(float *x, float *theta, float *sigma) { + float gauss = 0.0f; + + for (uint16_t i = 0; i < {{ theta[0]|length }}; i++) { + gauss += log(sigma[i]); + gauss += pow(x[i] - theta[i], 2) / sigma[i]; + } + + return gauss; + } + }; + } + } +} \ No newline at end of file diff --git a/micromlgen/templates/gaussiannb/vote.jinja b/micromlgen/templates/gaussiannb/vote.jinja new file mode 100644 index 0000000..6a4832d --- /dev/null +++ b/micromlgen/templates/gaussiannb/vote.jinja @@ -0,0 +1,8 @@ +float theta[{{ theta[0]|length }}] = { 0 }; +float sigma[{{ sigma[0]|length }}] = { 0 }; + +{% for i, (t, s) in f.enumerate(f.zip(theta, sigma)) %} + {% for j, tj in f.enumerate(t) %}theta[{{ j }}] = {{ f.round(tj) }}; {% endfor %} + {% for j, sj in f.enumerate(s) %}sigma[{{ j }}] = {{ f.round(sj) }}; {% endfor %} + votes[{{ i }}] = {{ f.round(prior[i]) }} - gauss(x, theta, sigma); +{% endfor %} \ No newline at end of file diff --git a/micromlgen/templates/vote.jinja b/micromlgen/templates/vote.jinja index b696718..2371eb4 100644 --- a/micromlgen/templates/vote.jinja +++ b/micromlgen/templates/vote.jinja @@ -1,6 +1,6 @@ // return argmax of votes uint8_t classIdx = 0; -uint8_t maxVotes = votes[0]; +float maxVotes = votes[0]; for (uint8_t i = 1; i < {{ n_classes }}; i++) { if (votes[i] > maxVotes) {