Skip to content
This repository has been archived by the owner on May 25, 2024. It is now read-only.

Commit

Permalink
refactor templates
Browse files Browse the repository at this point in the history
  • Loading branch information
agrimagsrl committed Aug 3, 2020
1 parent e5b062d commit 06e75d4
Show file tree
Hide file tree
Showing 20 changed files with 219 additions and 311 deletions.
8 changes: 4 additions & 4 deletions micromlgen/decisiontree.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
from sklearn.tree import DecisionTreeClassifier
from micromlgen.utils import jinja
from micromlgen.utils import jinja, check_type


def is_decisiontree(clf):
"""Test if classifier can be ported"""
return isinstance(clf, DecisionTreeClassifier)
return check_type(clf, 'DecisionTreeClassifier')


def port_decisiontree(clf, **kwargs):
"""Port sklearn's DecisionTreeClassifier"""
kwargs['classname'] = kwargs['classname'] or 'DecisionTree'
return jinja('decisiontree/decisiontree.jinja', {
'left': clf.tree_.children_left,
'right': clf.tree_.children_right,
'features': clf.tree_.feature,
'thresholds': clf.tree_.threshold,
'classes': clf.tree_.value,
'i': 0
}, {
'classname': 'DecisionTree'
}, **kwargs)
10 changes: 5 additions & 5 deletions micromlgen/gaussiannb.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
from sklearn.naive_bayes import GaussianNB
from micromlgen.utils import jinja
from micromlgen.utils import jinja, check_type


def is_gaussiannb(clf):
"""Test if classifier can be ported"""
return isinstance(clf, GaussianNB)
return check_type(clf, 'GaussianNB')


def port_gaussiannb(clf, **kwargs):
"""Port sklearn's DecisionTreeClassifier"""
kwargs['classname'] = kwargs['classname'] or 'GaussianNB'
"""Port sklearn's GaussianNB"""
return jinja('gaussiannb/gaussiannb.jinja', {
'sigma': clf.sigma_,
'theta': clf.theta_,
'prior': clf.class_prior_,
'classes': clf.classes_,
'n_classes': len(clf.classes_)
}, {
'classname': 'GaussianNB'
}, **kwargs)
10 changes: 5 additions & 5 deletions micromlgen/logisticregression.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
from sklearn.linear_model import LogisticRegression
from micromlgen.utils import jinja
from micromlgen.utils import jinja, check_type


def is_logisticregression(clf):
"""Test if classifier can be ported"""
return isinstance(clf, LogisticRegression)
return check_type(clf, 'LogisticRegression')


def port_logisticregression(clf, **kwargs):
"""Port sklearn's DecisionTreeClassifier"""
kwargs['classname'] = kwargs['classname'] or 'LogisticRegression'
"""Port sklearn's LogisticRegressionClassifier"""
return jinja('logisticregression/logisticregression.jinja', {
'weights': clf.coef_,
'intercept': clf.intercept_,
'classes': clf.classes_,
'n_classes': len(clf.classes_)
}, {
'classname': 'LogisticRegression'
}, **kwargs)
30 changes: 9 additions & 21 deletions micromlgen/micromlgen.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,12 @@
from sklearn.decomposition import PCA
from sklearn.svm import SVC, LinearSVC, OneClassSVM

try:
from skbayes.rvm_ard_models import RVC
except ImportError:
from micromlgen.patches import RVC
try:
from sefr import SEFR
except ImportError:
from micromlgen.patches import SEFR

from micromlgen import platforms
from micromlgen.svm import port_svm
from micromlgen.rvm import port_rvm
from micromlgen.sefr import port_sefr
from micromlgen.svm import is_svm, port_svm
from micromlgen.rvm import is_rvm, port_rvm
from micromlgen.sefr import is_sefr, port_sefr
from micromlgen.decisiontree import is_decisiontree, port_decisiontree
from micromlgen.randomforest import is_randomforest, port_randomforest
from micromlgen.logisticregression import is_logisticregression, port_logisticregression
from micromlgen.gaussiannb import is_gaussiannb, port_gaussiannb
from micromlgen.pca import port_pca
from micromlgen.pca import is_pca, port_pca


