Skip to content

Commit

Permalink
Merge branch 'master' into ikrommyd/error-if-name-exists-in-analysis-…
Browse files Browse the repository at this point in the history
…tools
  • Loading branch information
ikrommyd authored Feb 12, 2025
2 parents 6e903be + 3d051d9 commit c1cf851
Show file tree
Hide file tree
Showing 5 changed files with 134 additions and 15 deletions.
35 changes: 31 additions & 4 deletions binder/mltools.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
"# Running inference tools\n",
"\n",
"As machine learning (ML) becomes more popular in HEP analysis, `coffea` also\n",
"provide tools to assist with using ML tools within the coffea framework. For\n",
"provide tools to assist with using ML tools within the `coffea` framework. For\n",
"training and validation, you would likely need custom data mangling tools to\n",
"convert HEP data formats ([NanoAOD][nanoaod], [PFNano][pfnano]) to a format that\n",
"best interfaces with the ML tool of choice, as for training and validation, you\n",
Expand Down Expand Up @@ -756,15 +756,15 @@
"source": [
"## Additional comments on common `prepare_awkward` patterns\n",
"\n",
"The key requirement of all wrapper classes in `ml_tools` pacakge, is that to convert\n",
"The key requirement of all wrapper classes in `ml_tools` package, is that to convert\n",
"awkward arrays into `numpy`-compatible formats using just `awkward` related tools, \n",
"which ensures that no eager data conversion is performed on dask arrays. Below are\n",
"some common patterns that are useful when defining a user-level class.\n",
"\n",
"### Casting multiple fields a collection to be separate axis\n",
"\n",
"Given our collection of particles of length $N$, our tool is interested in just a \n",
"sub-set of fields is to be represented as an $N\\time M$ array. You can do acheive this \n",
"sub-set of fields is to be represented as an $N\\time M$ array. You can acheive this \n",
"using just `ak.concatenate` and dimension expansion with `np.newaxis`:\n",
"\n",
"```python\n",
Expand Down Expand Up @@ -801,7 +801,34 @@
"```python\n",
"part_padded = ak.flatten(part_padded)\n",
"part_padded = ak.unflatten(part_padded, 128) # Now this is a Nx128 array\n",
"```"
"```\n",
"\n",
"### Length-zero arrays\n",
"\n",
"In HEP analysis, a common routine will have you working with length-zero array in\n",
"individual chunks (ex. when running the primarly selection workflow for background\n",
"events). In these cases, you need to make sure the inference library that you are \n",
"using behaves as expected when processing length-zero arrays. Ideally, all the\n",
"upstream librarys should handle length-zero arrays correct, but in the edges cases\n",
"where your model using a more exotic functions is causing issues, the wrappings in the\n",
"`ml_tools` package has some mechanisms that can help handle such situations:\n",
"\n",
"- `tf_wrapper/tensorflow`: the `skip_length_zero` flag can be passed to the `tf_wrapper` \n",
" constructor. When this is set to `True`, when a length-0 array is detected to be the input, \n",
" the wrapper will generate a length 0 numpy array instead of attempting to pass the \n",
" input for inference.\n",
"- `torch_wrapper/pytorch`: as the output format of `pytorch` models is difficult to \n",
" implement detect at runtime without additional input, if you need to skip length-zero\n",
" inputs, the analyist must provide the shape of the output in the `expected_output_shape` \n",
" argument of the the `torch_wrapper` constructor. Notice that the shape should be in \n",
" the format of the `numpy.array.shape` output of nominal return values, with the first\n",
" argument substituted to be a `None` to indicate arbitrary lenght.\n",
"- `triton_wrapper/triton`: Length-zero is handled automatically by the internal batching\n",
" process. No additional user input is required.\n",
"- `xgboost_wrapper/xgboost`: Length-zero should always be handled by the underlying \n",
" `xgboost` library. No additional user input is required.\n",
"\n",
"\n"
]
}
],
Expand Down
20 changes: 16 additions & 4 deletions src/coffea/ml_tools/tf_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class tf_wrapper(nonserializable_attribute, numpy_call_wrapper):
Wrapper for running tensorflow inference with awkward/dask-awkward inputs.
"""

def __init__(self, tf_model: str):
def __init__(self, tf_model: str, skip_length_zero: bool = False):
"""
As models are not guaranteed to be directly serializable, the use will
need to pass the model as files saved using the `tf.keras.save` method
Expand All @@ -27,9 +27,17 @@ def __init__(self, tf_model: str):
[1]
https://www.tensorflow.org/guide/keras/serialization_and_saving#saving
Parameters ----------
- tf_model: Path to the tensorflow model file to load
Parameters
----------
tf_model:
Path to the tensorflow model file for computation
skip_length_zero:
Generating a default length 0 numpy array if the input array is
detected to be length-0 instead of passing it into the tensorflow
model. This option should only be used if the model uses tensorflow
functions that does not properly implement behaviors for length 0
inputs.
"""
if _tf_import_error is not None:
warnings.warn(
Expand All @@ -43,6 +51,7 @@ def __init__(self, tf_model: str):

nonserializable_attribute.__init__(self, ["model"])
self.tf_model = tf_model
self.skip_length_zero = skip_length_zero

def _create_model(self):
"""
Expand Down Expand Up @@ -94,6 +103,9 @@ def numpy_call(self, *args: numpy.array, **kwargs: numpy.array) -> numpy.array:
[1]
https://keras.io/getting_started/faq/#whats-the-difference-between-model-methods-predict-and-call
"""
first_arg = args[0] if len(args) else next(iter(kwargs.values()))
if len(first_arg) == 0 and self.skip_length_zero:
return numpy.zeros(shape=(0, *self.model.output_shape[1:]))
args = [
(
tensorflow.convert_to_tensor(arr)
Expand Down
25 changes: 24 additions & 1 deletion src/coffea/ml_tools/torch_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import warnings
from typing import Optional, Tuple

import numpy

Expand Down Expand Up @@ -43,9 +44,19 @@ class torch_wrapper(nonserializable_attribute, numpy_call_wrapper):
----------
torch_jit: str
Path to the TorchScript file to load
expected_output_shape: tuple(int)
A tuple representing the expected shape of the torch model return.
In case a length-0 inputs is detected and this value is not None,
the wrapper will return the length-0 numpy array of the same shape,
as there are methods in torch that is incompatible with length-0
inputs. Note that the leading entry in shape should be None to
indicate that the outer-most dimension is arbitrary. It will always
be ignored in the operation.
"""

