Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: error if name already exists in analysis_tools's Weights and PackedSelection #1274

Merged
merged 2 commits into from
Feb 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions src/coffea/analysis_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def __init__(self, size, storeIndividual=False):
self._modifiers = {}
self._weightStats = {}
self._storeIndividual = storeIndividual
self._names = []

@property
def weightStatistics(self):
Expand All @@ -127,6 +128,7 @@ def __add_eager(self, name, weight, weightUp, weightDown, shift):
weight.max(),
weight.size,
)
self._names.append(name)

def __add_delayed(self, name, weight, weightUp, weightDown, shift):
"""Add a new weight with delayed calculation"""
Expand All @@ -148,6 +150,7 @@ def __add_delayed(self, name, weight, weightUp, weightDown, shift):
"minw": dask_awkward.min(weight),
"maxw": dask_awkward.max(weight),
}
self._names.append(name)

def add(self, name, weight, weightUp=None, weightDown=None, shift=False):
"""Add a new weight
Expand All @@ -173,6 +176,8 @@ def add(self, name, weight, weightUp=None, weightDown=None, shift=False):

.. note:: ``weightUp`` and ``weightDown`` are assumed to be rvalue-like and may be modified in-place by this function
"""
if name in self._names:
raise ValueError(f"Weight '{name}' already exists")
if name.endswith("Up") or name.endswith("Down"):
raise ValueError(
"Avoid using 'Up' and 'Down' in weight names, instead pass appropriate shifts to add() call"
Expand Down Expand Up @@ -223,6 +228,7 @@ def __add_multivariation_eager(
weight.max(),
weight.size,
)
self._names.append(name)

def __add_multivariation_delayed(
self, name, weight, modifierNames, weightsUp, weightsDown, shift=False
Expand Down Expand Up @@ -258,6 +264,7 @@ def __add_multivariation_delayed(
"minw": dask_awkward.min(weight),
"maxw": dask_awkward.max(weight),
}
self._names.append(name)

def add_multivariation(
self, name, weight, modifierNames, weightsUp, weightsDown, shift=False
Expand Down Expand Up @@ -287,6 +294,8 @@ def add_multivariation(

.. note:: ``weightUp`` and ``weightDown`` are assumed to be rvalue-like and may be modified in-place by this function
"""
if name in self._names:
raise ValueError(f"Weight '{name}' already exists")
if name.endswith("Up") or name.endswith("Down"):
raise ValueError(
"Avoid using 'Up' and 'Down' in weight names, instead pass appropriate shifts to add() call"
Expand Down Expand Up @@ -1234,6 +1243,8 @@ def add(self, name, selection, fill_value=False):
fill_value : bool, optional
All masked entries will be filled as specified (default: ``False``)
"""
if name in self._names:
raise ValueError(f"Selection '{name}' already exists")
if isinstance(selection, dask.array.Array):
raise ValueError(
"Dask arrays are not supported, please convert them to dask_awkward.Array by using dask_awkward.from_dask_array()"
Expand Down
43 changes: 41 additions & 2 deletions tests/test_analysis_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ def test_weights():
shift=True,
)

with pytest.raises(ValueError, match="Weight 'test' already exists"):
weight.add("test", scale_central, weightUp=scale_up, weightDown=scale_down)

var_names = weight.variations
expected_names = ["testShiftUp", "testShiftDown", "testUp", "testDown"]
for name in expected_names:
Expand Down Expand Up @@ -105,6 +108,9 @@ def test_weights_dak(optimization_enabled):
shift=True,
)

with pytest.raises(ValueError, match="Weight 'test' already exists"):
weight.add("test", scale_central, weightUp=scale_up, weightDown=scale_down)

var_names = weight.variations
expected_names = ["testShiftUp", "testShiftDown", "testUp", "testDown"]
for name in expected_names:
Expand Down Expand Up @@ -153,6 +159,15 @@ def test_weights_multivariation():
weightsDown=[scale_down, scale_down_2],
)

with pytest.raises(ValueError, match="Weight 'test' already exists"):
weight.add_multivariation(
"test",
scale_central,
modifierNames=["A", "B"],
weightsUp=[scale_up, scale_up_2],
weightsDown=[scale_down, scale_down_2],
)

var_names = weight.variations
expected_names = ["test_AUp", "test_ADown", "test_BUp", "test_BDown"]
for name in expected_names:
Expand Down Expand Up @@ -211,6 +226,15 @@ def test_weights_multivariation_dak(optimization_enabled):
weightsDown=[scale_down, scale_down_2],
)

with pytest.raises(ValueError, match="Weight 'test' already exists"):
weight.add_multivariation(
"test",
scale_central,
modifierNames=["A", "B"],
weightsUp=[scale_up, scale_up_2],
weightsDown=[scale_down, scale_down_2],
)

var_names = weight.variations
expected_names = ["test_AUp", "test_ADown", "test_BUp", "test_BDown"]
for name in expected_names:
Expand Down Expand Up @@ -253,6 +277,11 @@ def test_weights_partial():
weights.add("w1", w1)
weights.add("w2", w2)

with pytest.raises(ValueError, match="Weight 'w1' already exists"):
weights.add("w1", w1)
with pytest.raises(ValueError, match="Weight 'w2' already exists"):
weights.add("w2", w2)

test_exclude_none = weights.weight()
assert np.all(np.abs(test_exclude_none - w1 * w2) < 1e-6)

Expand Down Expand Up @@ -321,6 +350,11 @@ def test_weights_partial_dak(optimization_enabled):
weights.add("w1", w1)
weights.add("w2", w2)

with pytest.raises(ValueError, match="Weight 'w1' already exists"):
weights.add("w1", w1)
with pytest.raises(ValueError, match="Weight 'w2' already exists"):
weights.add("w2", w2)

test_exclude_none = weights.weight()
assert np.all(np.abs(test_exclude_none - w1 * w2).compute() < 1e-6)

Expand Down Expand Up @@ -397,6 +431,11 @@ def test_packed_selection_basic(dtype):
sel.add("fizz", fizz)
sel.add("buzz", buzz)

with pytest.raises(ValueError, match="Selection 'fizz' already exists"):
sel.add("fizz", fizz)
with pytest.raises(ValueError, match="Selection 'buzz' already exists"):
sel.add("buzz", buzz)

assert np.all(
sel.all()
== np.array(
Expand Down Expand Up @@ -449,7 +488,7 @@ def test_packed_selection_basic(dtype):
with pytest.raises(RuntimeError):
overpack = PackedSelection(dtype=dtype)
for i in range(65):
overpack.add("sel_%d", all_true)
overpack.add(f"sel_{i}", all_true)

with pytest.raises(
ValueError,
Expand Down Expand Up @@ -787,7 +826,7 @@ def test_packed_selection_basic_dak(optimization_enabled, dtype):
with pytest.raises(RuntimeError):
overpack = PackedSelection(dtype=dtype)
for i in range(65):
overpack.add("sel_%d", all_true)
overpack.add(f"sel_{i}", all_true)

with pytest.raises(
ValueError,
Expand Down
Loading