diff --git a/src/tensorwaves/data/generate.py b/src/tensorwaves/data/generate.py index 96cc1ae80..c19882b1a 100644 --- a/src/tensorwaves/data/generate.py +++ b/src/tensorwaves/data/generate.py @@ -4,6 +4,7 @@ from typing import Callable, Optional import numpy as np +from expertsystem.amplitude.data import MomentumPool from expertsystem.amplitude.kinematics import HelicityKinematics, ReactionInfo from tqdm import tqdm @@ -36,7 +37,7 @@ def _generate_data_bunch( np_phsp_sample = np.array(phsp_sample.values()) np_phsp_sample = np_phsp_sample.transpose(1, 0, 2) - return (np_phsp_sample[weights * intensities > uniform_randoms], maxvalue) + return np_phsp_sample[weights * intensities > uniform_randoms], maxvalue def generate_data( @@ -48,7 +49,7 @@ def generate_data( ] = TFPhaseSpaceGenerator, random_generator: Optional[UniformRealNumberGenerator] = None, bunch_size: int = 50000, -) -> np.ndarray: +) -> MomentumPool: """Facade function for creating data samples based on an intensities. Args: @@ -101,7 +102,9 @@ def generate_data( events = bunch progress_bar.update() progress_bar.close() - return events[0:size].transpose(1, 0, 2) + events = events[0:size].transpose(1, 0, 2) + pos_to_state_id = dict(enumerate(kinematics.reaction_info.final_state)) + return MomentumPool({pos_to_state_id[i]: events[i] for i in events}) def generate_phsp( @@ -112,7 +115,7 @@ def generate_phsp( ] = TFPhaseSpaceGenerator, random_generator: Optional[UniformRealNumberGenerator] = None, bunch_size: int = 50000, -) -> np.ndarray: +) -> MomentumPool: """Facade function for creating (unweighted) phase space samples. Args: @@ -157,4 +160,6 @@ def generate_phsp( events = bunch progress_bar.update() progress_bar.close() - return events[0:size].transpose(1, 0, 2) + events = events[0:size].transpose(1, 0, 2) + pos_to_state_id = dict(enumerate(kinematics.reaction_info.final_state)) + return MomentumPool({pos_to_state_id[i]: events[i] for i in events}) diff --git a/src/tensorwaves/data/tf_phasespace.py b/src/tensorwaves/data/tf_phasespace.py index 9c973b9a4..c473fded2 100644 --- a/src/tensorwaves/data/tf_phasespace.py +++ b/src/tensorwaves/data/tf_phasespace.py @@ -1,11 +1,11 @@ """Phase space generation using tensorflow.""" -from typing import Dict, Optional, Tuple +from typing import Optional, Tuple import expertsystem.amplitude.kinematics as es -import numpy as np import phasespace import tensorflow as tf +from expertsystem.amplitude.data import MomentumPool, ScalarSequence from phasespace.random import get_rng from tensorwaves.interfaces import ( @@ -30,7 +30,7 @@ def __init__(self, reaction_info: es.ReactionInfo) -> None: def generate( self, size: int, rng: UniformRealNumberGenerator - ) -> Tuple[Dict[int, np.ndarray], np.ndarray]: + ) -> Tuple[MomentumPool, ScalarSequence]: if not isinstance(rng, TFUniformRealNumberGenerator): raise TypeError( f"{TFPhaseSpaceGenerator.__name__} requires a " @@ -40,11 +40,13 @@ def generate( weights, particles = self.phsp_gen.generate( n_events=size, seed=rng.generator ) - momentum_pool = { - int(label): momenta.numpy().T - for label, momenta in particles.items() - } - return momentum_pool, weights.numpy() + momentum_pool = MomentumPool( + { + int(label): momenta.numpy().T + for label, momenta in particles.items() + } + ) + return momentum_pool, ScalarSequence(weights.numpy()) class TFUniformRealNumberGenerator(UniformRealNumberGenerator): @@ -56,13 +58,15 @@ def __init__(self, seed: Optional[float] = None): def __call__( self, size: int, min_value: float = 0.0, max_value: float = 1.0 - ) -> np.ndarray: - return self.generator.uniform( - shape=[size], - minval=min_value, - maxval=max_value, - dtype=self.dtype, - ).numpy() + ) -> ScalarSequence: + return ScalarSequence( + self.generator.uniform( + shape=[size], + minval=min_value, + maxval=max_value, + dtype=self.dtype, + ).numpy() + ) @property def seed(self) -> Optional[float]: diff --git a/src/tensorwaves/estimator.py b/src/tensorwaves/estimator.py index 717fbeb57..046545d08 100644 --- a/src/tensorwaves/estimator.py +++ b/src/tensorwaves/estimator.py @@ -2,21 +2,23 @@ All estimators have to implement the `~.interfaces.Estimator` interface. """ -from typing import Callable, Dict, List, Union +from typing import Callable, Dict, List, Mapping, Union + +from expertsystem.amplitude.data import ScalarSequence from tensorwaves.interfaces import Estimator, Model from tensorwaves.physics.amplitude import get_backend_modules def gradient_creator( - function: Callable[[Dict[str, Union[float, complex]]], float], + function: Callable[[Mapping[str, Union[float, complex]]], float], backend: Union[str, tuple, dict], ) -> Callable[ - [Dict[str, Union[float, complex]]], Dict[str, Union[float, complex]] + [Mapping[str, Union[float, complex]]], Dict[str, Union[float, complex]] ]: # pylint: disable=import-outside-toplevel def not_implemented( - parameters: Dict[str, Union[float, complex]] + parameters: Mapping[str, Union[float, complex]] ) -> Dict[str, Union[float, complex]]: raise NotImplementedError("Gradient not implemented.") @@ -50,8 +52,8 @@ class SympyUnbinnedNLL( # pylint: disable=too-many-instance-attributes def __init__( self, model: Model, - dataset: dict, - phsp_dataset: dict, + dataset: Mapping[str, Union[ScalarSequence, complex, float]], + phsp_dataset: Mapping[str, Union[ScalarSequence, complex, float]], phsp_volume: float = 1.0, backend: Union[str, tuple, dict] = "numpy", ) -> None: @@ -77,14 +79,14 @@ def find_function_in_backend(name: str) -> Callable: self.__phsp_volume = phsp_volume - self.__data_args = [] - self.__phsp_args = [] + self.__data_args: Dict[str, Union[ScalarSequence, complex, float]] = {} + self.__phsp_args: Dict[str, Union[ScalarSequence, complex, float]] = {} self.__parameter_index_mapping: Dict[str, int] = {} for i, var_name in enumerate(model.variables): if var_name in dataset and var_name in phsp_dataset: - self.__data_args.append(dataset[var_name]) - self.__phsp_args.append(phsp_dataset[var_name]) + self.__data_args[var_name] = dataset[var_name] + self.__phsp_args[var_name] = phsp_dataset[var_name] elif var_name in dataset: raise ValueError( f"Datasets do not match! {var_name} exists in dataset but " @@ -96,35 +98,36 @@ def find_function_in_backend(name: str) -> Callable: "dataset but not in dataset." ) else: - self.__data_args.append(model.parameters[var_name]) - self.__phsp_args.append(model.parameters[var_name]) + self.__data_args[var_name] = model.parameters[var_name] + self.__phsp_args[var_name] = model.parameters[var_name] self.__parameter_index_mapping[var_name] = i - def __call__(self, parameters: Dict[str, Union[float, complex]]) -> float: + def __call__( + self, parameters: Mapping[str, Union[float, complex]] + ) -> float: self.__update_parameters(parameters) - bare_intensities = self.__bare_model(*self.__data_args) + bare_intensities = self.__bare_model(self.__data_args) normalization_factor = 1.0 / ( self.__phsp_volume - * self.__mean_function(self.__bare_model(*self.__phsp_args)) + * self.__mean_function(self.__bare_model(self.__phsp_args)) ) likelihoods = normalization_factor * bare_intensities return -self.__sum_function(self.__log_function(likelihoods)) def __update_parameters( - self, parameters: Dict[str, Union[float, complex]] + self, parameters: Mapping[str, Union[float, complex]] ) -> None: for par_name, value in parameters.items(): if par_name in self.__parameter_index_mapping: - index = self.__parameter_index_mapping[par_name] - self.__data_args[index] = value - self.__phsp_args[index] = value + self.__data_args[par_name] = value + self.__phsp_args[par_name] = value @property def parameters(self) -> List[str]: return list(self.__parameter_index_mapping) def gradient( - self, parameters: Dict[str, Union[float, complex]] + self, parameters: Mapping[str, Union[float, complex]] ) -> Dict[str, Union[float, complex]]: return self.__gradient(parameters) diff --git a/src/tensorwaves/interfaces.py b/src/tensorwaves/interfaces.py index c3ca9ffa6..f7f377ff8 100644 --- a/src/tensorwaves/interfaces.py +++ b/src/tensorwaves/interfaces.py @@ -5,22 +5,18 @@ Any, Dict, FrozenSet, - Generic, Iterable, + Mapping, Optional, Protocol, Tuple, - TypeVar, Union, ) -import numpy as np +from expertsystem.amplitude.data import DataSet, MomentumPool, ScalarSequence -DataType = TypeVar("DataType") -"""Type of the data that is returned by `.Function.__call__`.""" - -class Function(Protocol, Generic[DataType]): +class Function(Protocol): """Interface of a callable function. The parameters of the model are separated from the domain variables. This @@ -32,7 +28,9 @@ class Function(Protocol, Generic[DataType]): is to facilitate the events when parameters have changed. """ - def __call__(self, dataset: Dict[str, DataType]) -> DataType: + def __call__( + self, dataset: Mapping[str, Union[ScalarSequence, complex, float]] + ) -> ScalarSequence: """Evaluate the function. Args: @@ -60,13 +58,13 @@ def lambdify(self, backend: Union[str, tuple, dict]) -> Function: """ @abstractmethod - def performance_optimize(self, fix_inputs: Dict[str, Any]) -> "Model": + def performance_optimize(self, fix_inputs: DataSet) -> "Model": """Create a performance optimized model, based on fixed inputs.""" @property @abstractmethod def parameters(self) -> Dict[str, Union[float, complex]]: - """Get `dict` of parameters.""" + """Get mapping of parameters to suggested initial values.""" @property @abstractmethod @@ -78,7 +76,9 @@ class Estimator(ABC): """Estimator for discrepancy model and data.""" @abstractmethod - def __call__(self, parameters: Dict[str, Union[float, complex]]) -> float: + def __call__( + self, parameters: Mapping[str, Union[float, complex]] + ) -> float: """Evaluate discrepancy.""" @property @@ -88,7 +88,7 @@ def parameters(self) -> Iterable[str]: @abstractmethod def gradient( - self, parameters: Dict[str, Union[float, complex]] + self, parameters: Mapping[str, Union[float, complex]] ) -> Dict[str, Union[float, complex]]: """Calculate gradient for given parameter mapping.""" @@ -97,11 +97,11 @@ class Kinematics(ABC): """Abstract interface for computation of kinematic variables.""" @abstractmethod - def convert(self, events: dict) -> dict: + def convert(self, events: MomentumPool) -> DataSet: """Convert a set of momentum tuples (events) to kinematic variables.""" @abstractmethod - def is_within_phase_space(self, events: dict) -> Tuple[bool]: + def is_within_phase_space(self, events: MomentumPool) -> Tuple[bool]: """Check which events lie within phase space.""" @property @@ -114,7 +114,11 @@ class Optimizer(ABC): """Optimize a fit model to a data set.""" @abstractmethod - def optimize(self, estimator: Estimator, initial_parameters: dict) -> dict: + def optimize( + self, + estimator: Estimator, + initial_parameters: Mapping[str, Union[float, complex]], + ) -> Dict[str, Any]: """Execute optimization.""" @@ -124,7 +128,7 @@ class UniformRealNumberGenerator(ABC): @abstractmethod def __call__( self, size: int, min_value: float = 0.0, max_value: float = 1.0 - ) -> Union[float, list]: + ) -> ScalarSequence: """Generate random floats in the range from [min_value,max_value).""" @property # type: ignore @@ -144,7 +148,7 @@ class PhaseSpaceGenerator(ABC): @abstractmethod def generate( self, size: int, rng: UniformRealNumberGenerator - ) -> Tuple[Dict[int, np.ndarray], np.ndarray]: + ) -> Tuple[MomentumPool, ScalarSequence]: """Generate phase space sample. Returns a `tuple` of a mapping of final state IDs to `numpy.array` s diff --git a/src/tensorwaves/optimizer/minuit.py b/src/tensorwaves/optimizer/minuit.py index 1eda638ed..37bec1120 100644 --- a/src/tensorwaves/optimizer/minuit.py +++ b/src/tensorwaves/optimizer/minuit.py @@ -5,7 +5,7 @@ import logging import time from datetime import datetime -from typing import Dict, Iterable, Optional, Union +from typing import Any, Dict, Iterable, Mapping, Optional, Union from iminuit import Minuit from tqdm import tqdm @@ -16,7 +16,9 @@ class ParameterFlattener: - def __init__(self, parameters: Dict[str, Union[float, complex]]) -> None: + def __init__( + self, parameters: Mapping[str, Union[float, complex]] + ) -> None: self.__real_imag_to_complex_name = {} self.__complex_to_real_imag_name = {} for name, val in parameters.items(): @@ -46,7 +48,7 @@ def unflatten( return parameters def flatten( - self, parameters: Dict[str, Union[float, complex]] + self, parameters: Mapping[str, Union[float, complex]] ) -> Dict[str, float]: flattened_parameters = {} for par_name, value in parameters.items(): @@ -80,8 +82,8 @@ def __init__( def optimize( # pylint: disable=too-many-locals self, estimator: Estimator, - initial_parameters: Dict[str, Union[complex, float]], - ) -> dict: + initial_parameters: Mapping[str, Union[complex, float]], + ) -> Dict[str, Any]: parameter_handler = ParameterFlattener(initial_parameters) flattened_parameters = parameter_handler.flatten(initial_parameters) diff --git a/src/tensorwaves/physics/amplitude.py b/src/tensorwaves/physics/amplitude.py index 6b9093901..20c97168b 100644 --- a/src/tensorwaves/physics/amplitude.py +++ b/src/tensorwaves/physics/amplitude.py @@ -1,8 +1,9 @@ """`.Function` Adapter for `sympy`-based models.""" -from typing import Any, Callable, Dict, FrozenSet, Optional, Tuple, Union +from typing import Callable, Dict, FrozenSet, Mapping, Optional, Tuple, Union import sympy as sp +from expertsystem.amplitude.data import DataSet, ScalarSequence from tensorwaves.interfaces import Function, Model @@ -68,24 +69,26 @@ def __init__( if symbol.name not in self.parameters } ) + if not all(map(lambda p: isinstance(p, sp.Symbol), parameters)): + raise TypeError(f"Not all parameters are of type {sp.Symbol}") def lambdify(self, backend: Union[str, tuple, dict]) -> Function: """Lambdify the model using `~sympy.utilities.lambdify.lambdify`.""" # pylint: disable=import-outside-toplevel - variables = tuple(self.__expression.free_symbols) + all_symbols = tuple(self.__expression.free_symbols) def jax_lambdify() -> Callable: from jax import jit return jit( sp.lambdify( - variables, + all_symbols, self.__expression, modules=backend_modules, ) ) - callable_model: Optional[Callable] = None + callable_model: Optional[Function] = None if isinstance(backend, str): if backend == "jax": callable_model = jax_lambdify() @@ -95,7 +98,7 @@ def jax_lambdify() -> Callable: callable_model = jit( sp.lambdify( - variables, + all_symbols, self.__expression, modules="numpy", ), @@ -107,7 +110,7 @@ def jax_lambdify() -> Callable: if callable_model is None: # default backend_modules = get_backend_modules(backend) callable_model = sp.lambdify( - variables, + all_symbols, self.__expression, modules=backend_modules, ) @@ -115,22 +118,22 @@ def jax_lambdify() -> Callable: raise ValueError(f"Failed to lambdify model for backend {backend}") input_variable_order: Tuple[str, ...] = tuple( - x.name for x in self.__expression.free_symbols + x.name for x in all_symbols ) - def function_wrapper(dataset: Dict[str, Any]) -> Any: + def function_wrapper( + dataset: Mapping[str, Union[ScalarSequence, complex, float]] + ) -> ScalarSequence: return callable_model( # type: ignore - *( - dataset[var_name] - if var_name in dataset - else self.parameters[var_name] + { + var_name: dataset.get(var_name, self.parameters[var_name]) for var_name in input_variable_order - ) + } ) return function_wrapper - def performance_optimize(self, fix_inputs: Dict[str, Any]) -> "Model": + def performance_optimize(self, fix_inputs: DataSet) -> "Model": raise NotImplementedError @property diff --git a/tests/conftest.py b/tests/conftest.py index 086a36248..753722541 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,8 +1,10 @@ # pylint: disable=redefined-outer-name +from typing import Any, Dict + import expertsystem as es -import numpy as np import pytest +from expertsystem.amplitude.data import DataSet, MomentumPool from expertsystem.amplitude.dynamics.builder import ( create_relativistic_breit_wigner_with_ff, ) @@ -61,12 +63,14 @@ def kinematics() -> HelicityKinematics: @pytest.fixture(scope="session") -def phsp_sample(kinematics: HelicityKinematics) -> np.ndarray: +def phsp_sample(kinematics: HelicityKinematics) -> MomentumPool: return generate_phsp(N_PHSP_EVENTS, kinematics, random_generator=RNG) @pytest.fixture(scope="session") -def phsp_set(kinematics: HelicityKinematics, phsp_sample: np.ndarray) -> dict: +def phsp_set( + kinematics: HelicityKinematics, phsp_sample: MomentumPool +) -> DataSet: return kinematics.convert(phsp_sample) @@ -74,7 +78,7 @@ def phsp_set(kinematics: HelicityKinematics, phsp_sample: np.ndarray) -> dict: def data_sample( kinematics: HelicityKinematics, helicity_model: SympyModel, -) -> np.ndarray: +) -> MomentumPool: callable_model = helicity_model.lambdify(backend="numpy") return generate_data( N_DATA_EVENTS, kinematics, callable_model, random_generator=RNG @@ -84,14 +88,14 @@ def data_sample( @pytest.fixture(scope="session") def data_set( kinematics: HelicityKinematics, - data_sample: np.ndarray, -) -> dict: + data_sample: MomentumPool, +) -> DataSet: return kinematics.convert(data_sample) @pytest.fixture(scope="session") def estimator( - helicity_model: SympyModel, data_set: dict, phsp_set: dict + helicity_model: SympyModel, data_set: DataSet, phsp_set: DataSet ) -> SympyUnbinnedNLL: return SympyUnbinnedNLL( helicity_model, @@ -101,7 +105,7 @@ def estimator( @pytest.fixture(scope="session") -def free_parameters() -> dict: +def free_parameters() -> Dict[str, float]: return { "Gamma_f(0)(500)": 0.3, "m_f(0)(980)": 1, @@ -110,8 +114,10 @@ def free_parameters() -> dict: @pytest.fixture(scope="session") def fit_result( - estimator: SympyUnbinnedNLL, free_parameters: dict, output_dir: str -) -> dict: + estimator: SympyUnbinnedNLL, + free_parameters: Dict[str, float], + output_dir: str, +) -> Dict[str, Any]: optimizer = Minuit2( callback=CallbackList( [ diff --git a/tests/data/test_generate.py b/tests/data/test_generate.py index c23e41cee..a32e1dcc2 100644 --- a/tests/data/test_generate.py +++ b/tests/data/test_generate.py @@ -3,6 +3,7 @@ import numpy as np import pytest +from expertsystem.amplitude.data import MomentumPool from expertsystem.amplitude.kinematics import HelicityKinematics, ReactionInfo from tensorwaves.data.generate import generate_phsp @@ -44,80 +45,86 @@ def test_generate_data(data_sample: np.ndarray): ( "J/psi(1S)", ("pi0", "pi0", "pi0"), - [ - [ - [0.799667989, 0.159823862, 0.156340839, 0.841233472], - [-0.364360112, -0.371962329, 0.347228344, 0.640234742], - [0.403805561, 0.417294074, -0.208401449, 0.631540320], - ], - [ - [-0.053789754, -0.535237707, -0.947232044, 1.097652050], - [1.168326711, -0.060296302, -0.805136016, 1.426564296], - [0.014812643, 0.081738919, 1.233338364, 1.243480165], - ], - [ - [-0.745878234, 0.375413844, 0.790891204, 1.158014477], - [-0.803966599, 0.432258632, 0.457907671, 1.030100961], - [-0.418618204, -0.499032994, -1.024936914, 1.221879513], - ], - ], + MomentumPool( + { + 0: [ + [0.841233472, 0.799667989, 0.159823862, 0.156340839], + [0.640234742, -0.364360112, -0.371962329, 0.347228344], + [0.631540320, 0.403805561, 0.417294074, -0.208401449], + ], + 1: [ + [1.09765205, -0.05378975, -0.53523771, -0.94723204], + [1.426564296, 1.168326711, -0.060296302, -0.805136016], + [1.243480165, 0.014812643, 0.081738919, 1.233338364], + ], + 2: [ + [1.158014477, -0.745878234, 0.375413844, 0.790891204], + [1.030100961, -0.803966599, 0.432258632, 0.457907671], + [1.22187951, -0.41861820, -0.49903210, -1.02493691], + ], + } + ), ), ( ("J/psi(1S)"), ("pi0", "pi0", "pi0", "gamma"), - [ - [ - [0.037458949, 0.339629143, -0.369297399, 0.520913076], - [-0.569078090, 0.687702756, -0.760836072, 1.180624927], - [0.543652274, 0.220242315, -0.077206475, 0.606831154], - ], - [ - [0.130561009, 0.299006221, -0.012444727, 0.353305116], - [0.123009165, 0.057692537, 0.033979586, 0.194507152], - [0.224048290, -0.156048645, 0.130817046, 0.331482507], - ], - [ - [0.236609937, -0.366594420, 1.192296945, 1.276779728], - [0.571746863, -0.586304492, 1.051145223, 1.339317905], - [0.402982692, -0.697161285, 0.083274400, 0.820720580], - ], - [ - [-0.404629896, -0.272040943, -0.810554818, 0.945902078], - [-0.125677938, -0.159090801, -0.324288738, 0.382450013], - [-1.170683257, 0.632967615, -0.136884971, 1.337865758], - ], - ], + MomentumPool( + { + 0: [ + [0.520913076, 0.037458949, 0.339629143, -0.369297399], + [1.180624927, -0.569078090, 0.687702756, -0.760836072], + [0.606831154, 0.543652274, 0.220242315, -0.077206475], + ], + 1: [ + [0.353305116, 0.130561009, 0.299006221, -0.012444727], + [0.194507152, 0.123009165, 0.057692537, 0.033979586], + [0.331482507, 0.224048290, -0.156048645, 0.130817046], + ], + 2: [ + [1.276779728, 0.236609937, -0.366594420, 1.192296945], + [1.339317905, 0.571746863, -0.586304492, 1.051145223], + [0.820720580, 0.402982692, -0.697161285, 0.083274400], + ], + 3: [ + [0.945902080, -0.40462990, -0.27204094, -0.81055482], + [0.38245001, -0.12567794, -0.15909080, -0.32428874], + [1.337865758, -1.170683257, 0.632967615, -0.136884971], + ], + } + ), ), ( "J/psi(1S)", ("pi0", "pi0", "pi0", "pi0", "gamma"), - [ - [ - [0.715439409, -0.284844373, -0.623772405, 1.000150296], - [0.134562969, 0.189723778, 0.229578969, 0.353592342], - [0.655088513, -0.205095150, -0.222905673, 0.734241552], - ], - [ - [-0.062423993, 0.008278542, -0.516645045, 0.537685901], - [-0.075102421, -0.215361523, 0.351626927, 0.440319420], - [-0.569846157, -0.063070826, 0.199036046, 0.621720722], - ], - [ - [-0.190428491, -0.002167052, 0.540188288, 0.588463958], - [-0.114856586, -0.554777459, -0.515051054, 0.777474366], - [-0.120958419, 0.236101553, -0.455239823, 0.543908922], - ], - [ - [-0.286712460, -0.089479316, 0.393698133, 0.513251926], - [0.536198573, -0.215753382, -0.007385008, 0.593575359], - [-0.442948181, -0.261969339, 0.187557768, 0.564116725], - ], - [ - [-0.175874464, 0.368212199, 0.206531028, 0.457347916], - [-0.480802535, 0.796168585, -0.058769834, 0.931938511], - [0.478664245, 0.294033763, 0.291551681, 0.632912076], - ], - ], + MomentumPool( + { + 0: [ + [1.000150296, 0.715439409, -0.284844373, -0.623772405], + [0.353592342, 0.134562969, 0.189723778, 0.229578969], + [0.734241552, 0.655088513, -0.205095150, -0.222905673], + ], + 1: [ + [0.537685901, -0.062423993, 0.008278542, -0.516645045], + [0.440319420, -0.075102421, -0.215361523, 0.351626927], + [0.621720722, -0.569846157, -0.063070826, 0.199036046], + ], + 2: [ + [0.588463958, -0.190428491, -0.002167052, 0.540188288], + [0.77747437, -0.11485659, -0.55477746, -0.51505105], + [0.543908922, -0.120958419, 0.236101553, -0.455239823], + ], + 3: [ + [0.513251926, -0.286712460, -0.089479316, 0.393698133], + [0.593575359, 0.536198573, -0.215753382, -0.007385008], + [0.564116725, -0.442948181, -0.261969339, 0.187557768], + ], + 4: [ + [0.457347916, -0.175874464, 0.368212199, 0.206531028], + [0.931938511, -0.480802535, 0.796168585, -0.058769834], + [0.632912076, 0.478664245, 0.294033763, 0.291551681], + ], + } + ), ), ], ) @@ -142,5 +149,7 @@ def test_generate_phsp( sample_size = 3 rng = TFUniformRealNumberGenerator(seed=0) sample = generate_phsp(sample_size, kin, random_generator=rng) - assert sample.shape == (len(final_state_names), sample_size, 4) + assert len(sample) == len(final_state_names) + for four_momenta in sample.values(): + assert len(four_momenta) == sample_size assert pytest.approx(sample, abs=1e-8) == expected_sample diff --git a/tests/data/test_tf_phasespace.py b/tests/data/test_tf_phasespace.py index 7e0962175..a52fc950b 100644 --- a/tests/data/test_tf_phasespace.py +++ b/tests/data/test_tf_phasespace.py @@ -32,8 +32,8 @@ def test_generate_deterministic(pdg): phsp_generator = TFPhaseSpaceGenerator(reaction_info) four_momenta, weights = phsp_generator.generate(sample_size, rng) for values in four_momenta.values(): - assert values.shape == (sample_size, 4) - assert weights.shape == (sample_size,) + assert len(values) == sample_size + assert len(weights) == sample_size assert pytest.approx(four_momenta, abs=1e-6) == [ [ [0.357209, 0.251997, 0.244128, 0.705915], diff --git a/tests/test_estimator.py b/tests/test_estimator.py index 0cf232c37..5943d8721 100644 --- a/tests/test_estimator.py +++ b/tests/test_estimator.py @@ -1,10 +1,12 @@ # pylint: disable=invalid-name, redefined-outer-name import math +from typing import Dict, Mapping, Union import numpy as np import pytest import sympy as sp +from expertsystem.amplitude.data import ScalarSequence from tensorwaves.estimator import SympyUnbinnedNLL from tensorwaves.optimizer.minuit import Minuit2 @@ -58,7 +60,7 @@ def gaussian_sum( @pytest.fixture(scope="module") -def phsp_dataset(): +def phsp_dataset() -> Dict[str, np.ndarray]: rng = np.random.default_rng(12345) return {"x": rng.uniform(low=-2.0, high=5.0, size=10000)} @@ -142,7 +144,10 @@ def phsp_dataset(): ], ) def test_sympy_unbinned_nll( - model: SympyModel, dataset: dict, true_params: dict, phsp_dataset: dict + model: SympyModel, + dataset: Mapping[str, Union[ScalarSequence, complex, float]], + true_params: Dict[str, Union[complex, float]], + phsp_dataset: Mapping[str, Union[ScalarSequence, complex, float]], ): estimator = SympyUnbinnedNLL( model,