Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow scalar input for 't', 'E' and 'data' arguments of Container class #301

Merged
merged 6 commits into from
Feb 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 22 additions & 7 deletions python/snewpy/flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,14 +101,14 @@ def __init__(self,
data: :class:`astropy.Quantity`
3D array of the stored quantity, must have dimensions compatible with (flavor, time, energy)

flavor: list of :class:`snewpy.neutrino.Flavor`
flavor: list or a single value of :class:`snewpy.neutrino.Flavor`
array of flavors (should be ``len(flavor)==data.shape[0]``

time: array of :class:`astropy.Quantity`
time: :class:`astropy.Quantity`
sampling points in time (then ``len(time)==data.shape[1]``)
or time bin edges (then ``len(time)==data.shape[1]+1``)

energy: array of :class:`astropy.Quantity`
energy: :class:`astropy.Quantity`
sampling points in energy (then ``len(energy)=data.shape[2]``)
or energy bin edges (then ``len(energy)=data.shape[2]+1``)

Expand All @@ -119,10 +119,25 @@ def __init__(self,
if self.unit is not None:
#try to convert to the unit
data = data.to(self.unit)
self.array = data
self.flavor = np.sort(flavor)
self.time = time
self.energy = energy
#convert the input values to arrays if they are scalar
self.array = u.Quantity(data)
self.time = u.Quantity(time, ndmin=1)
self.energy = u.Quantity(energy, ndmin=1)
self.flavor = np.sort(np.array(flavor, ndmin=1))

Nf,Nt,Ne = len(self.flavor), len(self.time), len(self.energy)
#list all valid shapes of the input array
expected_shapes=[(nf,nt,ne) for nf in (Nf,Nf-1) for nt in (Nt,Nt-1) for ne in (Ne,Ne-1)]
#treat special case if data is 1d array
if self.array.ndim==1:
#try to reshape the array to expected shape
for expected_shape in expected_shapes:
if np.prod(expected_shape)==self.array.size:
self.array = self.array.reshape(expected_shape)
break
#validate the data array shape
if self.array.shape not in expected_shapes:
raise ValueError(f"Data array of shape {data.shape} is inconsistent with any valid shapes {expected_shapes}")

if integrable_axes is not None:
#store which axes can be integrated
Expand Down
8 changes: 7 additions & 1 deletion python/snewpy/test/test_flux_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,19 @@ def test_construction_succesfull(flavor, energy, time, unit):
assert Container[unit](data, flavor, time, energy)

@given(flavor=flavors, time=times, energy=energies, unit=units)
def test_construction_with_wrong_units_raises_ValueError(flavor, energy, time, unit):
def test_construction_with_wrong_units_raises_ConversionError(flavor, energy, time, unit):
data = np.ones([len(flavor),len(time), len(energy)])<<unit
with pytest.raises(u.UnitConversionError):
Container[unit*u.kg](data, flavor, time, energy)
with pytest.raises(u.UnitConversionError):
Container[unit](data*u.kg, flavor, time, energy)

@given(flavor=flavors, time=times, energy=energies, unit=units)
def test_construction_with_wrong_dimensions_raises_ValueError(flavor, energy, time, unit):
data = np.ones([len(flavor)+1,len(time), len(energy)])<<unit
with pytest.raises(ValueError):
Container[unit](data, flavor, time, energy)

@given(f=random_flux_containers())
def test_summation_over_flavor(f:Container):
fS = f.sum('flavor')
Expand Down
Loading