Skip to content

Commit 3086b0a

Browse files
run black and ruff
1 parent be9405a commit 3086b0a

File tree

11 files changed

+442
-2708
lines changed

11 files changed

+442
-2708
lines changed

notebooks/basic_tutorial.ipynb

Lines changed: 211 additions & 2485 deletions
Large diffs are not rendered by default.

notebooks/basic_tutorial_tmp.ipynb

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -344,9 +344,7 @@
344344
}
345345
],
346346
"source": [
347-
"my_dataset_generator = ssms.dataset_generators.data_generator(\n",
348-
" generator_config=my_data_config, model_config=my_model_config\n",
349-
")"
347+
"my_dataset_generator = ssms.dataset_generators.data_generator(generator_config=my_data_config, model_config=my_model_config)"
350348
]
351349
},
352350
{
@@ -479,9 +477,7 @@
479477
}
480478
],
481479
"source": [
482-
"my_dataset_generator = ssms.dataset_generators.data_generator(\n",
483-
" generator_config=my_data_config, model_config=my_model_config\n",
484-
")"
480+
"my_dataset_generator = ssms.dataset_generators.data_generator(generator_config=my_data_config, model_config=my_model_config)"
485481
]
486482
},
487483
{

notebooks/test_deadline.ipynb

Lines changed: 62 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -506,7 +506,7 @@
506506
"metadata": {},
507507
"outputs": [],
508508
"source": [
509-
"'ddm_boundaryfun_driftfun_deadline'"
509+
"\"ddm_boundaryfun_driftfun_deadline\""
510510
]
511511
},
512512
{
@@ -573,13 +573,19 @@
573573
"outputs": [],
574574
"source": [
575575
"out_dict = {}\n",
576-
"out_dict['choice_p'] = {}\n",
577-
"out_dict['choice_p_no_omission'] = {}\n",
578-
"out_dict['p_omission'] = {}\n",
579-
"for choice in simulations['metadata']['possible_choices']:\n",
580-
" out_dict['choice_p'][choice] = np.array([(simulations[\"choices\"] == choice).sum() / simulations[\"choices\"].flatten().shape[0]])\n",
581-
" out_dict['choice_p_no_omission'][choice] = np.array([(simulations[\"choices\"][simulations['rts'] != -999] == choice).sum() / simulations[\"choices\"].flatten().shape[0]])\n",
582-
" out_dict['p_omission'][choice] = np.array([(simulations[\"rts\"] == -999).sum() / simulations[\"choices\"].flatten().shape[0]])"
576+
"out_dict[\"choice_p\"] = {}\n",
577+
"out_dict[\"choice_p_no_omission\"] = {}\n",
578+
"out_dict[\"p_omission\"] = {}\n",
579+
"for choice in simulations[\"metadata\"][\"possible_choices\"]:\n",
580+
" out_dict[\"choice_p\"][choice] = np.array(\n",
581+
" [(simulations[\"choices\"] == choice).sum() / simulations[\"choices\"].flatten().shape[0]]\n",
582+
" )\n",
583+
" out_dict[\"choice_p_no_omission\"][choice] = np.array(\n",
584+
" [(simulations[\"choices\"][simulations[\"rts\"] != -999] == choice).sum() / simulations[\"choices\"].flatten().shape[0]]\n",
585+
" )\n",
586+
" out_dict[\"p_omission\"][choice] = np.array(\n",
587+
" [(simulations[\"rts\"] == -999).sum() / simulations[\"choices\"].flatten().shape[0]]\n",
588+
" )"
583589
]
584590
},
585591
{
@@ -600,7 +606,7 @@
600606
}
601607
],
602608
"source": [
603-
"simulations['rts'] != -999"
609+
"simulations[\"rts\"] != -999"
604610
]
605611
},
606612
{
@@ -664,7 +670,7 @@
664670
}
665671
],
666672
"source": [
667-
"out['metadata']"
673+
"out[\"metadata\"]"
668674
]
669675
},
670676
{
@@ -683,17 +689,18 @@
683689
],
684690
"source": [
685691
"from copy import deepcopy\n",
692+
"\n",
686693
"v = 1.0\n",
687694
"a = 2.0\n",
688695
"z = 0.5\n",
689696
"t = 0.0\n",
690697
"theta = 0.7\n",
691698
"deadline = 10\n",
692699
"out = simulator(model=\"angle_deadline\", theta=[v, a, z, t, theta, deadline], n_samples=10000, max_t=20)\n",
693-
"out_log = deepcopy(out) \n",
700+
"out_log = deepcopy(out)\n",
694701
"out_log[\"log_rts\"] = np.ones(out[\"rts\"].shape) * -999\n",
695-
"out_log[\"log_rts\"][out_log['rts'] != -999] = np.log(out_log[\"rts\"][out_log['rts'] != -999])\n",
696-
"del out_log['rts']"
702+
"out_log[\"log_rts\"][out_log[\"rts\"] != -999] = np.log(out_log[\"rts\"][out_log[\"rts\"] != -999])\n",
703+
"del out_log[\"rts\"]"
697704
]
698705
},
699706
{
@@ -751,9 +758,11 @@
751758
"sample_kde = my_kde.kde_sample(10000)\n",
752759
"sample_kde_shifted = my_kde_shifted.kde_sample(10000)\n",
753760
"sample_kde_shifted_log = my_kde_shifted_log.kde_sample(10000)\n",
754-
"plt.hist(sample_kde[0] * sample_kde[1], bins = 50, density=True, histtype='step', color='blue')\n",
755-
"plt.hist(sample_kde_shifted['rts'] * sample_kde_shifted['choices'], bins = 50, density=True, histtype='step', color='red')\n",
756-
"plt.hist(sample_kde_shifted_log['rts'] * sample_kde_shifted_log['choices'], bins = 50, density=True, histtype='step', color='green')"
761+
"plt.hist(sample_kde[0] * sample_kde[1], bins=50, density=True, histtype=\"step\", color=\"blue\")\n",
762+
"plt.hist(sample_kde_shifted[\"rts\"] * sample_kde_shifted[\"choices\"], bins=50, density=True, histtype=\"step\", color=\"red\")\n",
763+
"plt.hist(\n",
764+
" sample_kde_shifted_log[\"rts\"] * sample_kde_shifted_log[\"choices\"], bins=50, density=True, histtype=\"step\", color=\"green\"\n",
765+
")"
757766
]
758767
},
759768
{
@@ -855,8 +864,8 @@
855864
}
856865
],
857866
"source": [
858-
"my_kde.kde_eval((data_1['rts'], data_1['choices']))\n",
859-
"my_kde.kde_sample(n_samples = 10000)"
867+
"my_kde.kde_eval((data_1[\"rts\"], data_1[\"choices\"]))\n",
868+
"my_kde.kde_sample(n_samples=10000)"
860869
]
861870
},
862871
{
@@ -876,7 +885,7 @@
876885
}
877886
],
878887
"source": [
879-
"my_kde_shifted.kde_sample(n_samples=10000)['rts'].shape"
888+
"my_kde_shifted.kde_sample(n_samples=10000)[\"rts\"].shape"
880889
]
881890
},
882891
{
@@ -895,15 +904,11 @@
895904
"metadata": {},
896905
"outputs": [],
897906
"source": [
898-
"data_1 = {'rts': np.linspace(0.01, 10, 1000),\n",
899-
" 'choices': np.ones(1000)}\n",
900-
"data_m1 = {'rts': np.linspace(0.01, 10, 1000), \n",
901-
" 'choices': (-1) * np.ones(1000)}\n",
907+
"data_1 = {\"rts\": np.linspace(0.01, 10, 1000), \"choices\": np.ones(1000)}\n",
908+
"data_m1 = {\"rts\": np.linspace(0.01, 10, 1000), \"choices\": (-1) * np.ones(1000)}\n",
902909
"\n",
903-
"data_l1 = {'log_rts': np.log(np.linspace(0.01, 10, 1000)),\n",
904-
" 'choices': np.ones(1000)}\n",
905-
"data_lm1 = {'log_rts': np.log(np.linspace(0.01, 10, 1000)),\n",
906-
" 'choices': (-1) * np.ones(1000)}\n",
910+
"data_l1 = {\"log_rts\": np.log(np.linspace(0.01, 10, 1000)), \"choices\": np.ones(1000)}\n",
911+
"data_lm1 = {\"log_rts\": np.log(np.linspace(0.01, 10, 1000)), \"choices\": (-1) * np.ones(1000)}\n",
907912
"\n",
908913
"# data_m1 = (np.linspace(0.01, 10, 1000), np.ones(1000) * (-1))\n",
909914
"\n",
@@ -935,7 +940,7 @@
935940
"evals_m1 = my_kde_shifted.kde_eval(data_m1)\n",
936941
"\n",
937942
"evals_l1 = my_kde_shifted.kde_eval(data_1)\n",
938-
"print('this is the problem')\n",
943+
"print(\"this is the problem\")\n",
939944
"evals_lm1 = my_kde_shifted.kde_eval(data_lm1)\n",
940945
"\n",
941946
"# evals_1_shifted = my_kde_shifted.kde_eval(data_1_shifted)\n",
@@ -986,8 +991,9 @@
986991
"# my_kde_shifted.kde_eval(data_1)\n",
987992
"# my_kde_log_shifted.kde_eval(data_1)\n",
988993
"from matplotlib import pyplot as plt\n",
989-
"plt.plot(data_1['rts'], np.exp(my_kde_shifted.kde_eval(data_1)), color=\"blue\", label='')\n",
990-
"plt.plot(data_m1['rts'] * (-1), np.exp(my_kde_shifted.kde_eval(data_m1)), color=\"blue\", label='')"
994+
"\n",
995+
"plt.plot(data_1[\"rts\"], np.exp(my_kde_shifted.kde_eval(data_1)), color=\"blue\", label=\"\")\n",
996+
"plt.plot(data_m1[\"rts\"] * (-1), np.exp(my_kde_shifted.kde_eval(data_m1)), color=\"blue\", label=\"\")"
991997
]
992998
},
993999
{
@@ -1035,8 +1041,18 @@
10351041
}
10361042
],
10371043
"source": [
1038-
"plt.plot(np.exp(data_l1['log_rts']), np.exp(np.squeeze(my_kde_shifted.kde_eval(data_l1)) - np.log(np.exp(data_lm1['log_rts']) - t)), color=\"blue\", label='')\n",
1039-
"plt.plot(np.exp(data_lm1['log_rts']) * (-1), np.exp(np.squeeze(my_kde_shifted.kde_eval(data_lm1)) - np.log(np.exp(data_lm1['log_rts']) - t)), color=\"blue\", label='')"
1044+
"plt.plot(\n",
1045+
" np.exp(data_l1[\"log_rts\"]),\n",
1046+
" np.exp(np.squeeze(my_kde_shifted.kde_eval(data_l1)) - np.log(np.exp(data_lm1[\"log_rts\"]) - t)),\n",
1047+
" color=\"blue\",\n",
1048+
" label=\"\",\n",
1049+
")\n",
1050+
"plt.plot(\n",
1051+
" np.exp(data_lm1[\"log_rts\"]) * (-1),\n",
1052+
" np.exp(np.squeeze(my_kde_shifted.kde_eval(data_lm1)) - np.log(np.exp(data_lm1[\"log_rts\"]) - t)),\n",
1053+
" color=\"blue\",\n",
1054+
" label=\"\",\n",
1055+
")"
10401056
]
10411057
},
10421058
{
@@ -1329,7 +1345,7 @@
13291345
}
13301346
],
13311347
"source": [
1332-
"np.log(np.exp(data_lm1['log_rts']) - t)"
1348+
"np.log(np.exp(data_lm1[\"log_rts\"]) - t)"
13331349
]
13341350
},
13351351
{
@@ -1460,7 +1476,7 @@
14601476
}
14611477
],
14621478
"source": [
1463-
"np.exp(data_lm1['log_rts'])"
1479+
"np.exp(data_lm1[\"log_rts\"])"
14641480
]
14651481
},
14661482
{
@@ -1469,7 +1485,7 @@
14691485
"metadata": {},
14701486
"outputs": [],
14711487
"source": [
1472-
"data_l1 = np.log(data_1)\n"
1488+
"data_l1 = np.log(data_1)"
14731489
]
14741490
},
14751491
{
@@ -1576,8 +1592,8 @@
15761592
}
15771593
],
15781594
"source": [
1579-
"out_kde_shifted = my_kde_shifted.kde_sample(n_samples = 10000)\n",
1580-
"out_kde = my_kde.kde_sample(n_samples = 10000)"
1595+
"out_kde_shifted = my_kde_shifted.kde_sample(n_samples=10000)\n",
1596+
"out_kde = my_kde.kde_sample(n_samples=10000)"
15811597
]
15821598
},
15831599
{
@@ -1660,8 +1676,8 @@
16601676
"# plt.plot(data_m1[0] * (-1), np.exp(evals_m1), color=\"blue\")\n",
16611677
"# plt.plot(data_1_shifted[0], np.exp(evals_1_shifted), color=\"red\")\n",
16621678
"# plt.plot(data_m1_shifted[0] * (-1), np.exp(evals_m1_shifted), color=\"red\")\n",
1663-
"plt.hist(out_kde_shifted[0] * out_kde_shifted[1], bins = 100, histtype=\"step\", density=True)\n",
1664-
"plt.hist(out_kde[0] * out_kde[1], bins = 100, histtype=\"step\", density=True)\n",
1679+
"plt.hist(out_kde_shifted[0] * out_kde_shifted[1], bins=100, histtype=\"step\", density=True)\n",
1680+
"plt.hist(out_kde[0] * out_kde[1], bins=100, histtype=\"step\", density=True)\n",
16651681
"\n",
16661682
"\n",
16671683
"plt.hist(out[\"rts\"][out[\"rts\"] != -999] * out[\"choices\"][out[\"rts\"] != -999], bins=40, histtype=\"step\", density=True)"
@@ -1735,8 +1751,8 @@
17351751
}
17361752
],
17371753
"source": [
1738-
"#plt.hist(out_kde_shifted[0] * out_kde_shifted[1], bins = 100, histtype=\"step\", density=True)\n",
1739-
"plt.hist(np.log(out_kde[0][out_kde[1] == 1]), bins = 100, histtype=\"step\", density=True)\n"
1754+
"# plt.hist(out_kde_shifted[0] * out_kde_shifted[1], bins = 100, histtype=\"step\", density=True)\n",
1755+
"plt.hist(np.log(out_kde[0][out_kde[1] == 1]), bins=100, histtype=\"step\", density=True)"
17401756
]
17411757
},
17421758
{
@@ -1771,7 +1787,7 @@
17711787
}
17721788
],
17731789
"source": [
1774-
"plt.hist(np.exp(np.log(np.random.uniform(size = 100000))))"
1790+
"plt.hist(np.exp(np.log(np.random.uniform(size=100000))))"
17751791
]
17761792
},
17771793
{
@@ -1842,9 +1858,7 @@
18421858
"from time import time\n",
18431859
"\n",
18441860
"start = time()\n",
1845-
"out_traj = simulator(\n",
1846-
" model=\"ddm_mic2_multinoise_no_bias\", theta=[1.0, 1.0, 1.0, 1.5, 0.5, 1.0], n_samples=100000, max_t=20\n",
1847-
")\n",
1861+
"out_traj = simulator(model=\"ddm_mic2_multinoise_no_bias\", theta=[1.0, 1.0, 1.0, 1.5, 0.5, 1.0], n_samples=100000, max_t=20)\n",
18481862
"\n",
18491863
"end = time()\n",
18501864
"\n",
@@ -1904,15 +1918,9 @@
19041918
}
19051919
],
19061920
"source": [
1907-
"plt.hist(\n",
1908-
" out_traj[\"rts\"][(out_traj[\"choices\"] == 0) & (out_traj[\"rts\"] != -999)], histtype=\"step\", bins=40, label=\"low low\"\n",
1909-
")\n",
1910-
"plt.hist(\n",
1911-
" out_traj[\"rts\"][(out_traj[\"choices\"] == 1) & (out_traj[\"rts\"] != -999)], histtype=\"step\", bins=40, label=\"low high\"\n",
1912-
")\n",
1913-
"plt.hist(\n",
1914-
" out_traj[\"rts\"][(out_traj[\"choices\"] == 2) & (out_traj[\"rts\"] != -999)], histtype=\"step\", bins=40, label=\"high low\"\n",
1915-
")\n",
1921+
"plt.hist(out_traj[\"rts\"][(out_traj[\"choices\"] == 0) & (out_traj[\"rts\"] != -999)], histtype=\"step\", bins=40, label=\"low low\")\n",
1922+
"plt.hist(out_traj[\"rts\"][(out_traj[\"choices\"] == 1) & (out_traj[\"rts\"] != -999)], histtype=\"step\", bins=40, label=\"low high\")\n",
1923+
"plt.hist(out_traj[\"rts\"][(out_traj[\"choices\"] == 2) & (out_traj[\"rts\"] != -999)], histtype=\"step\", bins=40, label=\"high low\")\n",
19161924
"plt.hist(\n",
19171925
" out_traj[\"rts\"][(out_traj[\"choices\"] == 3) & (out_traj[\"rts\"] != -999)], histtype=\"step\", bins=40, label=\"high high\"\n",
19181926
")\n",

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ requires = ["setuptools", "wheel", "Cython>=0.29.23", "numpy >= 1.20"]
77

88
[project]
99
name= "ssm-simulators"
10-
version= "0.6.1"
10+
version= "0.7.0"
1111
authors= [{name = "Alexander Fenger", email = "[email protected]"}]
1212
description= "SSMS is a package collecting simulators and training data generators for a bunch of generative models of interest in the cognitive science / neuroscience and approximate bayesian computation communities"
1313
readme = "README.md"

ssms/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,6 @@
44
from . import config
55
from . import support_utils
66

7-
__version__ = "0.6.1" # importlib.metadata.version(__package__ or __name__)
7+
__version__ = "0.7.0" # importlib.metadata.version(__package__ or __name__)
88

99
__all__ = ["basic_simulators", "dataset_generators", "config", "support_utils"]

ssms/basic_simulators/boundary_functions.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,5 @@ def conflict_gamma(
105105
"""
106106

107107
return (
108-
scale * gamma.pdf(t, a=alpha_gamma, loc=0, scale=scale_gamma)
109-
+ np.multiply(t, (-np.sin(theta) / np.cos(theta))),
108+
scale * gamma.pdf(t, a=alpha_gamma, loc=0, scale=scale_gamma) + np.multiply(t, (-np.sin(theta) / np.cos(theta))),
110109
)

ssms/basic_simulators/drift_functions.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
This module defines a collection of drift functions for the simulators in the package.
66
"""
77

8+
89
def constant(t=np.arange(0, 20, 0.1)):
910
"""constant drift function
1011
@@ -49,6 +50,7 @@ def gamma_drift(t=np.arange(0, 20, 0.1), shape=2, scale=0.01, c=1.5):
4950
div_ = np.power(shape - 1, shape - 1) * np.power(scale, shape - 1) * np.exp(-(shape - 1))
5051
return c * np.divide(num_, div_)
5152

53+
5254
def ds_support_analytic(t=np.arange(0, 10, 0.001), init_p=0, fix_point=1, slope=2):
5355
"""Solution to differential equation of the form:
5456
x' = slope*(fix_point - x),
@@ -75,6 +77,7 @@ def ds_support_analytic(t=np.arange(0, 10, 0.001), init_p=0, fix_point=1, slope=
7577

7678
return (init_p - fix_point) * np.exp(-(slope * t)) + fix_point
7779

80+
7881
def ds_conflict_drift(
7982
t=np.arange(0, 10, 0.001),
8083
init_p_t=0,

0 commit comments

Comments
 (0)