Skip to content
This repository has been archived by the owner on May 25, 2024. It is now read-only.

Commit

Permalink
add XGBClassifier
Browse files Browse the repository at this point in the history
  • Loading branch information
agrimagsrl committed Oct 20, 2020
1 parent 66a8fbb commit a9a59c4
Show file tree
Hide file tree
Showing 8 changed files with 82 additions and 4 deletions.
5 changes: 5 additions & 0 deletions MANIFEST
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ micromlgen/__init__.py
micromlgen/convolution.py
micromlgen/decisiontree.py
micromlgen/gaussiannb.py
micromlgen/linear_regression.py
micromlgen/logisticregression.py
micromlgen/micromlgen.py
micromlgen/pca.py
Expand All @@ -16,6 +17,7 @@ micromlgen/sefr.py
micromlgen/svm.py
micromlgen/utils.py
micromlgen/wifiindoorpositioning.py
micromlgen/xgboost.py
micromlgen/templates/_skeleton.jinja
micromlgen/templates/classmap.jinja
micromlgen/templates/dot.jinja
Expand All @@ -27,6 +29,7 @@ micromlgen/templates/decisiontree/decisiontree.jinja
micromlgen/templates/decisiontree/tree.jinja
micromlgen/templates/gaussiannb/gaussiannb.jinja
micromlgen/templates/gaussiannb/vote.jinja
micromlgen/templates/linearregression/linearregression.jinja
micromlgen/templates/logisticregression/logisticregression.jinja
micromlgen/templates/logisticregression/vote.arduino.jinja
micromlgen/templates/logisticregression/vote.attiny.jinja
Expand All @@ -49,3 +52,5 @@ micromlgen/templates/svm/kernel/arduino.jinja
micromlgen/templates/svm/kernel/attiny.jinja
micromlgen/templates/svm/kernel/kernel.jinja
micromlgen/templates/wifiindoorpositioning/wifiindoorpositioning.jinja
micromlgen/templates/xgboost/tree.jinja
micromlgen/templates/xgboost/xgboost.jinja
Binary file added micromlgen/__init__.pyc
Binary file not shown.
4 changes: 3 additions & 1 deletion micromlgen/micromlgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from micromlgen.pca import is_pca, port_pca
from micromlgen.principalfft import is_principalfft, port_principalfft
from micromlgen.linear_regression import is_linear_regression, port_linear_regression

from micromlgen.xgboost import is_xgboost, port_xgboost

def port(
clf,
Expand Down Expand Up @@ -40,4 +40,6 @@ def port(
return port_principalfft(**locals(), **kwargs)
elif is_linear_regression(clf):
return port_linear_regression(**locals(), **kwargs)
elif is_xgboost(clf):
return port_xgboost(**locals(), **kwargs)
raise TypeError('clf MUST be one of %s' % ', '.join(platforms.ALLOWED_CLASSIFIERS))
3 changes: 2 additions & 1 deletion micromlgen/platforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,6 @@
'GaussianNB',
'LogisticRegression',
'PCA',
'PrincipalFFT'
'PrincipalFFT',
'LinearRegression'
]
14 changes: 14 additions & 0 deletions micromlgen/templates/xgboost/tree.jinja
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
{% if tree['left'][i] != tree['right'][i] %}
if (x[{{ tree['features'][i] }}] <= {{ tree['thresholds'][i] }}) {
{% with i = tree['left'][i] %}
{% include 'xgboost/tree.jinja' %}
{% endwith %}
}
else {
{% with i = tree['right'][i] %}
{% include 'xgboost/tree.jinja' %}
{% endwith %}
}
{% else %}
votes[{{ class_idx }}] += {{ tree['thresholds'][i] }};
{% endif %}
14 changes: 14 additions & 0 deletions micromlgen/templates/xgboost/xgboost.jinja
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
{% extends '_skeleton.jinja' %}

{% block predict %}
float votes[{{ n_classes }}] = { 0.0f };

{% for k, tree in f.enumerate(trees) %}
{% with i = 0, class_idx = k % n_classes %}
// tree #{{ k + 1 }}
{% include 'xgboost/tree.jinja' %}
{% endwith %}
{% endfor %}

{% include 'vote.jinja' %}
{% endblock %}
42 changes: 42 additions & 0 deletions micromlgen/xgboost.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from micromlgen.utils import jinja, check_type
from tempfile import NamedTemporaryFile
import json


def format_tree(tree):
"""
Format xgboost tree like a sklearn DecisionTree
:param tree:
:return:
"""
split_indices = tree['split_indices']
split_conditions = tree['split_conditions']
left_children = tree['left_children']
right_children = tree['right_children']
return {
'left': left_children,
'right': right_children,
'features': split_indices,
'thresholds': split_conditions
}


def is_xgboost(clf):
"""Test if classifier can be ported"""
return check_type(clf, 'XGBClassifier')


def port_xgboost(clf, **kwargs):
"""Port a XGBoost classifier"""
with NamedTemporaryFile('w+', suffix='.json', encoding='utf-8') as tmp:
clf.save_model(tmp.name)
tmp.seek(0)
decoded = json.load(tmp)
trees = [format_tree(tree) for tree in decoded['learner']['gradient_booster']['model']['trees']]
print(trees)
return jinja('xgboost/xgboost.jinja', {
'n_classes': int(decoded['learner']['learner_model_param']['num_class']),
'trees': trees,
}, {
'classname': 'XGBClassifier'
}, **kwargs)
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@
setup(
name = 'micromlgen',
packages = ['micromlgen'],
version = '1.1.10',
version = '1.1.11',
license='MIT',
description = 'Generate C code for microcontrollers from Python\'s sklearn classifiers',
author = 'Simone Salerno',
author_email = '[email protected]',
url = 'https://github.com/eloquentarduino/micromlgen',
download_url = 'https://github.com/eloquentarduino/micromlgen/archive/v_1110.tar.gz',
download_url = 'https://github.com/eloquentarduino/micromlgen/archive/v_1111.tar.gz',
keywords = [
'ML',
'microcontrollers',
Expand Down

0 comments on commit a9a59c4

Please sign in to comment.