This repository has been archived by the owner on May 25, 2024. It is now read-only.
forked from agrimagsrl/micromlgen
-
Notifications
You must be signed in to change notification settings - Fork 34
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
e5b062d
commit 06e75d4
Showing
20 changed files
with
219 additions
and
311 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,20 +1,20 @@ | ||
from sklearn.tree import DecisionTreeClassifier | ||
from micromlgen.utils import jinja | ||
from micromlgen.utils import jinja, check_type | ||
|
||
|
||
def is_decisiontree(clf): | ||
"""Test if classifier can be ported""" | ||
return isinstance(clf, DecisionTreeClassifier) | ||
return check_type(clf, 'DecisionTreeClassifier') | ||
|
||
|
||
def port_decisiontree(clf, **kwargs): | ||
"""Port sklearn's DecisionTreeClassifier""" | ||
kwargs['classname'] = kwargs['classname'] or 'DecisionTree' | ||
return jinja('decisiontree/decisiontree.jinja', { | ||
'left': clf.tree_.children_left, | ||
'right': clf.tree_.children_right, | ||
'features': clf.tree_.feature, | ||
'thresholds': clf.tree_.threshold, | ||
'classes': clf.tree_.value, | ||
'i': 0 | ||
}, { | ||
'classname': 'DecisionTree' | ||
}, **kwargs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,19 +1,19 @@ | ||
from sklearn.naive_bayes import GaussianNB | ||
from micromlgen.utils import jinja | ||
from micromlgen.utils import jinja, check_type | ||
|
||
|
||
def is_gaussiannb(clf): | ||
"""Test if classifier can be ported""" | ||
return isinstance(clf, GaussianNB) | ||
return check_type(clf, 'GaussianNB') | ||
|
||
|
||
def port_gaussiannb(clf, **kwargs): | ||
"""Port sklearn's DecisionTreeClassifier""" | ||
kwargs['classname'] = kwargs['classname'] or 'GaussianNB' | ||
"""Port sklearn's GaussianNB""" | ||
return jinja('gaussiannb/gaussiannb.jinja', { | ||
'sigma': clf.sigma_, | ||
'theta': clf.theta_, | ||
'prior': clf.class_prior_, | ||
'classes': clf.classes_, | ||
'n_classes': len(clf.classes_) | ||
}, { | ||
'classname': 'GaussianNB' | ||
}, **kwargs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,18 +1,18 @@ | ||
from sklearn.linear_model import LogisticRegression | ||
from micromlgen.utils import jinja | ||
from micromlgen.utils import jinja, check_type | ||
|
||
|
||
def is_logisticregression(clf): | ||
"""Test if classifier can be ported""" | ||
return isinstance(clf, LogisticRegression) | ||
return check_type(clf, 'LogisticRegression') | ||
|
||
|
||
def port_logisticregression(clf, **kwargs): | ||
"""Port sklearn's DecisionTreeClassifier""" | ||
kwargs['classname'] = kwargs['classname'] or 'LogisticRegression' | ||
"""Port sklearn's LogisticRegressionClassifier""" | ||
return jinja('logisticregression/logisticregression.jinja', { | ||
'weights': clf.coef_, | ||
'intercept': clf.intercept_, | ||
'classes': clf.classes_, | ||
'n_classes': len(clf.classes_) | ||
}, { | ||
'classname': 'LogisticRegression' | ||
}, **kwargs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,13 +1,18 @@ | ||
from micromlgen.utils import jinja | ||
from micromlgen.utils import jinja, check_type | ||
|
||
|
||
def port_pca(pca, classname=None, **kwargs): | ||
def is_pca(clf): | ||
"""Test if classifier can be ported""" | ||
return check_type(clf, 'PCA') | ||
|
||
|
||
def port_pca(clf, **kwargs): | ||
"""Port a PCA""" | ||
template_data = { | ||
return jinja('pca/pca.jinja', { | ||
'arrays': { | ||
'components': pca.components_, | ||
'mean': pca.mean_ | ||
'components': clf.components_, | ||
'mean': clf.mean_ | ||
}, | ||
'classname': classname if classname is not None else 'PCA' | ||
} | ||
return jinja('pca/pca.jinja', template_data) | ||
}, { | ||
'classname': 'PCA' | ||
}, **kwargs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,12 +1,17 @@ | ||
from micromlgen.utils import jinja | ||
from micromlgen.utils import jinja, check_type | ||
|
||
|
||
def is_sefr(clf): | ||
"""Test if classifier can be ported""" | ||
return check_type(clf, 'SEFR') | ||
|
||
|
||
def port_sefr(clf, classname=None, **kwargs): | ||
"""Port SEFR classifier""" | ||
kwargs.update({ | ||
return jinja('sefr/sefr.jinja', { | ||
'weights': clf.weights, | ||
'bias': clf.bias, | ||
'dimension': len(clf.weights), | ||
'classname': classname or 'SEFR' | ||
}) | ||
return jinja('sefr/sefr.jinja', kwargs) | ||
}, { | ||
'classname': 'SEFR' | ||
}, **kwargs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
#pragma once | ||
|
||
namespace Eloquent { | ||
namespace ML { | ||
namespace Port { | ||
class {{ classname }} { | ||
public: | ||
|
||
/** | ||
* Predict class for features vector | ||
*/ | ||
int predict(float *x) { | ||
{% block predict %}{% endblock %} | ||
} | ||
|
||
{% include 'classmap.jinja' %} | ||
|
||
{% block public %}{% endblock %} | ||
|
||
protected: | ||
|
||
{% block protected %}{% endblock %} | ||
}; | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,21 +1,5 @@ | ||
#pragma once | ||
{% extends '_skeleton.jinja' %} | ||
|
||
namespace Eloquent { | ||
namespace ML { | ||
namespace Port { | ||
|
||
class {{ classname }} { | ||
public: | ||
|
||
/** | ||
* Predict class for features vector | ||
*/ | ||
int predict(float *x) { | ||
{% include 'decisiontree/tree.jinja' %} | ||
} | ||
|
||
{% include 'classmap.jinja' %} | ||
}; | ||
} | ||
} | ||
} | ||
{% block predict %} | ||
{% include 'decisiontree/tree.jinja' %} | ||
{% endblock %} |
Oops, something went wrong.