Skip to content

Commit

Permalink
add inputs the classifer
Browse files Browse the repository at this point in the history
  • Loading branch information
memimo committed Jun 26, 2015
1 parent fc762a7 commit 427a4b3
Showing 1 changed file with 5 additions and 9 deletions.
14 changes: 5 additions & 9 deletions code/logistic_sgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@

import theano
import theano.tensor as T
from theano.gof import graph


class LogisticRegression(object):
Expand Down Expand Up @@ -110,6 +109,9 @@ def __init__(self, input, n_in, n_out):
# parameters of the model
self.params = [self.W, self.b]

# keep track of model input
self.input = input

def negative_log_likelihood(self, y):
"""Return the mean of the negative log-likelihood of the prediction
of this model under a given target distribution.
Expand Down Expand Up @@ -447,17 +449,11 @@ def predict():

# load the saved model
classifier = cPickle.load(open('best_model.pkl'))
y_pred = classifier.y_pred

# find the input to theano graph
inputs = graph.inputs([y_pred])
# select only x
inputs = [item for item in inputs if item.name == 'x']

# compile a predictor function
predict_model = theano.function(
inputs=inputs,
outputs=y_pred)
inputs=[classifier.input],
outputs=classifier.y_pred)

# We can test it on some examples from test test
dataset='mnist.pkl.gz'
Expand Down

0 comments on commit 427a4b3

Please sign in to comment.