-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathmain.py
37 lines (29 loc) · 1018 Bytes
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
from sklearn import datasets
from sklearn.linear_model import LogisticRegression
from sklearn.externals import joblib
import coremltools
# Load the iris dataset
iris = datasets.load_iris()
# Train a logistic regression
model = LogisticRegression()
model.fit(iris.data, iris.target)
# Make a prediction
print 'prediction with scikit model:'
print iris.target_names[model.predict([[1.0, 2.0, 2.0, 3.0]])]
# Dumping the model with joblib for comparison
joblib.dump(model, 'iris.pkl')
# Export and save the CoreML model
coreml_model = coremltools.converters.sklearn.convert(model, iris.feature_names, 'iris class')
coreml_model.save('iris.mlmodel')
# Load back the model
loaded_model = coremltools.models.MLModel('iris.mlmodel')
# You can check the model's specifications
print loaded_model.get_spec()
input_data = {
'sepal length (cm)': 1.0,
'sepal width (cm)': 2.0,
'petal length (cm)': 2.0,
'petal width (cm)': 3.0
}
print 'prediction with coreml model:'
print loaded_model.predict(input_data)