Skip to content

Commit

Permalink
Merge branch 'main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
gipert authored Jan 10, 2025
2 parents cf8b181 + cdca6ca commit 5ca3dff
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 48 deletions.
99 changes: 56 additions & 43 deletions src/pygama/evt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,22 @@
H5DataLoc = namedtuple(
"H5DataLoc", ("file", "group", "table_fmt"), defaults=3 * (None,)
)

DataInfo = namedtuple(
"DataInfo", ("raw", "tcm", "dsp", "hit", "evt"), defaults=5 * (None,)
)
DataInfo = namedtuple("DataInfo", ("raw", "tcm", "evt"), defaults=3 * (None,))

TCMData = namedtuple("TCMData", ("id", "idx", "cumulative_length"))


def make_files_config(data: dict):
if not isinstance(data, DataInfo):
if not isinstance(data, tuple):
if "raw" not in data:
data["raw"] = (None,)
if "tcm" not in data:
data["tcm"] = (None,)
if "evt" not in data:
data["evt"] = (None,)
DataInfo = namedtuple(
"DataInfo", tuple(data.keys()), defaults=len(data.keys()) * (None,)
)
return DataInfo(
*[
H5DataLoc(*data[tier]) if tier in data else H5DataLoc()
Expand Down Expand Up @@ -72,7 +78,7 @@ def find_parameters(
idx_ch,
field_list,
) -> dict:
"""Finds and returns parameters from `hit` and `dsp` tiers.
"""Finds and returns parameters from non `tcm`, `evt` tiers.
Parameters
----------
Expand All @@ -83,43 +89,38 @@ def find_parameters(
idx_ch
index array of entries to be read from datainfo.
field_list
list of tuples ``(tier, field)`` to be found in the `hit/dsp` tiers.
list of tuples ``(tier, field)`` to be found in non `tcm`, `evt` tiers.
"""
f = make_files_config(datainfo)

# find fields in either dsp, hit
dsp_flds = [e[1] for e in field_list if e[0] == f.dsp.group]
hit_flds = [e[1] for e in field_list if e[0] == f.hit.group]
final_dict = {}

hit_dict, dsp_dict = {}, {}
for name, tier in f._asdict().items():
if name not in ["tcm", "evt"] and tier.file is not None: # skip other tables
keys = [
k.split("/")[-1]
for k in lh5.ls(tier.file, f"{ch.replace('/', '')}/{tier.group}/")
]
flds = [e[1] for e in field_list if e[0] == name and e[1] in keys]

if len(hit_flds) > 0:
hit_ak = lh5.read_as(
f"{ch.replace('/', '')}/{f.hit.group}/",
f.hit.file,
field_mask=hit_flds,
idx=idx_ch,
library="ak",
)
if len(flds) > 0:
tier_ak = lh5.read_as(
f"{ch.replace('/', '')}/{tier.group}/",
tier.file,
field_mask=flds,
idx=idx_ch,
library="ak",
)

hit_dict = dict(
zip([f"{f.hit.group}_" + e for e in ak.fields(hit_ak)], ak.unzip(hit_ak))
)
tier_dict = dict(
zip(
[f"{name}_" + e for e in ak.fields(tier_ak)],
ak.unzip(tier_ak),
)
)
final_dict = final_dict | tier_dict

if len(dsp_flds) > 0:
dsp_ak = lh5.read_as(
f"{ch.replace('/', '')}/{f.dsp.group}/",
f.dsp.file,
field_mask=dsp_flds,
idx=idx_ch,
library="ak",
)

dsp_dict = dict(
zip([f"{f.dsp.group}_" + e for e in ak.fields(dsp_ak)], ak.unzip(dsp_ak))
)

return hit_dict | dsp_dict
return final_dict


def get_data_at_channel(
Expand Down Expand Up @@ -178,10 +179,16 @@ def get_data_at_channel(

# evaluate expression
# move tier+dots in expression to underscores (e.g. evt.foo -> evt_foo)

new_expr = expr
for name in f._asdict():
if name == "evt":
new_expr = new_expr.replace(f"{name}.", "")
elif name not in ["tcm", "raw"]:
new_expr = new_expr.replace(f"{name}.", f"{name}_")

res = eval(
expr.replace(f"{f.dsp.group}.", f"{f.dsp.group}_")
.replace(f"{f.hit.group}.", f"{f.hit.group}_")
.replace(f"{f.evt.group}.", ""),
new_expr,
var,
)

Expand Down Expand Up @@ -231,17 +238,23 @@ def get_mask_from_query(

# get sub evt based query condition if needed
if isinstance(query, str):
query_lst = re.findall(r"(hit|dsp).([a-zA-Z_$][\w$]*)", query)
query_lst = re.findall(
rf"({'|'.join(f._asdict().keys())}).([a-zA-Z_$][\w$]*)", query
)
query_var = find_parameters(
datainfo=datainfo,
ch=ch,
idx_ch=idx_ch,
field_list=query_lst,
)

new_query = query
for name in f._asdict():
if name not in ["tcm", "evt"]:
new_query = new_query.replace(f"{name}.", f"{name}_")

limarr = eval(
query.replace(f"{f.dsp.group}.", f"{f.dsp.group}_").replace(
f"{f.hit.group}.", f"{f.hit.group}_"
),
new_query,
query_var,
)

Expand Down
4 changes: 2 additions & 2 deletions src/pygama/flow/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -1411,7 +1411,7 @@ def load_iterator(
)
for file in gb.groups.keys()
]
tb_names += [self.filedb.get_table_name(tier, tb)] * len(gb)
tb_names += [[self.filedb.get_table_name(tier, tb)]] * len(gb)
idx_list += [
list(entry_list.loc[i, f"{level}_idx"])
for i in gb.groups.values()
Expand Down Expand Up @@ -1509,7 +1509,7 @@ def browse(
)
for file in gb.groups.keys()
]
tb_names += [self.filedb.get_table_name(tier, tb)] * len(gb)
tb_names += [[self.filedb.get_table_name(tier, tb)]] * len(gb)
idx_list += [
list(entry_list.loc[i, f"{parent}_idx"]) for i in gb.groups.values()
]
Expand Down
2 changes: 1 addition & 1 deletion src/pygama/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class NumbaPygamaDefaults(MutableMapping):
"""

def __init__(self) -> None:
self.parallel: bool = getenv_bool("PYGAMA_PARALLEL", default=True)
self.parallel: bool = getenv_bool("PYGAMA_PARALLEL", default=False)
self.fastmath: bool = getenv_bool("PYGAMA_FASTMATH", default=True)

def __getitem__(self, item: str) -> Any:
Expand Down
4 changes: 2 additions & 2 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

def test_math_numba_defaults():
assert pgu.numba_math_defaults_kwargs.fastmath
assert pgu.numba_math_defaults_kwargs.parallel
assert not pgu.numba_math_defaults_kwargs.parallel

pgu.numba_math_defaults.fastmath = False
assert ~pgu.numba_math_defaults.fastmath
assert not pgu.numba_math_defaults.fastmath

0 comments on commit 5ca3dff

Please sign in to comment.