From fe4051b9118498a0f81874b8c8f6dadbe6838d64 Mon Sep 17 00:00:00 2001 From: Marvin T Date: Wed, 8 Jul 2015 17:25:45 -0700 Subject: [PATCH 1/3] all queues set to updated on load recursive on_load resetting More informative error messages --- pyoperant/queues.py | 34 +++++++++++++++++++++++++++++++++- 1 file changed, 33 insertions(+), 1 deletion(-) diff --git a/pyoperant/queues.py b/pyoperant/queues.py index ff8ca593..e397b883 100644 --- a/pyoperant/queues.py +++ b/pyoperant/queues.py @@ -62,6 +62,7 @@ class AdaptiveBase(object): """ def __init__(self, **kwargs): self.updated = True # for first trial, no update needed + self.update_error_str = "queue hasn't been updated since last trial" def __iter__(self): return self @@ -73,12 +74,20 @@ def update(self, correct, no_resp): def next(self): if not self.updated: #hasn't been updated since last trial - raise Exception("queue hasn't been updated since last trial") + raise Exception(self.update_error_str) self.updated = False def no_response(self): pass + def on_load(self): + try: + super(AdaptiveBase, self).on_load() + except AttributeError: + pass + self.updated = True + self.no_response() + class PersistentBase(object): """ A mixin that allows for the creation of an obj through a load command that @@ -95,10 +104,17 @@ def load(cls, filename, *args, **kwargs): try: with open(filename, 'rb') as handle: ab = pickle.load(handle) + ab.on_load() return ab except IOError: return cls(*args, filename=filename, **kwargs) + def on_load(self): + try: + super(PersistentBase, self).on_load() + except AttributeError: + pass + def save(self): with open(self.filename, 'wb') as handle: pickle.dump(self, handle) @@ -186,6 +202,7 @@ def __init__(self, stims, rate_constant=.05, **kwargs): self.low_idx = 0 self.high_idx = len(self.stims) - 1 self.trial = {} + self.update_error_str = "double staircase queue %s hasn't been updated since last trial" % (self.stims[0]) def update(self, correct, no_resp): super(DoubleStaircase, self).update(correct, no_resp) @@ -233,6 +250,7 @@ def __init__(self, stims, rate_constant=.05, probe_rate=.1, **kwargs): self.stims = stims self.probe_rate = probe_rate self.last_probe = False + self.update_error_str = "reinforced double staircase queue %s hasn't been updated since last trial" % (self.stims[0]) def update(self, correct, no_resp): super(DoubleStaircaseReinforced, self).update(correct, no_resp) @@ -265,6 +283,10 @@ def no_response(self): super(DoubleStaircaseReinforced, self).no_response() self.last_probe = False + def on_load(self): + super(DoubleStaircaseReinforced, self).on_load() + self.dblstaircase.on_load() + class MixedAdaptiveQueue(PersistentBase, AdaptiveBase): """ @@ -285,6 +307,7 @@ def __init__(self, sub_queues, probabilities=None, **kwargs): self.sub_queues = sub_queues self.probabilities = probabilities self.sub_queue_idx = -1 + self.update_error_str = "MixedAdaptiveQueue hasn't been updated since last trial" self.save() def update(self, correct, no_resp): @@ -305,5 +328,14 @@ def next(self): #TODO: support variable probabilities for each sub_queue raise NotImplementedError + def on_load(self): + super(MixedAdaptiveQueue, self).on_load() + for sub_queue in self.sub_queues: + try: + sub_queue.on_load() + except AttributeError: + pass + + From 3d190ab2fa58630b080ae3c4c8da596f924a712d Mon Sep 17 00:00:00 2001 From: Marvin T Date: Fri, 10 Jul 2015 13:39:06 -0700 Subject: [PATCH 2/3] fix sampling logic sample_log option for double_staircase --- pyoperant/queues.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/pyoperant/queues.py b/pyoperant/queues.py index e397b883..08d181ed 100644 --- a/pyoperant/queues.py +++ b/pyoperant/queues.py @@ -244,11 +244,12 @@ class DoubleStaircaseReinforced(AdaptiveBase): rate_constant: the step size is the rate_constant*(high_idx-low_idx) probe_rate: proportion of trials that are between [0, low_idx] or [high_idx, length(stims)] """ - def __init__(self, stims, rate_constant=.05, probe_rate=.1, **kwargs): + def __init__(self, stims, rate_constant=.05, probe_rate=.1, sample_log=False, **kwargs): super(DoubleStaircaseReinforced, self).__init__(**kwargs) self.dblstaircase = DoubleStaircase(stims, rate_constant) self.stims = stims self.probe_rate = probe_rate + self.sample_log = sample_log self.last_probe = False self.update_error_str = "reinforced double staircase queue %s hasn't been updated since last trial" % (self.stims[0]) @@ -273,10 +274,16 @@ def next(self): else: self.last_probe = False if random.random() < .5: # probe left - val = int((1 - rand_from_log_shape_dist()) * self.dblstaircase.low_idx) + if self.sample_log: + val = int((1 - rand_from_log_shape_dist()) * self.dblstaircase.low_idx) + else: + val = random.randrange(self.dblstaircase.low_idx) return {'class': 'L', 'stim_name': self.stims[val]} else: # probe right - val = self.dblstaircase.high_idx - int(rand_from_log_shape_dist() * (len(self.stims) - 1 - self.dblstaircase.high_idx)) + if self.sample_log: + val = self.dblstaircase.high_idx + int(rand_from_log_shape_dist() * (len(self.stims) - self.dblstaircase.high_idx)) + else: + val = self.dblstaircase.high_idx + random.randrange(len(self.stims) - self.dblstaircase.high_idx) return {'class': 'R', 'stim_name': self.stims[val]} def no_response(self): From f1c2fc8795f9dacc4e4f856a9491cec5ba86298a Mon Sep 17 00:00:00 2001 From: Marvin T Date: Mon, 27 Jul 2015 12:19:02 -0700 Subject: [PATCH 3/3] no_response_correction_trials parameter --- pyoperant/behavior/two_alt_choice.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/pyoperant/behavior/two_alt_choice.py b/pyoperant/behavior/two_alt_choice.py index 6aca676b..457c6c38 100755 --- a/pyoperant/behavior/two_alt_choice.py +++ b/pyoperant/behavior/two_alt_choice.py @@ -104,6 +104,9 @@ def __init__(self, *args, **kwargs): if 'session_schedule' not in self.parameters: self.parameters['session_schedule'] = self.parameters['light_schedule'] + if 'no_response_correction_trials' not in self.parameters: + self.parameters['no_response_correction_trials'] = False + def make_data_csv(self): """ Create the csv file to save trial data @@ -350,7 +353,7 @@ def trial_post(self): self.do_correction = False elif self.this_trial.response == 'none': if self.this_trial.type_ == 'normal': - self.do_correction = False + self.do_correction = self.parameters['no_response_correction_trials'] else: self.do_correction = False else: