From 47a61ef4f8d3289ddd9b2ab78a478be2163aa4bc Mon Sep 17 00:00:00 2001 From: Alexander Fengler Date: Wed, 18 Dec 2024 12:43:29 +0100 Subject: [PATCH] change default of displace t (#59) --- pyproject.toml | 2 +- ssms/__init__.py | 2 +- ssms/support_utils/kde_class.py | 25 +++++++++++++++++++++---- 3 files changed, 23 insertions(+), 6 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index cc50614..e72b98b 100755 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ requires = ["setuptools", "wheel", "Cython>=0.29.23", "numpy >= 1.20"] [project] name= "ssm-simulators" -version= "0.7.9" +version= "0.8.0" authors= [{name = "Alexander Fenger", email = "alexander_fengler@brown.edu"}] description= "SSMS is a package collecting simulators and training data generators for a bunch of generative models of interest in the cognitive science / neuroscience and approximate bayesian computation communities" readme = "README.md" diff --git a/ssms/__init__.py b/ssms/__init__.py index 34b554f..94ea215 100755 --- a/ssms/__init__.py +++ b/ssms/__init__.py @@ -4,6 +4,6 @@ from . import config from . import support_utils -__version__ = "0.7.9" # importlib.metadata.version(__package__ or __name__) +__version__ = "0.8.0" # importlib.metadata.version(__package__ or __name__) __all__ = ["basic_simulators", "dataset_generators", "config", "support_utils"] diff --git a/ssms/support_utils/kde_class.py b/ssms/support_utils/kde_class.py index 626c9a7..8a39d95 100755 --- a/ssms/support_utils/kde_class.py +++ b/ssms/support_utils/kde_class.py @@ -46,11 +46,28 @@ class LogKDE: # Initialize the class def __init__( self, - simulator_data, # as returned by simulator function - bandwidth_type="silverman", - auto_bandwidth=True, - displace_t=True, + simulator_data: dict, # as returned by simulator function + bandwidth_type: str = "silverman", + auto_bandwidth: bool = True, + displace_t: bool = False, ): + """Initialize LogKDE class. + + Arguments: + ---------- + simulator_data: Dictionary containing simulation data with keys 'rts', 'choices', and 'metadata'. + Follows the format returned by simulator functions in this package. + bandwidth_type: Type of bandwidth to use for KDE. Currently only 'silverman' is supported. + Defaults to 'silverman'. + auto_bandwidth: Whether to automatically compute bandwidths based on the data. + If False, bandwidths must be set manually. Defaults to True. + displace_t: Whether to shift RTs by the t parameter from metadata. + Only works if all trials have the same t value. Defaults to False. + + Raises: + ------- + AssertionError: If displace_t is True but metadata contains multiple t values. + """ self.simulator_info = simulator_data["metadata"] if displace_t: