diff --git a/projects/deploy_mnist/train_mnist_model.py b/projects/deploy_mnist/train_mnist_model.py index 3262433..4417921 100644 --- a/projects/deploy_mnist/train_mnist_model.py +++ b/projects/deploy_mnist/train_mnist_model.py @@ -1,11 +1,11 @@ -from sklearn.datasets import fetch_mldata +from sklearn.datasets import fetch_openml import numpy as np from sklearn.linear_model import SGDClassifier from sklearn.metrics import accuracy_score from sklearn.externals import joblib np.random.seed(42) -mnist = fetch_mldata("MNIST original") +mnist = fetch_openml("MNIST original", version=1, cache=True) X, y = mnist["data"], mnist["target"] X_train, X_test, y_train, y_test = X[:60000], X[60000:], y[:60000], y[60000:]