From 2337326392a4b9e847e961833defe1f070e08932 Mon Sep 17 00:00:00 2001 From: ggmarshall Date: Wed, 9 Oct 2024 13:44:27 +0200 Subject: [PATCH 1/7] make more generic to handle new tiers --- src/pygama/evt/utils.py | 91 ++++++++++++++++++++++------------------- 1 file changed, 48 insertions(+), 43 deletions(-) diff --git a/src/pygama/evt/utils.py b/src/pygama/evt/utils.py index 4f8391353..fa559f405 100644 --- a/src/pygama/evt/utils.py +++ b/src/pygama/evt/utils.py @@ -16,16 +16,16 @@ H5DataLoc = namedtuple( "H5DataLoc", ("file", "group", "table_fmt"), defaults=3 * (None,) ) - -DataInfo = namedtuple( - "DataInfo", ("raw", "tcm", "dsp", "hit", "evt"), defaults=5 * (None,) -) +DataInfo = namedtuple("DataInfo", ("raw"), defaults=1 * (None,)) TCMData = namedtuple("TCMData", ("id", "idx", "cumulative_length")) def make_files_config(data: dict): - if not isinstance(data, DataInfo): + if not isinstance(data, tuple): + DataInfo = namedtuple( + "DataInfo", tuple(data.keys()), defaults=len(data.keys()) * (None,) + ) return DataInfo( *[ H5DataLoc(*data[tier]) if tier in data else H5DataLoc() @@ -72,7 +72,7 @@ def find_parameters( idx_ch, field_list, ) -> dict: - """Finds and returns parameters from `hit` and `dsp` tiers. + """Finds and returns parameters from non `tcm`, `evt` tiers. Parameters ---------- @@ -83,43 +83,38 @@ def find_parameters( idx_ch index array of entries to be read from datainfo. field_list - list of tuples ``(tier, field)`` to be found in the `hit/dsp` tiers. + list of tuples ``(tier, field)`` to be found in non `tcm`, `evt` tiers. """ f = make_files_config(datainfo) - # find fields in either dsp, hit - dsp_flds = [e[1] for e in field_list if e[0] == f.dsp.group] - hit_flds = [e[1] for e in field_list if e[0] == f.hit.group] + final_dict = {} - hit_dict, dsp_dict = {}, {} + for name, tier in f._asdict().items(): + if name not in ["tcm", "evt"] and tier.file is not None: # skip other tables + keys = [ + k.split("/")[-1] + for k in lh5.ls(tier.file, f"{ch.replace('/', '')}/{tier.group}/") + ] + flds = [e[1] for e in field_list if e[0] == name and e[1] in keys] - if len(hit_flds) > 0: - hit_ak = lh5.read_as( - f"{ch.replace('/', '')}/{f.hit.group}/", - f.hit.file, - field_mask=hit_flds, - idx=idx_ch, - library="ak", - ) + if len(flds) > 0: + tier_ak = lh5.read_as( + f"{ch.replace('/', '')}/{tier.group}/", + tier.file, + field_mask=flds, + idx=idx_ch, + library="ak", + ) - hit_dict = dict( - zip([f"{f.hit.group}_" + e for e in ak.fields(hit_ak)], ak.unzip(hit_ak)) - ) + tier_dict = dict( + zip( + [f"{name}_" + e for e in ak.fields(tier_ak)], + ak.unzip(tier_ak), + ) + ) + final_dict = final_dict | tier_dict - if len(dsp_flds) > 0: - dsp_ak = lh5.read_as( - f"{ch.replace('/', '')}/{f.dsp.group}/", - f.dsp.file, - field_mask=dsp_flds, - idx=idx_ch, - library="ak", - ) - - dsp_dict = dict( - zip([f"{f.dsp.group}_" + e for e in ak.fields(dsp_ak)], ak.unzip(dsp_ak)) - ) - - return hit_dict | dsp_dict + return final_dict def get_data_at_channel( @@ -178,10 +173,14 @@ def get_data_at_channel( # evaluate expression # move tier+dots in expression to underscores (e.g. evt.foo -> evt_foo) + + new_expr = expr + for name in f._asdict(): + if name not in ["tcm", "raw"]: + new_expr = new_expr.replace(f"{name}.", f"{name}_") + res = eval( - expr.replace(f"{f.dsp.group}.", f"{f.dsp.group}_") - .replace(f"{f.hit.group}.", f"{f.hit.group}_") - .replace(f"{f.evt.group}.", ""), + new_expr, var, ) @@ -231,17 +230,23 @@ def get_mask_from_query( # get sub evt based query condition if needed if isinstance(query, str): - query_lst = re.findall(r"(hit|dsp).([a-zA-Z_$][\w$]*)", query) + query_lst = re.findall( + rf"({'|'.join(f._asdict().keys())}).([a-zA-Z_$][\w$]*)", query + ) query_var = find_parameters( datainfo=datainfo, ch=ch, idx_ch=idx_ch, field_list=query_lst, ) + + new_query = query + for name in f._asdict(): + if name not in ["tcm", "evt"]: + new_query = new_query.replace(f"{name}.", f"{name}_") + limarr = eval( - query.replace(f"{f.dsp.group}.", f"{f.dsp.group}_").replace( - f"{f.hit.group}.", f"{f.hit.group}_" - ), + new_query, query_var, ) From 8e823c0e42c22c252bcb819a3a0ff616fcc159d6 Mon Sep 17 00:00:00 2001 From: ggmarshall Date: Wed, 9 Oct 2024 14:39:23 +0200 Subject: [PATCH 2/7] fix for tests --- src/pygama/evt/utils.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/pygama/evt/utils.py b/src/pygama/evt/utils.py index fa559f405..3644e141d 100644 --- a/src/pygama/evt/utils.py +++ b/src/pygama/evt/utils.py @@ -16,13 +16,19 @@ H5DataLoc = namedtuple( "H5DataLoc", ("file", "group", "table_fmt"), defaults=3 * (None,) ) -DataInfo = namedtuple("DataInfo", ("raw"), defaults=1 * (None,)) +DataInfo = namedtuple("DataInfo", ("raw", "tcm", "evt"), defaults=3 * (None,)) TCMData = namedtuple("TCMData", ("id", "idx", "cumulative_length")) def make_files_config(data: dict): if not isinstance(data, tuple): + if "raw" not in data: + data["raw"] = (None,) + if "tcm" not in data: + data["tcm"] = (None,) + if "evt" not in data: + data["evt"] = (None,) DataInfo = namedtuple( "DataInfo", tuple(data.keys()), defaults=len(data.keys()) * (None,) ) From 2d143cd327f0e7fa6925b53f2485386ecc740c0b Mon Sep 17 00:00:00 2001 From: ggmarshall Date: Wed, 9 Oct 2024 14:52:17 +0200 Subject: [PATCH 3/7] evt replaced by empty string --- src/pygama/evt/utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/pygama/evt/utils.py b/src/pygama/evt/utils.py index 3644e141d..4aedc1438 100644 --- a/src/pygama/evt/utils.py +++ b/src/pygama/evt/utils.py @@ -182,7 +182,9 @@ def get_data_at_channel( new_expr = expr for name in f._asdict(): - if name not in ["tcm", "raw"]: + if name == "evt": + new_expr = new_expr.replace(f"{name}.", "") + elif name not in ["tcm", "raw"]: new_expr = new_expr.replace(f"{name}.", f"{name}_") res = eval( From 29e17a6bb11528f8b3928a751768242e273ee631 Mon Sep 17 00:00:00 2001 From: iguinn Date: Fri, 1 Nov 2024 10:15:22 -0700 Subject: [PATCH 4/7] Default parallel to False --- src/pygama/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pygama/utils.py b/src/pygama/utils.py index 888ca396c..b6b73d26c 100644 --- a/src/pygama/utils.py +++ b/src/pygama/utils.py @@ -51,7 +51,7 @@ class NumbaPygamaDefaults(MutableMapping): """ def __init__(self) -> None: - self.parallel: bool = getenv_bool("PYGAMA_PARALLEL", default=True) + self.parallel: bool = getenv_bool("PYGAMA_PARALLEL", default=False) self.fastmath: bool = getenv_bool("PYGAMA_FASTMATH", default=True) def __getitem__(self, item: str) -> Any: From 2db25b8aba2d1bd0b7346e4471a2ed050f7f10cb Mon Sep 17 00:00:00 2001 From: iguinn Date: Fri, 1 Nov 2024 11:27:39 -0700 Subject: [PATCH 5/7] Changes needed for changes to LGDO iterator --- src/pygama/flow/data_loader.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/pygama/flow/data_loader.py b/src/pygama/flow/data_loader.py index aa65d2a3c..53c103c57 100644 --- a/src/pygama/flow/data_loader.py +++ b/src/pygama/flow/data_loader.py @@ -1411,7 +1411,7 @@ def load_iterator( ) for file in gb.groups.keys() ] - tb_names += [self.filedb.get_table_name(tier, tb)] * len(gb) + tb_names += [ [self.filedb.get_table_name(tier, tb)] ]* len(gb) idx_list += [ list(entry_list.loc[i, f"{level}_idx"]) for i in gb.groups.values() @@ -1509,7 +1509,7 @@ def browse( ) for file in gb.groups.keys() ] - tb_names += [self.filedb.get_table_name(tier, tb)] * len(gb) + tb_names += [ [self.filedb.get_table_name(tier, tb)] ]* len(gb) idx_list += [ list(entry_list.loc[i, f"{parent}_idx"]) for i in gb.groups.values() ] From ebed9fd4854ef74185b544d42e4c1db6679545dc Mon Sep 17 00:00:00 2001 From: iguinn Date: Fri, 1 Nov 2024 11:27:56 -0700 Subject: [PATCH 6/7] Fixed math defaults test --- tests/test_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_utils.py b/tests/test_utils.py index ba70ad382..95e3f1b2c 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -3,7 +3,7 @@ def test_math_numba_defaults(): assert pgu.numba_math_defaults_kwargs.fastmath - assert pgu.numba_math_defaults_kwargs.parallel + assert not pgu.numba_math_defaults_kwargs.parallel pgu.numba_math_defaults.fastmath = False - assert ~pgu.numba_math_defaults.fastmath + assert not pgu.numba_math_defaults.fastmath From ddb25b7cb9ef865f78f500f2713ac97c953bac22 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 1 Nov 2024 18:28:36 +0000 Subject: [PATCH 7/7] style: pre-commit fixes --- src/pygama/flow/data_loader.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/pygama/flow/data_loader.py b/src/pygama/flow/data_loader.py index 53c103c57..e2c9f6c5f 100644 --- a/src/pygama/flow/data_loader.py +++ b/src/pygama/flow/data_loader.py @@ -1411,7 +1411,7 @@ def load_iterator( ) for file in gb.groups.keys() ] - tb_names += [ [self.filedb.get_table_name(tier, tb)] ]* len(gb) + tb_names += [[self.filedb.get_table_name(tier, tb)]] * len(gb) idx_list += [ list(entry_list.loc[i, f"{level}_idx"]) for i in gb.groups.values() @@ -1509,7 +1509,7 @@ def browse( ) for file in gb.groups.keys() ] - tb_names += [ [self.filedb.get_table_name(tier, tb)] ]* len(gb) + tb_names += [[self.filedb.get_table_name(tier, tb)]] * len(gb) idx_list += [ list(entry_list.loc[i, f"{parent}_idx"]) for i in gb.groups.values() ]