From 14c028288b0e5d791394cf8f45fad0efa0a11216 Mon Sep 17 00:00:00 2001 From: Fengler Date: Mon, 12 Feb 2024 23:09:52 -0500 Subject: [PATCH 1/3] return upper boundary crossing as choice p in cpn training data --- ssms/dataset_generators/lan_mlp.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ssms/dataset_generators/lan_mlp.py b/ssms/dataset_generators/lan_mlp.py index 44efbbc..da24283 100755 --- a/ssms/dataset_generators/lan_mlp.py +++ b/ssms/dataset_generators/lan_mlp.py @@ -270,9 +270,9 @@ def _mlp_get_processed_data_for_theta(self, random_seed_tuple): kde_data = self._make_kde_data(simulations=simulations, theta=theta) if len(simulations["metadata"]["possible_choices"]) == 2: - cpn_labels = np.expand_dims(simulations["choice_p"][0, 0], axis=0) + cpn_labels = np.expand_dims(simulations["choice_p"][0, 1], axis=0) cpn_no_omission_labels = np.expand_dims( - simulations["choice_p_no_omission"][0, 0], axis=0 + simulations["choice_p_no_omission"][0, 1], axis=0 ) else: cpn_labels = simulations["choice_p"] From 922235d5e865aa54e407ce6db9c7c93570fac562 Mon Sep 17 00:00:00 2001 From: Fengler Date: Mon, 12 Feb 2024 23:17:23 -0500 Subject: [PATCH 2/3] align _cpn_get_processed_data_for_theta with _mlp_get_processes_data_for_theta --- ssms/dataset_generators/lan_mlp.py | 73 +++++++----------------------- 1 file changed, 17 insertions(+), 56 deletions(-) diff --git a/ssms/dataset_generators/lan_mlp.py b/ssms/dataset_generators/lan_mlp.py index da24283..892794b 100755 --- a/ssms/dataset_generators/lan_mlp.py +++ b/ssms/dataset_generators/lan_mlp.py @@ -313,65 +313,26 @@ def _cpn_get_processed_data_for_theta(self, random_seed_tuple): theta=theta_dict, random_seed=random_seed_tuple[1] ) - # Compute the choice probabilities - out_dict = {} - out_dict["choice_p"] = np.zeros( - (1, len(simulations["metadata"]["possible_choices"])) - ) - out_dict["choice_p_no_omission"] = np.zeros( - (1, len(simulations["metadata"]["possible_choices"])) - ) - out_dict["omission_p"] = {} - out_dict["theta"] = np.expand_dims(theta, axis=0) - - for k, choice in enumerate(simulations["metadata"]["possible_choices"]): - out_dict["choice_p"][k] = np.array( - [ - (simulations["choices"] == choice).sum() - / simulations["choices"].flatten().shape[0] - ] - ) - out_dict["choice_p_no_omission"][k] = np.array( - [ - (simulations["choices"][simulations["rts"] != -999] == choice).sum() - / simulations["choices"].flatten().shape[0] - ] + if len(simulations["metadata"]["possible_choices"]) == 2: + cpn_labels = np.expand_dims(simulations["choice_p"][0, 1], axis=0) + cpn_no_omission_labels = np.expand_dims( + simulations["choice_p_no_omission"][0, 1], axis=0 ) - - out_dict["omission_p"] = np.expand_dims( - np.array( - [ - (simulations["rts"] == -999).sum() - / simulations["choices"].flatten().shape[0] - ] - ), - axis=0, - ) - out_dict["nogo_p"] = np.expand_dims( - np.array( - [ - (simulations["rts"] == -999) - | ( - simulations["choices"] - != simulations["metadata"]["possible_choices"][0] - ) - ] - ), - axis=0, - ) - out_dict["go_p"] = 1 - out_dict["nogo_p"] + else: + cpn_labels = simulations["choice_p"] + cpn_no_omission_labels = simulations["choice_p_no_omission"] return { - "cpn_data": np.expand_dims(theta, axis=0), - "cpn_labels": out_dict["choice_p"], - "cpn_no_omission_data": np.expand_dims(theta, axis=0), - "cpn_no_omission_labels": out_dict["choice_p_no_omission"], - "opn_data": np.expand_dims(theta, axis=0), - "opn_labels": out_dict["omission_p"], - "gonogo_data": np.expand_dims(theta, axis=0), - "gonogo_labels": out_dict["nogo_p"], - "theta": np.expand_dims(theta, axis=0), - } + "cpn_data": np.expand_dims(theta, axis=0), + "cpn_labels": cpn_labels, + "cpn_no_omission_data": np.expand_dims(theta, axis=0), + "cpn_no_omission_labels": cpn_no_omission_labels, + "opn_data": np.expand_dims(theta, axis=0), + "opn_labels": simulations["omission_p"], + "gonogo_data": np.expand_dims(theta, axis=0), + "gonogo_labels": simulations["nogo_p"], + "theta": np.expand_dims(theta, axis=0), + } def _get_rejected_parameter_setups(self, random_seed_tuple): np.random.seed(random_seed_tuple[0]) From 814c894947ae1241ed416335710a87fd3a2eea76 Mon Sep 17 00:00:00 2001 From: Fengler Date: Mon, 12 Feb 2024 23:17:53 -0500 Subject: [PATCH 3/3] black --- ssms/dataset_generators/lan_mlp.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/ssms/dataset_generators/lan_mlp.py b/ssms/dataset_generators/lan_mlp.py index 892794b..8e4aa62 100755 --- a/ssms/dataset_generators/lan_mlp.py +++ b/ssms/dataset_generators/lan_mlp.py @@ -323,16 +323,16 @@ def _cpn_get_processed_data_for_theta(self, random_seed_tuple): cpn_no_omission_labels = simulations["choice_p_no_omission"] return { - "cpn_data": np.expand_dims(theta, axis=0), - "cpn_labels": cpn_labels, - "cpn_no_omission_data": np.expand_dims(theta, axis=0), - "cpn_no_omission_labels": cpn_no_omission_labels, - "opn_data": np.expand_dims(theta, axis=0), - "opn_labels": simulations["omission_p"], - "gonogo_data": np.expand_dims(theta, axis=0), - "gonogo_labels": simulations["nogo_p"], - "theta": np.expand_dims(theta, axis=0), - } + "cpn_data": np.expand_dims(theta, axis=0), + "cpn_labels": cpn_labels, + "cpn_no_omission_data": np.expand_dims(theta, axis=0), + "cpn_no_omission_labels": cpn_no_omission_labels, + "opn_data": np.expand_dims(theta, axis=0), + "opn_labels": simulations["omission_p"], + "gonogo_data": np.expand_dims(theta, axis=0), + "gonogo_labels": simulations["nogo_p"], + "theta": np.expand_dims(theta, axis=0), + } def _get_rejected_parameter_setups(self, random_seed_tuple): np.random.seed(random_seed_tuple[0])