def port(
Expand All @@ -29,13 +17,11 @@ def port(
precision=None):
"""Port a classifier to plain C++"""
assert platform in platforms.ALL, 'Unknown platform %s. Use one of %s' % (platform, ', '.join(platforms.ALL))
if isinstance(clf, (SVC, LinearSVC, OneClassSVM)):
if is_svm(clf):
return port_svm(**locals())
elif isinstance(clf, RVC):
elif is_rvm(clf):
return port_rvm(**locals())
elif isinstance(clf, PCA):
return port_pca(pca=clf, **locals())
elif isinstance(clf, SEFR):
elif is_sefr(clf):
return port_sefr(**locals())
elif is_decisiontree(clf):
return port_decisiontree(**locals())
Expand All @@ -45,4 +31,6 @@ def port(
return port_logisticregression(**locals())
elif is_gaussiannb(clf):
return port_gaussiannb(**locals())
elif is_pca(clf):
return port_pca(**locals())
raise TypeError('clf MUST be one of SVC, LinearSVC, OneClassSVC, RVC, DecisionTree, RandomForest, LogisticRegression, GaussianNB, SEFR, PCA')
8 changes: 0 additions & 8 deletions micromlgen/patches.py

This file was deleted.

21 changes: 13 additions & 8 deletions micromlgen/pca.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
from micromlgen.utils import jinja
from micromlgen.utils import jinja, check_type


def port_pca(pca, classname=None, **kwargs):
def is_pca(clf):
"""Test if classifier can be ported"""
return check_type(clf, 'PCA')


def port_pca(clf, **kwargs):
"""Port a PCA"""
template_data = {
return jinja('pca/pca.jinja', {
'arrays': {
'components': pca.components_,
'mean': pca.mean_
'components': clf.components_,
'mean': clf.mean_
},
'classname': classname if classname is not None else 'PCA'
}
return jinja('pca/pca.jinja', template_data)
}, {
'classname': 'PCA'
}, **kwargs)
10 changes: 5 additions & 5 deletions micromlgen/randomforest.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
from sklearn.ensemble import RandomForestClassifier
from micromlgen.utils import jinja
from micromlgen.utils import jinja, check_type


def is_randomforest(clf):
"""Test if classifier can be ported"""
return isinstance(clf, RandomForestClassifier)
return check_type(clf, 'RandomForestClassifier')


def port_randomforest(clf, **kwargs):
"""Port sklearn's RandomForestClassifier"""
kwargs['classname'] = kwargs['classname'] or 'RandomForest'
return jinja('randomforest/randomforest.jinja', {
'n_classes': clf.n_classes_,
'trees': [{
Expand All @@ -19,4 +17,6 @@ def port_randomforest(clf, **kwargs):
'thresholds': clf.tree_.threshold,
'classes': clf.tree_.value,
} for clf in clf.estimators_]
}, **locals())
}, {
'classname': 'RandomForest'
}, **kwargs)
22 changes: 13 additions & 9 deletions micromlgen/rvm.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,23 @@
from micromlgen.utils import jinja
from micromlgen.utils import jinja, check_type


def port_rvm(clf, classname, **kwargs):
def is_rvm(clf):
"""Test if classifier can be ported"""
return check_type(clf, 'RVC')


def port_rvm(clf, **kwargs):
"""Port a RVM classifier"""
assert classname is None or len(classname) > 0, 'Invalid class name'
template_data = {
**kwargs,
return jinja('rvm/rvm.jinja', {
'n_classes': len(clf.intercept_),
'kernel': {
'type': clf.kernel,
'gamma': clf.gamma,
'coef0': clf.coef0,
'degree': clf.degree
},
'sizes': {
'features': len(clf.relevant_vectors_[0]),
'features': clf.relevant_vectors_[0].shape[1],
},
'arrays': {
'vectors': clf.relevant_vectors_,
Expand All @@ -23,6 +27,6 @@ def port_rvm(clf, classname, **kwargs):
'mean': clf._x_mean,
'std': clf._x_std
},
'classname': classname if classname is not None else 'RVM',
}
return jinja('rvm/rvm.jinja', template_data)
}, {
'classname': 'RVC'
}, **kwargs)
15 changes: 10 additions & 5 deletions micromlgen/sefr.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
from micromlgen.utils import jinja
from micromlgen.utils import jinja, check_type


def is_sefr(clf):
"""Test if classifier can be ported"""
return check_type(clf, 'SEFR')


def port_sefr(clf, classname=None, **kwargs):
"""Port SEFR classifier"""
kwargs.update({
return jinja('sefr/sefr.jinja', {
'weights': clf.weights,
'bias': clf.bias,
'dimension': len(clf.weights),
'classname': classname or 'SEFR'
})
return jinja('sefr/sefr.jinja', kwargs)
}, {
'classname': 'SEFR'
}, **kwargs)
23 changes: 11 additions & 12 deletions micromlgen/svm.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,17 @@
from sklearn.svm import OneClassSVM
from micromlgen.utils import jinja, check_type

from micromlgen.utils import jinja

def is_svm(clf):
"""Test if classifier can be ported"""
return check_type(clf, 'SVC', 'LinearSVC', 'OneClassSVM')

def port_svm(clf, classname=None, **kwargs):

def port_svm(clf, **kwargs):
"""Port a SVC / LinearSVC classifier"""
assert isinstance(clf.gamma, float), 'You probably didn\'t set an explicit value for gamma: 0.001 is a good default'
assert classname is None or len(classname) > 0, 'Invalid class name'
if classname is None:
classname = 'OneClassSVM' if isinstance(clf, OneClassSVM) else 'SVM'
support_v = clf.support_vectors_
n_classes = len(clf.n_support_)
template_data = {
**kwargs,
return jinja('svm/svm.jinja', {
'kernel': {
'type': clf.kernel,
'gamma': clf.gamma,
Expand All @@ -30,7 +29,7 @@ def port_svm(clf, classname=None, **kwargs):
'supports': support_v,
'intercepts': clf.intercept_,
'coefs': clf.dual_coef_
},
'classname': classname
}
return jinja('svm/svm.jinja', template_data)
}
}, {
'classname': 'OneClassSVM' if check_type(clf, 'OneClassSVM') else 'SVM'
}, **kwargs)
26 changes: 26 additions & 0 deletions micromlgen/templates/_skeleton.jinja
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#pragma once

namespace Eloquent {
namespace ML {
namespace Port {
class {{ classname }} {
public:

/**
* Predict class for features vector
*/
int predict(float *x) {
{% block predict %}{% endblock %}
}

{% include 'classmap.jinja' %}

{% block public %}{% endblock %}

protected:

{% block protected %}{% endblock %}
};
}
}
}
24 changes: 4 additions & 20 deletions micromlgen/templates/decisiontree/decisiontree.jinja
Original file line number Diff line number Diff line change
@@ -1,21 +1,5 @@
#pragma once
{% extends '_skeleton.jinja' %}

namespace Eloquent {
namespace ML {
namespace Port {

class {{ classname }} {
public:

/**
* Predict class for features vector
*/
int predict(float *x) {
{% include 'decisiontree/tree.jinja' %}
}

{% include 'classmap.jinja' %}
};
}
}
}
{% block predict %}
{% include 'decisiontree/tree.jinja' %}
{% endblock %}
Loading

0 comments on commit 06e75d4

Please sign in to comment.