Skip to content

Commit

Permalink
modernized sample_prop stuff
Browse files Browse the repository at this point in the history
  • Loading branch information
Ian Goodfellow committed Nov 2, 2013
1 parent 4584c7f commit 004236a
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 13 deletions.
104 changes: 92 additions & 12 deletions sample_prop/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import theano.tensor as T
from pylearn2.costs.cost import Cost
from theano.printing import Print
from pylearn2.space import CompositeSpace
from collections import OrderedDict

class SimpleModel(Model):

Expand Down Expand Up @@ -59,6 +61,11 @@ def prob_of(Y,Z):
class SamplingCost(Cost):
supervised = True

def get_data_specs(self, model):
space = CompositeSpace([model.get_input_space(), model.get_output_space()])
sources = (model.get_input_source(), model.get_target_source())
return (space, sources)

def __init__(self, weight_decay_1=0., weight_decay_2=0.):
self.__dict__.update(locals())
del self.self
Expand Down Expand Up @@ -218,7 +225,10 @@ def get_monitoring_channels(self, model, X, Y):

class SimpleModel2(Model):

def __init__(self, nvis, num_hid, num_hid_2, num_class):
def __init__(self, nvis, num_hid, num_hid_2, num_class,
h0_max_col_norm=None,
h1_max_col_norm=None,
y_max_col_norm=None):
self.__dict__.update(locals())
del self.self

Expand All @@ -227,15 +237,64 @@ def __init__(self, nvis, num_hid, num_hid_2, num_class):
self.theano_rng = MRG_RandomStreams(2012 + 10 + 16)
rng = np.random.RandomState([16,10,2012])

self.W = sharedX(rng.uniform(-.05,.05,(nvis, num_hid)))
self.W = sharedX(rng.uniform(-.05,.05,(nvis, num_hid)), 'h0_W')
self.hb = sharedX(np.zeros((num_hid,)) - 1.)
self.V = sharedX(rng.uniform(-.05,.05,(num_hid, num_hid_2)))
self.V = sharedX(rng.uniform(-.05,.05,(num_hid, num_hid_2)), 'h1_W')
self.gb = sharedX(np.zeros((num_hid_2,)) - 1.)
self.V2 = sharedX(rng.uniform(-.05,.05,(num_hid_2, num_class)))
self.V2 = sharedX(rng.uniform(-.05,.05,(num_hid_2, num_class)), 'y_W')
self.cb = sharedX(np.zeros((num_class,)))

self._params = [self.W, self.hb, self.V, self.V2, self.gb, self.cb ]

def censor_updates(self, updates):

def constrain(W, max_col_norm):
if max_col_norm is not None:
if W in updates:
updated_W = updates[W]
col_norms = T.sqrt(T.sum(T.sqr(updated_W), axis=0))
desired_norms = T.clip(col_norms, 0, max_col_norm)
updates[W] = updated_W * (desired_norms / (1e-7 + col_norms))

constrain(self.W, self.h0_max_col_norm)
constrain(self.V, self.h1_max_col_norm)
constrain(self.V2, self.y_max_col_norm)

def get_monitoring_channels(self, data, ** kwargs):

rval = OrderedDict()

def add_col_norms(name, mat):
norms = T.sqrt(T.sqr(mat).sum(axis=0))
rval[name+"_col_norm_max"] = norms.max()
rval[name+"_col_norm_mean"] = norms.mean()
rval[name+"_col_norm_min"] = norms.min()

add_col_norms('y', self.V2)
add_col_norms('h1', self.V)
add_col_norms('h0', self.W)

X, Y = data

eH, H, eG, G, Z = self.emit(X)

def add_certainty(name, mat):
uncertainty = 1. - T.maximum(mat, 1-mat)
rval[name + '.uncertainty.min_x.min'] = uncertainty.min(axis=0).min()
rval[name + '.uncertainty.min_x.mean'] = uncertainty.min(axis=0).mean()
rval[name + '.uncertainty.min_x.max'] = uncertainty.min(axis=0).max()
rval[name + '.uncertainty.mean_x.min'] = uncertainty.mean(axis=0).min()
rval[name + '.uncertainty.mean_x.mean'] = uncertainty.mean(axis=0).mean()
rval[name + '.uncertainty.mean_x.max'] = uncertainty.mean(axis=0).max()
rval[name + '.uncertainty.max_x.min'] = uncertainty.max(axis=0).min()
rval[name + '.uncertainty.max_x.mean'] = uncertainty.max(axis=0).mean()
rval[name + '.uncertainty.max.max'] = uncertainty.max(axis=0).max()

add_certainty('h0', eH)
add_certainty('h1', eG)

return rval

def get_weights(self):
return self.W.get_value()

Expand Down Expand Up @@ -281,12 +340,28 @@ def add_polyak_channels(self, params, d):

for n in d:
ds = d[n]
name = n+'_polyak_acc'
self.monitor.add_channel(name, (X, Y), polyak_acc, ds)
name = n+'_y_misclass_polyak'
self.monitor.add_channel(name, (X, Y), 1 - polyak_acc, ds)

def get_monitoring_data_specs(self):
"""
Return the (space, source) data_specs for self.get_monitoring_channels.
In this case, we want the inputs and targets.
"""
space = CompositeSpace((self.get_input_space(),
self.get_output_space()))
source = (self.get_input_source(), self.get_target_source())
return (space, source)

class SamplingCost3(Cost):
supervised = True

def get_data_specs(self, model):
space = CompositeSpace([model.get_input_space(), model.get_output_space()])
sources = (model.get_input_source(), model.get_target_source())
return (space, sources)

def __init__(self, weight_decay_1=0., weight_decay_2=0.,
weight_decay_3=0.):
self.__dict__.update(locals())
Expand All @@ -295,14 +370,16 @@ def __init__(self, weight_decay_1=0., weight_decay_2=0.,
def batch_loss(self, Y, Z):
return - log_prob_of(Y, Z)

def __call__(self, model, X, Y):
def expr(self, model, data, **kwargs):
X, Y = data
assert type(model) is SimpleModel2 # yes, I did not mean to use isinstance
eH, H, eG, G, Z = model.emit(X)
return self.batch_loss(Y, Z).mean()

def get_gradients(self, model, X, Y):
def get_gradients(self, model, data, **kwargs):

obj = self(model, X, Y)
X, Y = data
obj = self.expr(model, data, **kwargs)
exp_H, H, exp_G, G, Z = model.emit(X)
batch_loss = self.batch_loss(Y, Z)
mn = sharedX(0.)
Expand All @@ -312,7 +389,7 @@ def get_gradients(self, model, X, Y):
# though. so we use earlier batch's samples
batch_loss = batch_loss - mn
alpha = .01
updates = { mn : alpha * batch_loss.mean() + (1.-alpha)*mn }
updates = OrderedDict([(mn, alpha * batch_loss.mean() + (1.-alpha)*mn)])

rval = {}

Expand All @@ -338,7 +415,10 @@ def get_gradients(self, model, X, Y):

return rval, updates

def get_monitoring_channels(self, model, X, Y):
def get_monitoring_channels(self, model, data, **kwargs):
X, Y = data
_, __, ___, ____, Z = model.emit(X)

return { 'acc' : T.cast(T.eq(T.argmax(Z,axis=1),T.argmax(Y,axis=1)).mean(), 'float32') }
return OrderedDict([
('y_misclass', T.cast(T.neq(T.argmax(Z,axis=1),T.argmax(Y,axis=1)).mean(), 'float32'))])

2 changes: 1 addition & 1 deletion sample_prop/sgd_mnist_9.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
!obj:pylearn2.scripts.train.Train {
!obj:pylearn2.train.Train {
dataset: &train !obj:pylearn2.datasets.mnist.MNIST {
which_set: "train",
binarize: 1,
Expand Down

0 comments on commit 004236a

Please sign in to comment.