Skip to content

Commit

Permalink
new lba models -- porting from the other PR (#58)
Browse files Browse the repository at this point in the history
* new lba models -- porting from the other PR

* v3

* debug 1

* run black

* change default of displace t (#59)

* incorporate displace t correctly into data_generators (#60)

* bump version (#61)

* add shrink spot simple with restricted range on r parameters (#62)

* add shrink spot simple with restricted range on r parameters

* bump version

* add model to theta processor

* fix behavior where random_state=None in call to simulator() reuses random_states if called in quick succession (#64)

* model configs now include choice options explicitly (#65)

* resolved PR for merging

---------

Co-authored-by: Alexander Fengler <[email protected]>
  • Loading branch information
krishnbera and AlexanderFengler authored Jan 7, 2025
1 parent e63e7ef commit 826234d
Show file tree
Hide file tree
Showing 8 changed files with 229 additions and 41 deletions.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ requires = ["setuptools", "wheel", "Cython>=0.29.23", "numpy >= 1.20"]

[project]
name= "ssm-simulators"
version= "0.8.3"
version= "0.9.0"

authors= [{name = "Alexander Fenger", email = "[email protected]"}]
description= "SSMS is a package collecting simulators and training data generators for a bunch of generative models of interest in the cognitive science / neuroscience and approximate bayesian computation communities"
readme = "README.md"
Expand Down
115 changes: 100 additions & 15 deletions src/cssm.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -3859,7 +3859,7 @@ def lba_vanilla(np.ndarray[float, ndim = 2] v,
np.ndarray[float, ndim = 2] z,
np.ndarray[float, ndim = 1] deadline,
np.ndarray[float, ndim = 2] sd, # noise sigma
np.ndarray[float, ndim = 1] ndt, # non-decision time
np.ndarray[float, ndim = 1] t, # non-decision time
int nact = 3,
int n_samples = 2000,
int n_trials = 1,
Expand All @@ -3881,7 +3881,7 @@ def lba_vanilla(np.ndarray[float, ndim = 2] v,
Maximum allowed decision time.
sd : np.ndarray[float, ndim=1]
Standard deviation of the drift rate distribution.
ndt : np.ndarray[float, ndim=1]
t : np.ndarray[float, ndim=1]
Non-decision time.
nact : int, optional
Number of accumulators (default is 3).
Expand All @@ -3907,7 +3907,7 @@ def lba_vanilla(np.ndarray[float, ndim = 2] v,
cdef float[:, :] v_view = v
cdef float[:, :] a_view = a
cdef float[:, :] z_view = z
cdef float[:] ndt_view = ndt
cdef float[:] t_view = t
cdef float[:] deadline_view = deadline
cdef float[:, :] sd_view = sd

Expand All @@ -3929,7 +3929,7 @@ def lba_vanilla(np.ndarray[float, ndim = 2] v,
x_t = ([a_view[k]]*nact - zs)/vs

choices_view[n, k, 0] = np.argmin(x_t) # store choices for sample n
rts_view[n, k, 0] = np.min(x_t) + ndt_view[k] # store reaction time for sample n
rts_view[n, k, 0] = np.min(x_t) + t_view[k] # store reaction time for sample n

# If the rt exceeds the deadline, set rt to -999
if rts_view[n, k, 0] >= deadline_view[k]:
Expand All @@ -3945,7 +3945,7 @@ def lba_vanilla(np.ndarray[float, ndim = 2] v,
'z': z,
'deadline': deadline,
'sd': sd,
'ndt': ndt,
't': t,
'n_samples': n_samples,
'simulator' : 'lba_vanilla',
'possible_choices': list(np.arange(0, nact, 1)),
Expand All @@ -3961,7 +3961,7 @@ def lba_angle(np.ndarray[float, ndim = 2] v,
np.ndarray[float, ndim = 2] theta,
np.ndarray[float, ndim = 1] deadline,
np.ndarray[float, ndim = 2] sd, # noise sigma
np.ndarray[float, ndim = 1] ndt, # non-decision time
np.ndarray[float, ndim = 1] t, # non-decision time
int nact = 3,
int n_samples = 2000,
int n_trials = 1,
Expand All @@ -3985,7 +3985,7 @@ def lba_angle(np.ndarray[float, ndim = 2] v,
Maximum allowed decision time.
sd : np.ndarray[float, ndim=1]
Standard deviation of the drift rate distribution.
ndt : np.ndarray[float, ndim=1]
t : np.ndarray[float, ndim=1]
Non-decision time.
nact : int, optional
Number of accumulators (default is 3).
Expand All @@ -4010,7 +4010,7 @@ def lba_angle(np.ndarray[float, ndim = 2] v,
cdef float[:, :] a_view = a
cdef float[:, :] z_view = z
cdef float[:, :] theta_view = theta
cdef float[:] ndt_view = ndt
cdef float[:] t_view = t

cdef float[:] deadline_view = deadline
cdef float[:, :] sd_view = sd
Expand All @@ -4031,7 +4031,7 @@ def lba_angle(np.ndarray[float, ndim = 2] v,
x_t = ([a_view[k]]*nact - zs)/(vs + np.tan(theta_view[k, 0]))

choices_view[n, k, 0] = np.argmin(x_t) # store choices for sample n
rts_view[n, k, 0] = np.min(x_t) + ndt_view[k] # store reaction time for sample n
rts_view[n, k, 0] = np.min(x_t) + t_view[k] # store reaction time for sample n

# If the rt exceeds the deadline, set rt to -999
if rts_view[n, k, 0] >= deadline_view[k]:
Expand All @@ -4050,21 +4050,105 @@ def lba_angle(np.ndarray[float, ndim = 2] v,
'theta': theta,
'deadline': deadline,
'sd': sd,
't': t,
'n_samples': n_samples,
'simulator' : 'lba_angle',
'possible_choices': list(np.arange(0, nact, 1)),
'max_t': max_t,
}}


# Simulate (rt, choice) tuples from LBA piece-wise model -----------------------------
def rlwm_lba_pw_v1(np.ndarray[float, ndim = 2] v_RL,
np.ndarray[float, ndim = 2] v_WM,
np.ndarray[float, ndim = 2] a,
np.ndarray[float, ndim = 2] z,
np.ndarray[float, ndim = 2] t_WM,
np.ndarray[float, ndim = 1] deadline,
np.ndarray[float, ndim = 2] sd, # std dev
np.ndarray[float, ndim = 1] t, # ndt is supposed to be 0 by default because of parameter identifiability issues
int nact = 3,
int n_samples = 2000,
int n_trials = 1,
float max_t = 20,
**kwargs
):

# Param views
cdef float[:, :] v_RL_view = v_RL
cdef float[:, :] v_WM_view = v_WM
cdef float[:, :] a_view = a
cdef float[:, :] z_view = z
cdef float[:, :] t_WM_view = t_WM
cdef float[:] t_view = t

cdef float[:] deadline_view = deadline
cdef float[:, :] sd_view = sd

cdef np.ndarray[float, ndim = 1] zs
cdef np.ndarray[double, ndim = 2] x_t_RL
cdef np.ndarray[double, ndim = 2] x_t_WM
cdef np.ndarray[double, ndim = 1] vs_RL
cdef np.ndarray[double, ndim = 1] vs_WM

rts = np.zeros((n_samples, n_trials, 1), dtype = DTYPE)
cdef float[:, :, :] rts_view = rts

choices = np.zeros((n_samples, n_trials, 1), dtype = np.intc)
cdef int[:, :, :] choices_view = choices

cdef Py_ssize_t n, k, i

for k in range(n_trials):

for n in range(n_samples):
zs = np.random.uniform(0, z_view[k], nact).astype(DTYPE)

vs_RL = np.abs(np.random.normal(v_RL_view[k], sd_view[k])) # np.abs() to avoid negative vs
vs_WM = np.abs(np.random.normal(v_WM_view[k], sd_view[k])) # np.abs() to avoid negative vs

x_t_RL = ([a_view[k]]*nact - zs)/vs_RL
# x_t_WM = ([a_view[k]]*nact - zs)/vs_WM

if np.min(x_t_RL) < t_WM_view[k]:
x_t = x_t_RL
else:
x_t = t_WM_view[k] + ( [a_view[k]]*nact - zs - ([t_WM_view[k]]*nact)*vs_RL ) / ( vs_RL + vs_WM )

choices_view[n, k, 0] = np.argmin(x_t) # store choices for sample n
rts_view[n, k, 0] = np.min(x_t) + t_view[k] # store reaction time for sample n

# If the rt exceeds the deadline, set rt to -999
if rts_view[n, k, 0] >= deadline_view[k]:
rts_view[n, k, 0] = -999


v_dict = {}
for i in range(nact):
v_dict['v_RL_' + str(i)] = v_RL[:, i]
v_dict['v_WM_' + str(i)] = v_WM[:, i]

return {'rts': rts, 'choices': choices, 'metadata': {**v_dict,
'a': a,
'z': z,
't_WM': t_WM,
't': t,
'deadline': deadline,
'sd': sd,
'n_samples': n_samples,
'simulator' : 'rlwm_lba_pw_v1',
'possible_choices': list(np.arange(0, nact, 1)),
'max_t': max_t,
}}

# Simulate (rt, choice) tuples from: RLWM LBA Race Model without ndt -----------------------------
def rlwm_lba_race(np.ndarray[float, ndim = 2] v_RL, # RL drift parameters (np.array expect: one column of floats)
np.ndarray[float, ndim = 2] v_WM, # WM drift parameters (np.array expect: one column of floats)
np.ndarray[float, ndim = 2] a, # criterion height
np.ndarray[float, ndim = 2] z, # initial bias parameters (np.array expect: one column of floats)
np.ndarray[float, ndim = 1] deadline,
np.ndarray[float, ndim = 2] sd, # noise sigma
np.ndarray[float, ndim = 1] ndt, # non-decision time
np.ndarray[float, ndim = 1] t, # non-decision time
int nact = 3,
int n_samples = 2000,
int n_trials = 1,
Expand All @@ -4088,7 +4172,7 @@ def rlwm_lba_race(np.ndarray[float, ndim = 2] v_RL, # RL drift parameters (np.ar
Maximum allowed decision time.
sd : np.ndarray[float, ndim=1]
Standard deviation of the drift rate distribution.
ndt : np.ndarray[float, ndim=1]
t : np.ndarray[float, ndim=1]
Non-decision time.
nact : int, optional
Number of accumulators (default is 3).
Expand All @@ -4113,7 +4197,7 @@ def rlwm_lba_race(np.ndarray[float, ndim = 2] v_RL, # RL drift parameters (np.ar
cdef float[:, :] v_WM_view = v_WM
cdef float[:, :] a_view = a
cdef float[:, :] z_view = z
cdef float[:] ndt_view = ndt
cdef float[:] t_view = t

cdef float[:] deadline_view = deadline
cdef float[:, :] sd_view = sd
Expand Down Expand Up @@ -4143,10 +4227,10 @@ def rlwm_lba_race(np.ndarray[float, ndim = 2] v_RL, # RL drift parameters (np.ar
x_t_WM = ([a_view[k]]*nact - zs)/vs_WM

if np.min(x_t_RL) <= np.min(x_t_WM):
rts_view[n, k, 0] = np.min(x_t_RL) + ndt_view[k] # store reaction time for sample n
rts_view[n, k, 0] = np.min(x_t_RL) + t_view[k] # store reaction time for sample n
choices_view[n, k, 0] = np.argmin(x_t_RL) # store choices for sample n
else:
rts_view[n, k, 0] = np.min(x_t_WM) + ndt_view[k] # store reaction time for sample n
rts_view[n, k, 0] = np.min(x_t_WM) + t_view[k] # store reaction time for sample n
choices_view[n, k, 0] = np.argmin(x_t_WM) # store choices for sample n

# If the rt exceeds the deadline, set rt to -999
Expand All @@ -4162,9 +4246,10 @@ def rlwm_lba_race(np.ndarray[float, ndim = 2] v_RL, # RL drift parameters (np.ar
return {'rts': rts, 'choices': choices, 'metadata': {**v_dict,
'a': a,
'z': z,
't': 0,
'deadline': deadline,
'sd': sd,
'ndt': ndt,
't': t,
'n_samples': n_samples,
'simulator' : 'rlwm_lba_race',
'possible_choices': list(np.arange(0, nact, 1)),
Expand Down
3 changes: 2 additions & 1 deletion ssms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from . import config
from . import support_utils

__version__ = "0.8.3" # importlib.metadata.version(__package__ or __name__)

__version__ = "0.9.0" # importlib.metadata.version(__package__ or __name__)

__all__ = ["basic_simulators", "dataset_generators", "config", "support_utils"]
23 changes: 18 additions & 5 deletions ssms/basic_simulators/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,16 +539,29 @@ def check_if_z_gt_a(z: np.ndarray, a: np.ndarray) -> None:
if np.any(z >= a):
raise ValueError("Starting point z >= a for at least one trial")

if model in ["lba_3_v1", "lba_angle_3_v1", "rlwm_lba_race_v1"]:
if model in ["lba_3_v1", "lba_angle_3_v1"]:
if model in [
"lba_3_vs_constraint",
"lba_angle_3_vs_constraint",
"lba_angle_3",
"dev_rlwm_lba_race_v1",
"dev_rlwm_lba_race_v2",
"dev_rlwm_lba_pw_v1",
]:
if model in ["lba_3_vs_constraint", "lba_angle_3_vs_constraint"]:
check_lba_drifts_sum(theta["v"])
check_if_z_gt_a(theta["z"], theta["a"])
elif model in ["rlwm_lba_race_v1"]:
elif model in ["dev_rlwm_lba_race_v1"]:
check_lba_drifts_sum(theta["v_RL"])
check_lba_drifts_sum(theta["v_WM"])
check_if_z_gt_a(theta["z"], theta["a"])
elif model in ["lba3", "lba2"]:
check_if_z_gt_a(theta["z"], theta["a"])
elif model in [
"lba3",
"lba2",
"lba_angle_3",
"dev_rlwm_lba_pw_v1",
"dev_rlwm_lba_race_v2",
]:
check_if_z_gt_a(theta["z"], theta["a"])


def make_noise_vec(
Expand Down
28 changes: 20 additions & 8 deletions ssms/basic_simulators/theta_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def process_theta(
theta["v"] = np.column_stack([theta["v0"], theta["v1"]])
theta["z"] = np.expand_dims(theta["A"], axis=1)
theta["a"] = np.expand_dims(theta["b"], axis=1)
theta["ndt"] = np.zeros(n_trials).astype(np.float32)
theta["t"] = np.zeros(n_trials).astype(np.float32)

del theta["A"]
del theta["b"]
Expand All @@ -147,25 +147,25 @@ def process_theta(

theta["z"] = np.expand_dims(theta["A"], axis=1)
theta["a"] = np.expand_dims(theta["b"], axis=1)
theta["ndt"] = np.zeros(n_trials).astype(np.float32)
theta["t"] = np.zeros(n_trials).astype(np.float32)

del theta["A"]
del theta["b"]

if model == "lba_3_v1":
if model == "lba_3_vs_constraint":
theta["v"] = np.column_stack([theta["v0"], theta["v1"], theta["v2"]])
theta["a"] = np.expand_dims(theta["a"], axis=1)
theta["z"] = np.expand_dims(theta["z"], axis=1)
theta["ndt"] = np.zeros(n_trials).astype(np.float32)
theta["t"] = np.zeros(n_trials).astype(np.float32)

if model == "lba_angle_3_v1":
if model in ["lba_angle_3_vs_constraint", "lba_angle_3"]:
theta["v"] = np.column_stack([theta["v0"], theta["v1"], theta["v2"]])
theta["a"] = np.expand_dims(theta["a"], axis=1)
theta["z"] = np.expand_dims(theta["z"], axis=1)
theta["theta"] = np.expand_dims(theta["theta"], axis=1)
theta["ndt"] = np.zeros(n_trials).astype(np.float32)
theta["t"] = np.zeros(n_trials).astype(np.float32)

if model == "rlwm_lba_race_v1":
if model in ["dev_rlwm_lba_race_v1", "dev_rlwm_lba_race_v2"]:
theta["v_RL"] = np.column_stack(
[theta["v_RL_0"], theta["v_RL_1"], theta["v_RL_2"]]
)
Expand All @@ -174,7 +174,19 @@ def process_theta(
)
theta["a"] = np.expand_dims(theta["a"], axis=1)
theta["z"] = np.expand_dims(theta["z"], axis=1)
theta["ndt"] = np.zeros(n_trials).astype(np.float32)
theta["t"] = np.zeros(n_trials).astype(np.float32)

if model == "dev_rlwm_lba_pw_v1":
theta["v_RL"] = np.column_stack(
[theta["v_RL_0"], theta["v_RL_1"], theta["v_RL_2"]]
)
theta["v_WM"] = np.column_stack(
[theta["v_WM_0"], theta["v_WM_1"], theta["v_WM_2"]]
)
theta["a"] = np.expand_dims(theta["a"], axis=1)
theta["z"] = np.expand_dims(theta["z"], axis=1)
theta["t_WM"] = np.expand_dims(theta["t_WM"], axis=1)
theta["t"] = np.zeros(n_trials).astype(np.float32)

# 2 Choice
if model == "race_2":
Expand Down
Loading

0 comments on commit 826234d

Please sign in to comment.