diff --git a/questplus/qp.py b/questplus/qp.py index b05b957..6283739 100644 --- a/questplus/qp.py +++ b/questplus/qp.py @@ -62,9 +62,11 @@ def __init__(self, *, stim_selection_options Use this argument to specify options for the stimulus selection - method specified via `stim_selection_method`. Currently, this is - only used to specify the number of `n` stimuli that will yield the - `n` smallest entropies `stim_selection_method=min_n_entropy`. + method specified via `stim_selection_method`. Currently, this can + be used to specify the number of `n` stimuli that will yield the + `n` smallest entropies if `stim_selection_method=min_n_entropy`, + and`max_consecutive_reps`, the number of times the same stimulus + can be presented consecutively. param_estimation_method The method to use when deriving the final parameter estimate. @@ -84,7 +86,12 @@ def __init__(self, *, self.likelihoods = self._gen_likelihoods() self.stim_selection = stim_selection_method - self.stim_selection_options = stim_selection_options + + if (self.stim_selection == 'min_n_entropy' and + stim_selection_options is None): + self.stim_selection_options = dict(n=4, max_consecutive_reps=2) + else: + self.stim_selection_options = stim_selection_options self.param_estimation_method = param_estimation_method @@ -199,7 +206,6 @@ def next_stim(self) -> dict: The stimulus to present next. """ - stim_selection = self.stim_selection new_posterior = self.posterior * self.likelihoods # Probability. @@ -215,31 +221,39 @@ def next_stim(self) -> dict: # Expected entropies for all possible stimulus parameters. EH = (pk * H).sum(dim=list(self.outcome_domain.keys())) - if stim_selection == 'min_entropy': + if self.stim_selection == 'min_entropy': # Get coordinates of stimulus properties that minimize entropy. index = np.unravel_index(EH.argmin(), EH.shape) coords = EH[index].coords stim = {stim_property: stim_val.item() for stim_property, stim_val in coords.items()} self.entropy = EH.min().item() - # FIXME: currently disabled, need to adopt above method for - # finding correct coordinates! - # elif stim_selection == 'min_n_entropy': - # index = np.argsort(EH)[:4] - # while True: - # stim_candidates = self.stim_domain['intensity'][index.values] - # stim = np.random.choice(stim_candidates) - # - # if len(self.stim_history['intensity']) < 2: - # break - # elif (np.isclose(stim, self.stim_history['intensity'][-1]) and - # np.isclose(stim, self.stim_history['intensity'][-2])): - # print('\n ==> shuffling again... <==\n') - # continue - # else: - # break - # - # print(f'options: {self.stim_domain["intensity"][index.values]} -> {stim}') + elif self.stim_selection == 'min_n_entropy': + # Number of stimuli to include (the n stimuli that yield the lowest + # entropies) + n_stim = self.stim_selection_options['n'] + + indices = np.unravel_index(EH.argsort(), EH.shape)[0] + indices = indices[:n_stim] + + while True: + # Randomly pick one index and retrieve its coordinates + # (stimulus parameters). + candidate_index = np.random.choice(indices) + coords = EH[candidate_index].coords + stim = {stim_property: stim_val.item() + for stim_property, stim_val in coords.items()} + + max_reps = self.stim_selection_options['max_consecutive_reps'] + + if len(self.stim_history) < 2: + break + elif all([stim == prev_stim + for prev_stim in self.stim_history[-max_reps:]]): + # Shuffle again. + continue + else: + break else: raise ValueError('Unknown stim_selection supplied.')