Skip to content

Commit

Permalink
Merge pull request #96 from MarvinT/double_staircase
Browse files Browse the repository at this point in the history
Double staircase
  • Loading branch information
MarvinT committed Jul 28, 2015
2 parents 09b1ae7 + f1c2fc8 commit a46e96e
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 5 deletions.
5 changes: 4 additions & 1 deletion pyoperant/behavior/two_alt_choice.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,9 @@ def __init__(self, *args, **kwargs):
if 'session_schedule' not in self.parameters:
self.parameters['session_schedule'] = self.parameters['light_schedule']

if 'no_response_correction_trials' not in self.parameters:
self.parameters['no_response_correction_trials'] = False

def make_data_csv(self):
""" Create the csv file to save trial data
Expand Down Expand Up @@ -352,7 +355,7 @@ def trial_post(self):
self.do_correction = False
elif self.this_trial.response == 'none':
if self.this_trial.type_ == 'normal':
self.do_correction = False
self.do_correction = self.parameters['no_response_correction_trials']
else:
self.do_correction = False
else:
Expand Down
47 changes: 43 additions & 4 deletions pyoperant/queues.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ class AdaptiveBase(object):
"""
def __init__(self, **kwargs):
self.updated = True # for first trial, no update needed
self.update_error_str = "queue hasn't been updated since last trial"

def __iter__(self):
return self
Expand All @@ -73,12 +74,20 @@ def update(self, correct, no_resp):

def next(self):
if not self.updated: #hasn't been updated since last trial
raise Exception("queue hasn't been updated since last trial")
raise Exception(self.update_error_str)
self.updated = False

def no_response(self):
pass

def on_load(self):
try:
super(AdaptiveBase, self).on_load()
except AttributeError:
pass
self.updated = True
self.no_response()

class PersistentBase(object):
"""
A mixin that allows for the creation of an obj through a load command that
Expand All @@ -95,10 +104,17 @@ def load(cls, filename, *args, **kwargs):
try:
with open(filename, 'rb') as handle:
ab = pickle.load(handle)
ab.on_load()
return ab
except IOError:
return cls(*args, filename=filename, **kwargs)

def on_load(self):
try:
super(PersistentBase, self).on_load()
except AttributeError:
pass

def save(self):
with open(self.filename, 'wb') as handle:
pickle.dump(self, handle)
Expand Down Expand Up @@ -186,6 +202,7 @@ def __init__(self, stims, rate_constant=.05, **kwargs):
self.low_idx = 0
self.high_idx = len(self.stims) - 1
self.trial = {}
self.update_error_str = "double staircase queue %s hasn't been updated since last trial" % (self.stims[0])

def update(self, correct, no_resp):
super(DoubleStaircase, self).update(correct, no_resp)
Expand Down Expand Up @@ -227,12 +244,14 @@ class DoubleStaircaseReinforced(AdaptiveBase):
rate_constant: the step size is the rate_constant*(high_idx-low_idx)
probe_rate: proportion of trials that are between [0, low_idx] or [high_idx, length(stims)]
"""
def __init__(self, stims, rate_constant=.05, probe_rate=.1, **kwargs):
def __init__(self, stims, rate_constant=.05, probe_rate=.1, sample_log=False, **kwargs):
super(DoubleStaircaseReinforced, self).__init__(**kwargs)
self.dblstaircase = DoubleStaircase(stims, rate_constant)
self.stims = stims
self.probe_rate = probe_rate
self.sample_log = sample_log
self.last_probe = False
self.update_error_str = "reinforced double staircase queue %s hasn't been updated since last trial" % (self.stims[0])

def update(self, correct, no_resp):
super(DoubleStaircaseReinforced, self).update(correct, no_resp)
Expand All @@ -255,16 +274,26 @@ def next(self):
else:
self.last_probe = False
if random.random() < .5: # probe left
val = int((1 - rand_from_log_shape_dist()) * self.dblstaircase.low_idx)
if self.sample_log:
val = int((1 - rand_from_log_shape_dist()) * self.dblstaircase.low_idx)
else:
val = random.randrange(self.dblstaircase.low_idx)
return {'class': 'L', 'stim_name': self.stims[val]}
else: # probe right
val = self.dblstaircase.high_idx - int(rand_from_log_shape_dist() * (len(self.stims) - 1 - self.dblstaircase.high_idx))
if self.sample_log:
val = self.dblstaircase.high_idx + int(rand_from_log_shape_dist() * (len(self.stims) - self.dblstaircase.high_idx))
else:
val = self.dblstaircase.high_idx + random.randrange(len(self.stims) - self.dblstaircase.high_idx)
return {'class': 'R', 'stim_name': self.stims[val]}

def no_response(self):
super(DoubleStaircaseReinforced, self).no_response()
self.last_probe = False

def on_load(self):
super(DoubleStaircaseReinforced, self).on_load()
self.dblstaircase.on_load()


class MixedAdaptiveQueue(PersistentBase, AdaptiveBase):
"""
Expand All @@ -285,6 +314,7 @@ def __init__(self, sub_queues, probabilities=None, **kwargs):
self.sub_queues = sub_queues
self.probabilities = probabilities
self.sub_queue_idx = -1
self.update_error_str = "MixedAdaptiveQueue hasn't been updated since last trial"
self.save()

def update(self, correct, no_resp):
Expand All @@ -305,5 +335,14 @@ def next(self):
#TODO: support variable probabilities for each sub_queue
raise NotImplementedError

def on_load(self):
super(MixedAdaptiveQueue, self).on_load()
for sub_queue in self.sub_queues:
try:
sub_queue.on_load()
except AttributeError:
pass




0 comments on commit a46e96e

Please sign in to comment.