Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Predict function for logistic regression. #92

Merged
merged 4 commits into from
Jul 2, 2015
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions code/logistic_sgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,9 @@ def __init__(self, input, n_in, n_out):
# parameters of the model
self.params = [self.W, self.b]

# keep track of model input
self.input = input

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you check if this is needed for the other models? That way, all pickled files will have the needed information.

def negative_log_likelihood(self, y):
"""Return the mean of the negative log-likelihood of the prediction
of this model under a given target distribution.
Expand Down Expand Up @@ -415,6 +418,10 @@ def sgd_optimization_mnist(learning_rate=0.13, n_epochs=1000,
)
)

# save the best model
with open('best_model.pkl', 'w') as f:
cPickle.dump(classifier, f)

if patience <= iter:
done_looping = True
break
Expand All @@ -433,5 +440,31 @@ def sgd_optimization_mnist(learning_rate=0.13, n_epochs=1000,
os.path.split(__file__)[1] +
' ran for %.1fs' % ((end_time - start_time)))


def predict():
"""
An example of how to load a trained model and use it
to predict labels.
"""

# load the saved model
classifier = cPickle.load(open('best_model.pkl'))

# compile a predictor function
predict_model = theano.function(
inputs=[classifier.input],
outputs=classifier.y_pred)

# We can test it on some examples from test test
dataset='mnist.pkl.gz'
datasets = load_data(dataset)
test_set_x, test_set_y = datasets[2]
test_set_x = test_set_x.get_value()

predicted_values = predict_model(test_set_x[:10])
print ("Predicted values for the first 10 examples in test set:")
print predicted_values


if __name__ == '__main__':
sgd_optimization_mnist()
13 changes: 13 additions & 0 deletions doc/logreg.txt
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,19 @@ approximately 1.936 epochs/sec and it took 75 epochs to reach a test
error of 7.489%. On the GPU the code does almost 10.0 epochs/sec. For this
instance we used a batch size of 600.


Prediction Using a Trained Model
+++++++++++++++++++++++++++++++

``sgd_optimization_mnist`` serialize and pickle the model each time new
lowest validation error is reached. We can reload this model and predict
labels of new data. ``predict`` function shows an example of how
this could be done.

.. literalinclude:: ../code/logistic_sgd.py
:pyobject: predict


.. rubric:: Footnotes

.. [#f1] For smaller datasets and simpler models, more sophisticated descent
Expand Down