def __init__(self, torch_jit: str):
def __init__(
self, torch_jit: str, expected_output_shape: Optional[Tuple[int]] = None
):
if _torch_import_error is not None:
warnings.warn(
"Users should make sure the torch package is installed before proceeding!\n"
Expand All @@ -58,6 +69,14 @@ def __init__(self, torch_jit: str):

nonserializable_attribute.__init__(self, ["model", "device"])
self.torch_jit = torch_jit
self.expected_output_shape = expected_output_shape
if (
self.expected_output_shape is not None
and self.expected_output_shape[0] is not None
):
warnings.warn(
"The outermost dimension will ignored for fallback situations, set leading dimension to None to avoid seeing this."
)

def _create_device(self):
"""
Expand Down Expand Up @@ -90,6 +109,10 @@ def numpy_call(self, *args: numpy.array, **kwargs: numpy.array) -> numpy.array:
Evaluating the numpy inputs via the model. Returning the results also as
as numpy array.
"""
first_arg = args[0] if len(args) else next(iter(kwargs.values()))
if len(first_arg) == 0 and self.expected_output_shape is not None:
return numpy.zeros(shape=(0, *self.expected_output_shape[1:]))

args = [
(
torch.from_numpy(arr)
Expand Down
23 changes: 18 additions & 5 deletions src/coffea/ml_tools/triton_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,9 +133,14 @@ def _create_model_inputs(self) -> Dict[str, Dict]:
for x in self.model_metadata["inputs"]
}

def _create_model_outputs(self) -> List[int]:
"""Getting a list of names of possible outputs"""
return [x["name"] for x in self.model_metadata["outputs"]]
def _create_model_outputs(self) -> Dict[str, Dict]:
"""
Extracting the model output data format.
"""
return {
x["name"]: {"shape": tuple(int(i) for i in x["shape"])}
for x in self.model_metadata["outputs"]
}

@property
def batch_size(self) -> int:
Expand All @@ -156,8 +161,7 @@ def batch_size(self) -> int:
self._batch_size = model_config["max_batch_size"]
else:
warnings.warn(
f"Batch size not set by model! Setting to default value {self.batch_size_fallback}. "
"Contact model maintainer to check if this is expected",
f"Batch size not set by model! Setting to default value {self.batch_size_fallback}. Contact model maintainer to check if this is expected",
UserWarning,
)
self._batch_size = self.batch_size_fallback
Expand Down Expand Up @@ -322,4 +326,13 @@ def _get_infer_shape(name):
output[o] = numpy.concatenate(
(output[o], request.as_numpy(o)), axis=0
)

if (
output is None
): # Input was a length-0, so we should generate the length-0 outputs with correct dimension
return {
o: numpy.zeros(shape=(0, *self.model_outputs[o]["shape"][1:]))
for o in output_list
}

return {k: v[:orig_len] for k, v in output.items()}
46 changes: 45 additions & 1 deletion tests/test_ml_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,12 @@ def prepare_awkward(self, output_list, jets):
columns = set(list(dak.necessary_columns(dak_res).values())[0])
assert columns == expected_columns

# Length 0 tests
ak_res = tw(["output"], ak_jets[ak_jets.eta < 0])
dak_res = tw(["output"], dak_jets[dak_jets.eta < 0])
for k in ak_res.keys():
assert len(ak_res[k]) == 0 and len(dak_res[k].compute()) == 0

client.close()


Expand All @@ -140,7 +146,6 @@ def prepare_awkward(self, jets):

tw = torch_wrapper_test("tests/samples/pn_demo.pt")
ak_jets, dak_jets = prepare_jets_array(njets=256)

ak_res = tw(ak_jets)
dak_res = tw(dak_jets)

Expand All @@ -156,6 +161,15 @@ def prepare_awkward(self, jets):
}
columns = set(list(dak.necessary_columns(dak_res).values())[0])
assert columns == expected_columns

# Length-0 testing
tw = torch_wrapper_test("tests/samples/pn_demo.pt", expected_output_shape=(None,))
ak_jets, dak_jets = prepare_jets_array(njets=256)
ak_jets = ak_jets[ak_jets.eta < -100] # Mimicking a low efficiency selection
dak_jets = dak_jets[dak_jets.eta < -100]
ak_res, dak_res = tw(ak_jets), tw(dak_jets)
assert len(ak_jets) == 0 and len(dak_res.compute()) == 0

client.close()


Expand Down Expand Up @@ -207,6 +221,30 @@ def postprocess_awkward(self, ret, jets):
expected_columns = {"ncands"} | {f"pfcands.feat{i}" for i in range(1, 19)}
columns = set(list(dak.necessary_columns(dak_res).values())[0])
assert columns == expected_columns

# Length 0 testing. we cannot use the unflatten module in this case
class tf_wrapper_lenght0_test(tf_wrapper):
def prepare_awkward(self, arr):
return [arr], {}

tfw_length0_tester = tf_wrapper_lenght0_test(
"tests/samples/tf_model.keras", skip_length_zero=True
)

# Making an explicit shape
arr = ak.from_numpy(np.random.random(size=(10, 64, 18)))
ak.to_parquet(arr, "tf_length10.parquet")
darr = dak.from_parquet("tf_length10.parquet")
ak_res = tfw_length0_tester(arr)
dak_res = tfw_length0_tester(darr)
assert np.all(np.isclose(ak_res, dak_res.compute()))
# Reducing the length 0
arr = ak.from_numpy(np.zeros(shape=(0, 64, 18)))
ak.to_parquet(arr, "tf_length0.parquet")
darr = dak.from_parquet("tf_length0.parquet")
ak_res = tfw_length0_tester(arr)
dak_res = tfw_length0_tester(darr)

client.close()


Expand Down Expand Up @@ -244,4 +282,10 @@ def prepare_awkward(self, events):
# Should only load required columns
columns = set(list(dak.necessary_columns(dak_res).values())[0])
assert columns == set(feature_list)

# Length 0 testing, xgboost always handles 0-length arrays elegantly
ak_res = xgb_wrap(ak_events[ak_events.feat0 < 0])
dak_res = xgb_wrap(dak_events[dak_events.feat0 < 0])
assert len(ak_res) == 0 and len(dak_res.compute()) == 0

client.close()

0 comments on commit c1cf851

Please sign in to comment.