Skip to content

Commit

Permalink
run black and ruff
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexanderFengler committed Feb 2, 2024
1 parent be9405a commit 3086b0a
Show file tree
Hide file tree
Showing 11 changed files with 442 additions and 2,708 deletions.
2,696 changes: 211 additions & 2,485 deletions notebooks/basic_tutorial.ipynb

Large diffs are not rendered by default.

8 changes: 2 additions & 6 deletions notebooks/basic_tutorial_tmp.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -344,9 +344,7 @@
}
],
"source": [
"my_dataset_generator = ssms.dataset_generators.data_generator(\n",
" generator_config=my_data_config, model_config=my_model_config\n",
")"
"my_dataset_generator = ssms.dataset_generators.data_generator(generator_config=my_data_config, model_config=my_model_config)"
]
},
{
Expand Down Expand Up @@ -479,9 +477,7 @@
}
],
"source": [
"my_dataset_generator = ssms.dataset_generators.data_generator(\n",
" generator_config=my_data_config, model_config=my_model_config\n",
")"
"my_dataset_generator = ssms.dataset_generators.data_generator(generator_config=my_data_config, model_config=my_model_config)"
]
},
{
Expand Down
116 changes: 62 additions & 54 deletions notebooks/test_deadline.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -506,7 +506,7 @@
"metadata": {},
"outputs": [],
"source": [
"'ddm_boundaryfun_driftfun_deadline'"
"\"ddm_boundaryfun_driftfun_deadline\""
]
},
{
Expand Down Expand Up @@ -573,13 +573,19 @@
"outputs": [],
"source": [
"out_dict = {}\n",
"out_dict['choice_p'] = {}\n",
"out_dict['choice_p_no_omission'] = {}\n",
"out_dict['p_omission'] = {}\n",
"for choice in simulations['metadata']['possible_choices']:\n",
" out_dict['choice_p'][choice] = np.array([(simulations[\"choices\"] == choice).sum() / simulations[\"choices\"].flatten().shape[0]])\n",
" out_dict['choice_p_no_omission'][choice] = np.array([(simulations[\"choices\"][simulations['rts'] != -999] == choice).sum() / simulations[\"choices\"].flatten().shape[0]])\n",
" out_dict['p_omission'][choice] = np.array([(simulations[\"rts\"] == -999).sum() / simulations[\"choices\"].flatten().shape[0]])"
"out_dict[\"choice_p\"] = {}\n",
"out_dict[\"choice_p_no_omission\"] = {}\n",
"out_dict[\"p_omission\"] = {}\n",
"for choice in simulations[\"metadata\"][\"possible_choices\"]:\n",
" out_dict[\"choice_p\"][choice] = np.array(\n",
" [(simulations[\"choices\"] == choice).sum() / simulations[\"choices\"].flatten().shape[0]]\n",
" )\n",
" out_dict[\"choice_p_no_omission\"][choice] = np.array(\n",
" [(simulations[\"choices\"][simulations[\"rts\"] != -999] == choice).sum() / simulations[\"choices\"].flatten().shape[0]]\n",
" )\n",
" out_dict[\"p_omission\"][choice] = np.array(\n",
" [(simulations[\"rts\"] == -999).sum() / simulations[\"choices\"].flatten().shape[0]]\n",
" )"
]
},
{
Expand All @@ -600,7 +606,7 @@
}
],
"source": [
"simulations['rts'] != -999"
"simulations[\"rts\"] != -999"
]
},
{
Expand Down Expand Up @@ -664,7 +670,7 @@
}
],
"source": [
"out['metadata']"
"out[\"metadata\"]"
]
},
{
Expand All @@ -683,17 +689,18 @@
],
"source": [
"from copy import deepcopy\n",
"\n",
"v = 1.0\n",
"a = 2.0\n",
"z = 0.5\n",
"t = 0.0\n",
"theta = 0.7\n",
"deadline = 10\n",
"out = simulator(model=\"angle_deadline\", theta=[v, a, z, t, theta, deadline], n_samples=10000, max_t=20)\n",
"out_log = deepcopy(out) \n",
"out_log = deepcopy(out)\n",
"out_log[\"log_rts\"] = np.ones(out[\"rts\"].shape) * -999\n",
"out_log[\"log_rts\"][out_log['rts'] != -999] = np.log(out_log[\"rts\"][out_log['rts'] != -999])\n",
"del out_log['rts']"
"out_log[\"log_rts\"][out_log[\"rts\"] != -999] = np.log(out_log[\"rts\"][out_log[\"rts\"] != -999])\n",
"del out_log[\"rts\"]"
]
},
{
Expand Down Expand Up @@ -751,9 +758,11 @@
"sample_kde = my_kde.kde_sample(10000)\n",
"sample_kde_shifted = my_kde_shifted.kde_sample(10000)\n",
"sample_kde_shifted_log = my_kde_shifted_log.kde_sample(10000)\n",
"plt.hist(sample_kde[0] * sample_kde[1], bins = 50, density=True, histtype='step', color='blue')\n",
"plt.hist(sample_kde_shifted['rts'] * sample_kde_shifted['choices'], bins = 50, density=True, histtype='step', color='red')\n",
"plt.hist(sample_kde_shifted_log['rts'] * sample_kde_shifted_log['choices'], bins = 50, density=True, histtype='step', color='green')"
"plt.hist(sample_kde[0] * sample_kde[1], bins=50, density=True, histtype=\"step\", color=\"blue\")\n",
"plt.hist(sample_kde_shifted[\"rts\"] * sample_kde_shifted[\"choices\"], bins=50, density=True, histtype=\"step\", color=\"red\")\n",
"plt.hist(\n",
" sample_kde_shifted_log[\"rts\"] * sample_kde_shifted_log[\"choices\"], bins=50, density=True, histtype=\"step\", color=\"green\"\n",
")"
]
},
{
Expand Down Expand Up @@ -855,8 +864,8 @@
}
],
"source": [
"my_kde.kde_eval((data_1['rts'], data_1['choices']))\n",
"my_kde.kde_sample(n_samples = 10000)"
"my_kde.kde_eval((data_1[\"rts\"], data_1[\"choices\"]))\n",
"my_kde.kde_sample(n_samples=10000)"
]
},
{
Expand All @@ -876,7 +885,7 @@
}
],
"source": [
"my_kde_shifted.kde_sample(n_samples=10000)['rts'].shape"
"my_kde_shifted.kde_sample(n_samples=10000)[\"rts\"].shape"
]
},
{
Expand All @@ -895,15 +904,11 @@
"metadata": {},
"outputs": [],
"source": [
"data_1 = {'rts': np.linspace(0.01, 10, 1000),\n",
" 'choices': np.ones(1000)}\n",
"data_m1 = {'rts': np.linspace(0.01, 10, 1000), \n",
" 'choices': (-1) * np.ones(1000)}\n",
"data_1 = {\"rts\": np.linspace(0.01, 10, 1000), \"choices\": np.ones(1000)}\n",
"data_m1 = {\"rts\": np.linspace(0.01, 10, 1000), \"choices\": (-1) * np.ones(1000)}\n",
"\n",
"data_l1 = {'log_rts': np.log(np.linspace(0.01, 10, 1000)),\n",
" 'choices': np.ones(1000)}\n",
"data_lm1 = {'log_rts': np.log(np.linspace(0.01, 10, 1000)),\n",
" 'choices': (-1) * np.ones(1000)}\n",
"data_l1 = {\"log_rts\": np.log(np.linspace(0.01, 10, 1000)), \"choices\": np.ones(1000)}\n",
"data_lm1 = {\"log_rts\": np.log(np.linspace(0.01, 10, 1000)), \"choices\": (-1) * np.ones(1000)}\n",
"\n",
"# data_m1 = (np.linspace(0.01, 10, 1000), np.ones(1000) * (-1))\n",
"\n",
Expand Down Expand Up @@ -935,7 +940,7 @@
"evals_m1 = my_kde_shifted.kde_eval(data_m1)\n",
"\n",
"evals_l1 = my_kde_shifted.kde_eval(data_1)\n",
"print('this is the problem')\n",
"print(\"this is the problem\")\n",
"evals_lm1 = my_kde_shifted.kde_eval(data_lm1)\n",
"\n",
"# evals_1_shifted = my_kde_shifted.kde_eval(data_1_shifted)\n",
Expand Down Expand Up @@ -986,8 +991,9 @@
"# my_kde_shifted.kde_eval(data_1)\n",
"# my_kde_log_shifted.kde_eval(data_1)\n",
"from matplotlib import pyplot as plt\n",
"plt.plot(data_1['rts'], np.exp(my_kde_shifted.kde_eval(data_1)), color=\"blue\", label='')\n",
"plt.plot(data_m1['rts'] * (-1), np.exp(my_kde_shifted.kde_eval(data_m1)), color=\"blue\", label='')"
"\n",
"plt.plot(data_1[\"rts\"], np.exp(my_kde_shifted.kde_eval(data_1)), color=\"blue\", label=\"\")\n",
"plt.plot(data_m1[\"rts\"] * (-1), np.exp(my_kde_shifted.kde_eval(data_m1)), color=\"blue\", label=\"\")"
]
},
{
Expand Down Expand Up @@ -1035,8 +1041,18 @@
}
],
"source": [
"plt.plot(np.exp(data_l1['log_rts']), np.exp(np.squeeze(my_kde_shifted.kde_eval(data_l1)) - np.log(np.exp(data_lm1['log_rts']) - t)), color=\"blue\", label='')\n",
"plt.plot(np.exp(data_lm1['log_rts']) * (-1), np.exp(np.squeeze(my_kde_shifted.kde_eval(data_lm1)) - np.log(np.exp(data_lm1['log_rts']) - t)), color=\"blue\", label='')"
"plt.plot(\n",
" np.exp(data_l1[\"log_rts\"]),\n",
" np.exp(np.squeeze(my_kde_shifted.kde_eval(data_l1)) - np.log(np.exp(data_lm1[\"log_rts\"]) - t)),\n",
" color=\"blue\",\n",
" label=\"\",\n",
")\n",
"plt.plot(\n",
" np.exp(data_lm1[\"log_rts\"]) * (-1),\n",
" np.exp(np.squeeze(my_kde_shifted.kde_eval(data_lm1)) - np.log(np.exp(data_lm1[\"log_rts\"]) - t)),\n",
" color=\"blue\",\n",
" label=\"\",\n",
")"
]
},
{
Expand Down Expand Up @@ -1329,7 +1345,7 @@
}
],
"source": [
"np.log(np.exp(data_lm1['log_rts']) - t)"
"np.log(np.exp(data_lm1[\"log_rts\"]) - t)"
]
},
{
Expand Down Expand Up @@ -1460,7 +1476,7 @@
}
],
"source": [
"np.exp(data_lm1['log_rts'])"
"np.exp(data_lm1[\"log_rts\"])"
]
},
{
Expand All @@ -1469,7 +1485,7 @@
"metadata": {},
"outputs": [],
"source": [
"data_l1 = np.log(data_1)\n"
"data_l1 = np.log(data_1)"
]
},
{
Expand Down Expand Up @@ -1576,8 +1592,8 @@
}
],
"source": [
"out_kde_shifted = my_kde_shifted.kde_sample(n_samples = 10000)\n",
"out_kde = my_kde.kde_sample(n_samples = 10000)"
"out_kde_shifted = my_kde_shifted.kde_sample(n_samples=10000)\n",
"out_kde = my_kde.kde_sample(n_samples=10000)"
]
},
{
Expand Down Expand Up @@ -1660,8 +1676,8 @@
"# plt.plot(data_m1[0] * (-1), np.exp(evals_m1), color=\"blue\")\n",
"# plt.plot(data_1_shifted[0], np.exp(evals_1_shifted), color=\"red\")\n",
"# plt.plot(data_m1_shifted[0] * (-1), np.exp(evals_m1_shifted), color=\"red\")\n",
"plt.hist(out_kde_shifted[0] * out_kde_shifted[1], bins = 100, histtype=\"step\", density=True)\n",
"plt.hist(out_kde[0] * out_kde[1], bins = 100, histtype=\"step\", density=True)\n",
"plt.hist(out_kde_shifted[0] * out_kde_shifted[1], bins=100, histtype=\"step\", density=True)\n",
"plt.hist(out_kde[0] * out_kde[1], bins=100, histtype=\"step\", density=True)\n",
"\n",
"\n",
"plt.hist(out[\"rts\"][out[\"rts\"] != -999] * out[\"choices\"][out[\"rts\"] != -999], bins=40, histtype=\"step\", density=True)"
Expand Down Expand Up @@ -1735,8 +1751,8 @@
}
],
"source": [
"#plt.hist(out_kde_shifted[0] * out_kde_shifted[1], bins = 100, histtype=\"step\", density=True)\n",
"plt.hist(np.log(out_kde[0][out_kde[1] == 1]), bins = 100, histtype=\"step\", density=True)\n"
"# plt.hist(out_kde_shifted[0] * out_kde_shifted[1], bins = 100, histtype=\"step\", density=True)\n",
"plt.hist(np.log(out_kde[0][out_kde[1] == 1]), bins=100, histtype=\"step\", density=True)"
]
},
{
Expand Down Expand Up @@ -1771,7 +1787,7 @@
}
],
"source": [
"plt.hist(np.exp(np.log(np.random.uniform(size = 100000))))"
"plt.hist(np.exp(np.log(np.random.uniform(size=100000))))"
]
},
{
Expand Down Expand Up @@ -1842,9 +1858,7 @@
"from time import time\n",
"\n",
"start = time()\n",
"out_traj = simulator(\n",
" model=\"ddm_mic2_multinoise_no_bias\", theta=[1.0, 1.0, 1.0, 1.5, 0.5, 1.0], n_samples=100000, max_t=20\n",
")\n",
"out_traj = simulator(model=\"ddm_mic2_multinoise_no_bias\", theta=[1.0, 1.0, 1.0, 1.5, 0.5, 1.0], n_samples=100000, max_t=20)\n",
"\n",
"end = time()\n",
"\n",
Expand Down Expand Up @@ -1904,15 +1918,9 @@
}
],
"source": [
"plt.hist(\n",
" out_traj[\"rts\"][(out_traj[\"choices\"] == 0) & (out_traj[\"rts\"] != -999)], histtype=\"step\", bins=40, label=\"low low\"\n",
")\n",
"plt.hist(\n",
" out_traj[\"rts\"][(out_traj[\"choices\"] == 1) & (out_traj[\"rts\"] != -999)], histtype=\"step\", bins=40, label=\"low high\"\n",
")\n",
"plt.hist(\n",
" out_traj[\"rts\"][(out_traj[\"choices\"] == 2) & (out_traj[\"rts\"] != -999)], histtype=\"step\", bins=40, label=\"high low\"\n",
")\n",
"plt.hist(out_traj[\"rts\"][(out_traj[\"choices\"] == 0) & (out_traj[\"rts\"] != -999)], histtype=\"step\", bins=40, label=\"low low\")\n",
"plt.hist(out_traj[\"rts\"][(out_traj[\"choices\"] == 1) & (out_traj[\"rts\"] != -999)], histtype=\"step\", bins=40, label=\"low high\")\n",
"plt.hist(out_traj[\"rts\"][(out_traj[\"choices\"] == 2) & (out_traj[\"rts\"] != -999)], histtype=\"step\", bins=40, label=\"high low\")\n",
"plt.hist(\n",
" out_traj[\"rts\"][(out_traj[\"choices\"] == 3) & (out_traj[\"rts\"] != -999)], histtype=\"step\", bins=40, label=\"high high\"\n",
")\n",
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ requires = ["setuptools", "wheel", "Cython>=0.29.23", "numpy >= 1.20"]

[project]
name= "ssm-simulators"
version= "0.6.1"
version= "0.7.0"
authors= [{name = "Alexander Fenger", email = "[email protected]"}]
description= "SSMS is a package collecting simulators and training data generators for a bunch of generative models of interest in the cognitive science / neuroscience and approximate bayesian computation communities"
readme = "README.md"
Expand Down
2 changes: 1 addition & 1 deletion ssms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@
from . import config
from . import support_utils

__version__ = "0.6.1" # importlib.metadata.version(__package__ or __name__)
__version__ = "0.7.0" # importlib.metadata.version(__package__ or __name__)

__all__ = ["basic_simulators", "dataset_generators", "config", "support_utils"]
3 changes: 1 addition & 2 deletions ssms/basic_simulators/boundary_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,5 @@ def conflict_gamma(
"""

return (
scale * gamma.pdf(t, a=alpha_gamma, loc=0, scale=scale_gamma)
+ np.multiply(t, (-np.sin(theta) / np.cos(theta))),
scale * gamma.pdf(t, a=alpha_gamma, loc=0, scale=scale_gamma) + np.multiply(t, (-np.sin(theta) / np.cos(theta))),
)
3 changes: 3 additions & 0 deletions ssms/basic_simulators/drift_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
This module defines a collection of drift functions for the simulators in the package.
"""


def constant(t=np.arange(0, 20, 0.1)):
"""constant drift function
Expand Down Expand Up @@ -49,6 +50,7 @@ def gamma_drift(t=np.arange(0, 20, 0.1), shape=2, scale=0.01, c=1.5):
div_ = np.power(shape - 1, shape - 1) * np.power(scale, shape - 1) * np.exp(-(shape - 1))
return c * np.divide(num_, div_)


def ds_support_analytic(t=np.arange(0, 10, 0.001), init_p=0, fix_point=1, slope=2):
"""Solution to differential equation of the form:
x' = slope*(fix_point - x),
Expand All @@ -75,6 +77,7 @@ def ds_support_analytic(t=np.arange(0, 10, 0.001), init_p=0, fix_point=1, slope=

return (init_p - fix_point) * np.exp(-(slope * t)) + fix_point


def ds_conflict_drift(
t=np.arange(0, 10, 0.001),
init_p_t=0,
Expand Down
Loading

0 comments on commit 3086b0a

Please sign in to comment.