|
506 | 506 | "metadata": {},
|
507 | 507 | "outputs": [],
|
508 | 508 | "source": [
|
509 |
| - "'ddm_boundaryfun_driftfun_deadline'" |
| 509 | + "\"ddm_boundaryfun_driftfun_deadline\"" |
510 | 510 | ]
|
511 | 511 | },
|
512 | 512 | {
|
|
573 | 573 | "outputs": [],
|
574 | 574 | "source": [
|
575 | 575 | "out_dict = {}\n",
|
576 |
| - "out_dict['choice_p'] = {}\n", |
577 |
| - "out_dict['choice_p_no_omission'] = {}\n", |
578 |
| - "out_dict['p_omission'] = {}\n", |
579 |
| - "for choice in simulations['metadata']['possible_choices']:\n", |
580 |
| - " out_dict['choice_p'][choice] = np.array([(simulations[\"choices\"] == choice).sum() / simulations[\"choices\"].flatten().shape[0]])\n", |
581 |
| - " out_dict['choice_p_no_omission'][choice] = np.array([(simulations[\"choices\"][simulations['rts'] != -999] == choice).sum() / simulations[\"choices\"].flatten().shape[0]])\n", |
582 |
| - " out_dict['p_omission'][choice] = np.array([(simulations[\"rts\"] == -999).sum() / simulations[\"choices\"].flatten().shape[0]])" |
| 576 | + "out_dict[\"choice_p\"] = {}\n", |
| 577 | + "out_dict[\"choice_p_no_omission\"] = {}\n", |
| 578 | + "out_dict[\"p_omission\"] = {}\n", |
| 579 | + "for choice in simulations[\"metadata\"][\"possible_choices\"]:\n", |
| 580 | + " out_dict[\"choice_p\"][choice] = np.array(\n", |
| 581 | + " [(simulations[\"choices\"] == choice).sum() / simulations[\"choices\"].flatten().shape[0]]\n", |
| 582 | + " )\n", |
| 583 | + " out_dict[\"choice_p_no_omission\"][choice] = np.array(\n", |
| 584 | + " [(simulations[\"choices\"][simulations[\"rts\"] != -999] == choice).sum() / simulations[\"choices\"].flatten().shape[0]]\n", |
| 585 | + " )\n", |
| 586 | + " out_dict[\"p_omission\"][choice] = np.array(\n", |
| 587 | + " [(simulations[\"rts\"] == -999).sum() / simulations[\"choices\"].flatten().shape[0]]\n", |
| 588 | + " )" |
583 | 589 | ]
|
584 | 590 | },
|
585 | 591 | {
|
|
600 | 606 | }
|
601 | 607 | ],
|
602 | 608 | "source": [
|
603 |
| - "simulations['rts'] != -999" |
| 609 | + "simulations[\"rts\"] != -999" |
604 | 610 | ]
|
605 | 611 | },
|
606 | 612 | {
|
|
664 | 670 | }
|
665 | 671 | ],
|
666 | 672 | "source": [
|
667 |
| - "out['metadata']" |
| 673 | + "out[\"metadata\"]" |
668 | 674 | ]
|
669 | 675 | },
|
670 | 676 | {
|
|
683 | 689 | ],
|
684 | 690 | "source": [
|
685 | 691 | "from copy import deepcopy\n",
|
| 692 | + "\n", |
686 | 693 | "v = 1.0\n",
|
687 | 694 | "a = 2.0\n",
|
688 | 695 | "z = 0.5\n",
|
689 | 696 | "t = 0.0\n",
|
690 | 697 | "theta = 0.7\n",
|
691 | 698 | "deadline = 10\n",
|
692 | 699 | "out = simulator(model=\"angle_deadline\", theta=[v, a, z, t, theta, deadline], n_samples=10000, max_t=20)\n",
|
693 |
| - "out_log = deepcopy(out) \n", |
| 700 | + "out_log = deepcopy(out)\n", |
694 | 701 | "out_log[\"log_rts\"] = np.ones(out[\"rts\"].shape) * -999\n",
|
695 |
| - "out_log[\"log_rts\"][out_log['rts'] != -999] = np.log(out_log[\"rts\"][out_log['rts'] != -999])\n", |
696 |
| - "del out_log['rts']" |
| 702 | + "out_log[\"log_rts\"][out_log[\"rts\"] != -999] = np.log(out_log[\"rts\"][out_log[\"rts\"] != -999])\n", |
| 703 | + "del out_log[\"rts\"]" |
697 | 704 | ]
|
698 | 705 | },
|
699 | 706 | {
|
|
751 | 758 | "sample_kde = my_kde.kde_sample(10000)\n",
|
752 | 759 | "sample_kde_shifted = my_kde_shifted.kde_sample(10000)\n",
|
753 | 760 | "sample_kde_shifted_log = my_kde_shifted_log.kde_sample(10000)\n",
|
754 |
| - "plt.hist(sample_kde[0] * sample_kde[1], bins = 50, density=True, histtype='step', color='blue')\n", |
755 |
| - "plt.hist(sample_kde_shifted['rts'] * sample_kde_shifted['choices'], bins = 50, density=True, histtype='step', color='red')\n", |
756 |
| - "plt.hist(sample_kde_shifted_log['rts'] * sample_kde_shifted_log['choices'], bins = 50, density=True, histtype='step', color='green')" |
| 761 | + "plt.hist(sample_kde[0] * sample_kde[1], bins=50, density=True, histtype=\"step\", color=\"blue\")\n", |
| 762 | + "plt.hist(sample_kde_shifted[\"rts\"] * sample_kde_shifted[\"choices\"], bins=50, density=True, histtype=\"step\", color=\"red\")\n", |
| 763 | + "plt.hist(\n", |
| 764 | + " sample_kde_shifted_log[\"rts\"] * sample_kde_shifted_log[\"choices\"], bins=50, density=True, histtype=\"step\", color=\"green\"\n", |
| 765 | + ")" |
757 | 766 | ]
|
758 | 767 | },
|
759 | 768 | {
|
|
855 | 864 | }
|
856 | 865 | ],
|
857 | 866 | "source": [
|
858 |
| - "my_kde.kde_eval((data_1['rts'], data_1['choices']))\n", |
859 |
| - "my_kde.kde_sample(n_samples = 10000)" |
| 867 | + "my_kde.kde_eval((data_1[\"rts\"], data_1[\"choices\"]))\n", |
| 868 | + "my_kde.kde_sample(n_samples=10000)" |
860 | 869 | ]
|
861 | 870 | },
|
862 | 871 | {
|
|
876 | 885 | }
|
877 | 886 | ],
|
878 | 887 | "source": [
|
879 |
| - "my_kde_shifted.kde_sample(n_samples=10000)['rts'].shape" |
| 888 | + "my_kde_shifted.kde_sample(n_samples=10000)[\"rts\"].shape" |
880 | 889 | ]
|
881 | 890 | },
|
882 | 891 | {
|
|
895 | 904 | "metadata": {},
|
896 | 905 | "outputs": [],
|
897 | 906 | "source": [
|
898 |
| - "data_1 = {'rts': np.linspace(0.01, 10, 1000),\n", |
899 |
| - " 'choices': np.ones(1000)}\n", |
900 |
| - "data_m1 = {'rts': np.linspace(0.01, 10, 1000), \n", |
901 |
| - " 'choices': (-1) * np.ones(1000)}\n", |
| 907 | + "data_1 = {\"rts\": np.linspace(0.01, 10, 1000), \"choices\": np.ones(1000)}\n", |
| 908 | + "data_m1 = {\"rts\": np.linspace(0.01, 10, 1000), \"choices\": (-1) * np.ones(1000)}\n", |
902 | 909 | "\n",
|
903 |
| - "data_l1 = {'log_rts': np.log(np.linspace(0.01, 10, 1000)),\n", |
904 |
| - " 'choices': np.ones(1000)}\n", |
905 |
| - "data_lm1 = {'log_rts': np.log(np.linspace(0.01, 10, 1000)),\n", |
906 |
| - " 'choices': (-1) * np.ones(1000)}\n", |
| 910 | + "data_l1 = {\"log_rts\": np.log(np.linspace(0.01, 10, 1000)), \"choices\": np.ones(1000)}\n", |
| 911 | + "data_lm1 = {\"log_rts\": np.log(np.linspace(0.01, 10, 1000)), \"choices\": (-1) * np.ones(1000)}\n", |
907 | 912 | "\n",
|
908 | 913 | "# data_m1 = (np.linspace(0.01, 10, 1000), np.ones(1000) * (-1))\n",
|
909 | 914 | "\n",
|
|
935 | 940 | "evals_m1 = my_kde_shifted.kde_eval(data_m1)\n",
|
936 | 941 | "\n",
|
937 | 942 | "evals_l1 = my_kde_shifted.kde_eval(data_1)\n",
|
938 |
| - "print('this is the problem')\n", |
| 943 | + "print(\"this is the problem\")\n", |
939 | 944 | "evals_lm1 = my_kde_shifted.kde_eval(data_lm1)\n",
|
940 | 945 | "\n",
|
941 | 946 | "# evals_1_shifted = my_kde_shifted.kde_eval(data_1_shifted)\n",
|
|
986 | 991 | "# my_kde_shifted.kde_eval(data_1)\n",
|
987 | 992 | "# my_kde_log_shifted.kde_eval(data_1)\n",
|
988 | 993 | "from matplotlib import pyplot as plt\n",
|
989 |
| - "plt.plot(data_1['rts'], np.exp(my_kde_shifted.kde_eval(data_1)), color=\"blue\", label='')\n", |
990 |
| - "plt.plot(data_m1['rts'] * (-1), np.exp(my_kde_shifted.kde_eval(data_m1)), color=\"blue\", label='')" |
| 994 | + "\n", |
| 995 | + "plt.plot(data_1[\"rts\"], np.exp(my_kde_shifted.kde_eval(data_1)), color=\"blue\", label=\"\")\n", |
| 996 | + "plt.plot(data_m1[\"rts\"] * (-1), np.exp(my_kde_shifted.kde_eval(data_m1)), color=\"blue\", label=\"\")" |
991 | 997 | ]
|
992 | 998 | },
|
993 | 999 | {
|
|
1035 | 1041 | }
|
1036 | 1042 | ],
|
1037 | 1043 | "source": [
|
1038 |
| - "plt.plot(np.exp(data_l1['log_rts']), np.exp(np.squeeze(my_kde_shifted.kde_eval(data_l1)) - np.log(np.exp(data_lm1['log_rts']) - t)), color=\"blue\", label='')\n", |
1039 |
| - "plt.plot(np.exp(data_lm1['log_rts']) * (-1), np.exp(np.squeeze(my_kde_shifted.kde_eval(data_lm1)) - np.log(np.exp(data_lm1['log_rts']) - t)), color=\"blue\", label='')" |
| 1044 | + "plt.plot(\n", |
| 1045 | + " np.exp(data_l1[\"log_rts\"]),\n", |
| 1046 | + " np.exp(np.squeeze(my_kde_shifted.kde_eval(data_l1)) - np.log(np.exp(data_lm1[\"log_rts\"]) - t)),\n", |
| 1047 | + " color=\"blue\",\n", |
| 1048 | + " label=\"\",\n", |
| 1049 | + ")\n", |
| 1050 | + "plt.plot(\n", |
| 1051 | + " np.exp(data_lm1[\"log_rts\"]) * (-1),\n", |
| 1052 | + " np.exp(np.squeeze(my_kde_shifted.kde_eval(data_lm1)) - np.log(np.exp(data_lm1[\"log_rts\"]) - t)),\n", |
| 1053 | + " color=\"blue\",\n", |
| 1054 | + " label=\"\",\n", |
| 1055 | + ")" |
1040 | 1056 | ]
|
1041 | 1057 | },
|
1042 | 1058 | {
|
|
1329 | 1345 | }
|
1330 | 1346 | ],
|
1331 | 1347 | "source": [
|
1332 |
| - "np.log(np.exp(data_lm1['log_rts']) - t)" |
| 1348 | + "np.log(np.exp(data_lm1[\"log_rts\"]) - t)" |
1333 | 1349 | ]
|
1334 | 1350 | },
|
1335 | 1351 | {
|
|
1460 | 1476 | }
|
1461 | 1477 | ],
|
1462 | 1478 | "source": [
|
1463 |
| - "np.exp(data_lm1['log_rts'])" |
| 1479 | + "np.exp(data_lm1[\"log_rts\"])" |
1464 | 1480 | ]
|
1465 | 1481 | },
|
1466 | 1482 | {
|
|
1469 | 1485 | "metadata": {},
|
1470 | 1486 | "outputs": [],
|
1471 | 1487 | "source": [
|
1472 |
| - "data_l1 = np.log(data_1)\n" |
| 1488 | + "data_l1 = np.log(data_1)" |
1473 | 1489 | ]
|
1474 | 1490 | },
|
1475 | 1491 | {
|
|
1576 | 1592 | }
|
1577 | 1593 | ],
|
1578 | 1594 | "source": [
|
1579 |
| - "out_kde_shifted = my_kde_shifted.kde_sample(n_samples = 10000)\n", |
1580 |
| - "out_kde = my_kde.kde_sample(n_samples = 10000)" |
| 1595 | + "out_kde_shifted = my_kde_shifted.kde_sample(n_samples=10000)\n", |
| 1596 | + "out_kde = my_kde.kde_sample(n_samples=10000)" |
1581 | 1597 | ]
|
1582 | 1598 | },
|
1583 | 1599 | {
|
|
1660 | 1676 | "# plt.plot(data_m1[0] * (-1), np.exp(evals_m1), color=\"blue\")\n",
|
1661 | 1677 | "# plt.plot(data_1_shifted[0], np.exp(evals_1_shifted), color=\"red\")\n",
|
1662 | 1678 | "# plt.plot(data_m1_shifted[0] * (-1), np.exp(evals_m1_shifted), color=\"red\")\n",
|
1663 |
| - "plt.hist(out_kde_shifted[0] * out_kde_shifted[1], bins = 100, histtype=\"step\", density=True)\n", |
1664 |
| - "plt.hist(out_kde[0] * out_kde[1], bins = 100, histtype=\"step\", density=True)\n", |
| 1679 | + "plt.hist(out_kde_shifted[0] * out_kde_shifted[1], bins=100, histtype=\"step\", density=True)\n", |
| 1680 | + "plt.hist(out_kde[0] * out_kde[1], bins=100, histtype=\"step\", density=True)\n", |
1665 | 1681 | "\n",
|
1666 | 1682 | "\n",
|
1667 | 1683 | "plt.hist(out[\"rts\"][out[\"rts\"] != -999] * out[\"choices\"][out[\"rts\"] != -999], bins=40, histtype=\"step\", density=True)"
|
|
1735 | 1751 | }
|
1736 | 1752 | ],
|
1737 | 1753 | "source": [
|
1738 |
| - "#plt.hist(out_kde_shifted[0] * out_kde_shifted[1], bins = 100, histtype=\"step\", density=True)\n", |
1739 |
| - "plt.hist(np.log(out_kde[0][out_kde[1] == 1]), bins = 100, histtype=\"step\", density=True)\n" |
| 1754 | + "# plt.hist(out_kde_shifted[0] * out_kde_shifted[1], bins = 100, histtype=\"step\", density=True)\n", |
| 1755 | + "plt.hist(np.log(out_kde[0][out_kde[1] == 1]), bins=100, histtype=\"step\", density=True)" |
1740 | 1756 | ]
|
1741 | 1757 | },
|
1742 | 1758 | {
|
|
1771 | 1787 | }
|
1772 | 1788 | ],
|
1773 | 1789 | "source": [
|
1774 |
| - "plt.hist(np.exp(np.log(np.random.uniform(size = 100000))))" |
| 1790 | + "plt.hist(np.exp(np.log(np.random.uniform(size=100000))))" |
1775 | 1791 | ]
|
1776 | 1792 | },
|
1777 | 1793 | {
|
|
1842 | 1858 | "from time import time\n",
|
1843 | 1859 | "\n",
|
1844 | 1860 | "start = time()\n",
|
1845 |
| - "out_traj = simulator(\n", |
1846 |
| - " model=\"ddm_mic2_multinoise_no_bias\", theta=[1.0, 1.0, 1.0, 1.5, 0.5, 1.0], n_samples=100000, max_t=20\n", |
1847 |
| - ")\n", |
| 1861 | + "out_traj = simulator(model=\"ddm_mic2_multinoise_no_bias\", theta=[1.0, 1.0, 1.0, 1.5, 0.5, 1.0], n_samples=100000, max_t=20)\n", |
1848 | 1862 | "\n",
|
1849 | 1863 | "end = time()\n",
|
1850 | 1864 | "\n",
|
|
1904 | 1918 | }
|
1905 | 1919 | ],
|
1906 | 1920 | "source": [
|
1907 |
| - "plt.hist(\n", |
1908 |
| - " out_traj[\"rts\"][(out_traj[\"choices\"] == 0) & (out_traj[\"rts\"] != -999)], histtype=\"step\", bins=40, label=\"low low\"\n", |
1909 |
| - ")\n", |
1910 |
| - "plt.hist(\n", |
1911 |
| - " out_traj[\"rts\"][(out_traj[\"choices\"] == 1) & (out_traj[\"rts\"] != -999)], histtype=\"step\", bins=40, label=\"low high\"\n", |
1912 |
| - ")\n", |
1913 |
| - "plt.hist(\n", |
1914 |
| - " out_traj[\"rts\"][(out_traj[\"choices\"] == 2) & (out_traj[\"rts\"] != -999)], histtype=\"step\", bins=40, label=\"high low\"\n", |
1915 |
| - ")\n", |
| 1921 | + "plt.hist(out_traj[\"rts\"][(out_traj[\"choices\"] == 0) & (out_traj[\"rts\"] != -999)], histtype=\"step\", bins=40, label=\"low low\")\n", |
| 1922 | + "plt.hist(out_traj[\"rts\"][(out_traj[\"choices\"] == 1) & (out_traj[\"rts\"] != -999)], histtype=\"step\", bins=40, label=\"low high\")\n", |
| 1923 | + "plt.hist(out_traj[\"rts\"][(out_traj[\"choices\"] == 2) & (out_traj[\"rts\"] != -999)], histtype=\"step\", bins=40, label=\"high low\")\n", |
1916 | 1924 | "plt.hist(\n",
|
1917 | 1925 | " out_traj[\"rts\"][(out_traj[\"choices\"] == 3) & (out_traj[\"rts\"] != -999)], histtype=\"step\", bins=40, label=\"high high\"\n",
|
1918 | 1926 | ")\n",
|
|
0 commit comments