diff --git a/analysis_templates/cms_minimal/law.cfg b/analysis_templates/cms_minimal/law.cfg index 6f2538d52..00b97f7ab 100644 --- a/analysis_templates/cms_minimal/law.cfg +++ b/analysis_templates/cms_minimal/law.cfg @@ -49,8 +49,8 @@ default_keep_reduced_events: True # slightly to the left to avoid them being excluded from the last bin; None leads to automatic mode default_histogram_last_edge_inclusive: None -# boolean flag that, if True, sets the *hists* output of cf.SelectEvents and cf.MergeSelectionStats to optional -default_selection_hists_optional: True +# boolean flag that, if True, configures cf.SelectEvents to create statistics histograms +default_create_selection_hists: False # wether or not the ensure_proxy decorator should be skipped, even if used by task's run methods skip_ensure_proxy: False diff --git a/columnflow/calibration/cms/egamma.py b/columnflow/calibration/cms/egamma.py new file mode 100644 index 000000000..5b0bf7ed7 --- /dev/null +++ b/columnflow/calibration/cms/egamma.py @@ -0,0 +1,613 @@ +# coding: utf-8 + +""" +Egamma energy correction methods. +Source: https://twiki.cern.ch/twiki/bin/view/CMS/EgammSFandSSRun3#Scale_And_Smearings_Correctionli +""" + +from __future__ import annotations + +import abc +import functools +import law +from dataclasses import dataclass, field + +from columnflow.calibration import Calibrator, calibrator +from columnflow.calibration.util import ak_random +from columnflow.util import maybe_import, InsertableDict +from columnflow.columnar_util import ( + set_ak_column, flat_np_view, ak_copy, optional_column, +) +from columnflow.types import Any + +ak = maybe_import("awkward") +np = maybe_import("numpy") + + +# helper +set_ak_column_f32 = functools.partial(set_ak_column, value_type=np.float32) + + +@dataclass +class EGammaCorrectionConfig: + correction_set: str = "Scale" + corrector_kwargs: dict[str, Any] = field(default_factory=dict) + + +class egamma_scale_corrector(Calibrator): + + with_uncertainties = True + """Switch to control whether uncertainties are calculated.""" + + @property + @abc.abstractmethod + def source_field(self) -> str: + """Fields required for the current calibrator.""" + ... + + @abc.abstractmethod + def get_correction_file(self, external_files: law.FileTargetCollection) -> law.LocalFile: + """Function to retrieve the correction file from the external files. + + :param external_files: File target containing the files as requested + in the current config instance under ``config_inst.x.external_files`` + """ + ... + + @abc.abstractmethod + def get_scale_config(self) -> EGammaCorrectionConfig: + """Function to retrieve the configuration for the photon energy correction.""" + ... + + def call_func( + self, + events: ak.Array, + **kwargs, + ) -> ak.Array: + """ + Apply energy corrections to EGamma objects in the events array. + + This implementation follows the recommendations from the EGamma POG: + https://twiki.cern.ch/twiki/bin/view/CMS/EgammSFandSSRun3#Scale_And_Smearings_Example + + Derivatives of this base class require additional member variables and + functions: + + - *source_field*: The field name of the EGamma objects in the events array (i.e. `Electron` or `Photon`). + - *get_correction_file*: Function to retrieve the correction file, e.g. + from the list of external files in the current `config_inst`. + - *get_scale_config*: Function to retrieve the configuration for the energy correction. + This config must be an instance of :py:class:`~columnflow.calibration.cms.egamma.EGammaCorrectionConfig`. + + If no raw pt (i.e., pt before any corrections) is available, use the nominal pt. + The correction tool only supports flat arrays, so inputs are converted to a flat numpy view first. + Corrections are always applied to the raw pt, which is important if more than one correction is applied in a + row. The final corrections must be applied to the current pt. + + If :py:attr:`with_uncertainties` is set to `True`, the scale uncertainties are calculated. + The scale uncertainties are only available for simulated data. + + :param events: The events array containing EGamma objects. + :return: The events array with applied scale corrections. + + :notes: + - Varied corrections are only applied to Monte Carlo (MC) data. + - EGamma energy correction is only applied to real data. + - Changes are applied to the views and directly propagate to the original awkward arrays. + """ + + # if no raw pt (i.e. pt for any corrections) is available, use the nominal pt + + if "rawPt" not in events[self.source_field].fields: + events = set_ak_column_f32( + events, f"{self.source_field}.rawPt", events[self.source_field].pt, + ) + # the correction tool only supports flat arrays, so convert inputs to flat np view first + # corrections are always applied to the raw pt - this is important if more than + # one correction is applied in a row + pt_eval = flat_np_view(events[self.source_field].rawPt, axis=1) + + # the final corrections must be applied to the current pt though + pt_application = flat_np_view(events[self.source_field].pt, axis=1) + + broadcasted_run = ak.broadcast_arrays( + events[self.source_field].pt, events.run, + ) + run = flat_np_view(broadcasted_run[1], axis=1) + gain = flat_np_view(events[self.source_field].seedGain, axis=1) + sceta = flat_np_view(events[self.source_field].superclusterEta, axis=1) + r9 = flat_np_view(events[self.source_field].r9, axis=1) + + # prepare arguments + # we use pt as et since there depends in linear (following the recoomendations) + # (energy is part of the LorentzVector behavior) + variable_map = { + "et": pt_eval, + "eta": sceta, + "gain": gain, + "r9": r9, + "run": run, + **self.scale_config.corrector_kwargs, + } + args = tuple( + variable_map[inp.name] for inp in self.scale_corrector.inputs + if inp.name in variable_map + ) + + # varied corrections are only applied to MC + if self.with_uncertainties and self.dataset_inst.is_mc: + scale_uncertainties = self.scale_corrector("total_uncertainty", *args) + scales_up = (1 + scale_uncertainties) + scales_down = (1 - scale_uncertainties) + + for (direction, scales) in [("up", scales_up), ("down", scales_down)]: + # copy pt and mass + pt_varied = ak_copy(events[self.source_field].pt) + pt_view = flat_np_view(pt_varied, axis=1) + + # apply the scale variation + pt_view *= scales + + # save columns + postfix = f"scale_{direction}" + events = set_ak_column_f32( + events, f"{self.source_field}.pt_{postfix}", pt_varied, + ) + + # apply the nominal correction + # note: changes are applied to the views and directly propagate to the original ak arrays + # and do not need to be inserted into the events chunk again + # EGamma energy correction is ONLY applied to DATA + if self.dataset_inst.is_data: + scales_nom = self.scale_corrector("total_correction", *args) + pt_application *= scales_nom + + return events + + def init_func(self) -> None: + """Function to initialize the calibrator. + + Sets the required and produced columns for the calibrator. + """ + self.uses |= { + # nano columns + f"{self.source_field}.{{seedGain,pt,superclusterEta,r9}}", + "run", + optional_column(f"{self.source_field}.rawPt"), + } + self.produces |= { + f"{self.source_field}.pt", + optional_column(f"{self.source_field}.rawPt"), + } + + # if we do not calculate uncertainties, this module + # should only run on observed DATA + self.data_only = not self.with_uncertainties + + # add columns with unceratinties if requested + # photon scale _uncertainties_ are only available for MC + if self.with_uncertainties and getattr(self, "dataset_inst", None): + if self.dataset_inst.is_mc: + self.produces |= {f"{self.source_field}.pt_scale_{{up,down}}"} + + def requires_func(self, reqs: dict) -> None: + """Function to add necessary requirements. + + This function add the :py:class:`~columnflow.tasks.external.BundleExternalFiles` + task to the requirements. + + :param reqs: Dictionary of requirements. + """ + from columnflow.tasks.external import BundleExternalFiles + reqs["external_files"] = BundleExternalFiles.req(self.task) + + def setup_func( + self, + reqs: dict, + inputs: dict, + reader_targets: InsertableDict, + ) -> None: + """Setup function before event chunk loop. + + This function loads the correction file and sets up the correction tool. + Additionally, the *scale_config* is retrieved. + + :param reqs: Dictionary with resolved requirements. + :param inputs: Dictionary with inputs (not used). + :param reader_targets: Dictionary for optional additional columns to load + (not used). + """ + bundle = reqs["external_files"] + self.scale_config = self.get_scale_config() + + # create the egamma corrector + import correctionlib + correctionlib.highlevel.Correction.__call__ = correctionlib.highlevel.Correction.evaluate + correction_set = correctionlib.CorrectionSet.from_string( + self.get_correction_file(bundle.files).load(formatter="gzip").decode("utf-8"), + ) + self.scale_corrector = correction_set[self.scale_config.correction_set] + + # check versions + assert self.scale_corrector.version in [0, 1, 2] + + +class egamma_resolution_corrector(Calibrator): + + with_uncertainties = True + """Switch to control whether uncertainties are calculated.""" + + # smearing of the energy resolution is only applied to MC + mc_only = True + """This calibrator is only applied to simulated data.""" + + deterministic_seed_index = -1 + """ use deterministic seeds for random smearing and + take the "index"-th random number per seed when not -1 + """ + + @property + @abc.abstractmethod + def source_field(self) -> str: + """Fields required for the current calibrator.""" + ... + + @abc.abstractmethod + def get_correction_file(self, external_files: law.FileTargetCollection) -> law.LocalFile: + """Function to retrieve the correction file from the external files. + + :param external_files: File target containing the files as requested + in the current config instance under ``config_inst.x.external_files`` + """ + ... + + @abc.abstractmethod + def get_resolution_config(self) -> EGammaCorrectionConfig: + """Function to retrieve the configuration for the photon energy correction.""" + ... + + def call_func( + self, + events: ak.Array, + **kwargs, + ) -> ak.Array: + """ + Apply energy resolution corrections to EGamma objects in the events array. + + This implementation follows the recommendations from the EGamma POG: + https://twiki.cern.ch/twiki/bin/view/CMS/EgammSFandSSRun3#Scale_And_Smearings_Example + + Derivatives of this base class require additional member variables and + functions: + + - *source_field*: The field name of the EGamma objects in the events array (i.e. `Electron` or `Photon`). + - *get_correction_file*: Function to retrieve the correction file, e.g. + from the list of external files in the current `config_inst`. + - *get_resolution_config*: Function to retrieve the configuration for the energy resolution correction. + This config must be an instance of :py:class:`~columnflow.calibration.cms.egamma.EGammaCorrectionConfig`. + + If no raw pt (i.e., pt before any corrections) is available, use the nominal pt. + The correction tool only supports flat arrays, so inputs are converted to a flat numpy view first. + Corrections are always applied to the raw pt, which is important if more than one correction is applied in a + row. The final corrections must be applied to the current pt. + + If :py:attr:`with_uncertainties` is set to `True`, the resolution uncertainties are calculated. + + If :py:attr:`deterministic_seed_index` is set to a value greater than or equal to 0, deterministic seeds + are used for random smearing. The "index"-th random number per seed is taken for the nominal resolution + correction. The "index+1"-th random number per seed is taken for the up variation and the "index+2"-th random + number per seed is taken for the down variation. + + :param events: The events array containing EGamma objects. + :return: The events array with applied resolution corrections. + + :notes: + - Energy resolution correction are only to be applied to simulation. + - Changes are applied to the views and directly propagate to the original awkward arrays. + """ + + # if no raw pt (i.e. pt for any corrections) is available, use the nominal pt + if "rawPt" not in events[self.source_field].fields: + events = set_ak_column_f32( + events, f"{self.source_field}.rawPt", ak_copy(events[self.source_field].pt), + ) + + # the correction tool only supports flat arrays, so convert inputs to flat np view first + + sceta = flat_np_view(events[self.source_field].superclusterEta, axis=1) + r9 = flat_np_view(events[self.source_field].r9, axis=1) + flat_seeds = flat_np_view(events[self.source_field].deterministic_seed, axis=1) + + # prepare arguments + # we use pt as et since there depends in linear (following the recoomendations) + # (energy is part of the LorentzVector behavior) + variable_map = { + "eta": sceta, + "r9": r9, + **self.resolution_config.corrector_kwargs, + } + args = tuple( + variable_map[inp.name] for inp in self.resolution_corrector.inputs + if inp.name in variable_map + ) + + # calculate the smearing scale + rho = self.resolution_corrector("rho", *args) + + # -- stochastic smearing + # normally distributed random numbers according to EGamma resolution + + # varied corrections + if self.with_uncertainties and self.dataset_inst.is_mc: + rho_unc = self.resolution_corrector("err_rho", *args) + smearing_up = ( + ak_random( + 1, rho + rho_unc, flat_seeds, + rand_func=self.deterministic_normal_up, + ) + if self.deterministic_seed_index >= 0 + else ak_random(1, rho + rho_unc, rand_func=np.random.Generator( + np.random.SFC64(events.event.to_list())).normal, + ) + ) + + smearing_down = ( + ak_random( + 1, rho - rho_unc, flat_seeds, + rand_func=self.deterministic_normal_down, + ) + if self.deterministic_seed_index >= 0 + else ak_random(1, rho - rho_unc, rand_func=np.random.Generator( + np.random.SFC64(events.event.to_list())).normal, + ) + ) + + for (direction, smear) in [("up", smearing_up), ("down", smearing_down)]: + # copy pt and mass + pt_varied = ak_copy(events[self.source_field].pt) + pt_view = flat_np_view(pt_varied, axis=1) + + # apply the scale variation + # cast ak to numpy array for convenient usage of *= + pt_view *= smear.to_numpy() + + # save columns + postfix = f"res_{direction}" + events = set_ak_column_f32( + events, f"{self.source_field}.pt_{postfix}", pt_varied, + ) + + # apply the nominal correction + # note: changes are applied to the views and directly propagate to the original ak arrays + # and do not need to be inserted into the events chunk again + # EGamma energy resolution correction is ONLY applied to MC + if self.dataset_inst.is_mc: + smearing = ( + ak_random(1, rho, flat_seeds, rand_func=self.deterministic_normal) + if self.deterministic_seed_index >= 0 + else ak_random(1, rho, rand_func=np.random.Generator( + np.random.SFC64(events.event.to_list())).normal, + ) + ) + # the final corrections must be applied to the current pt though + pt = flat_np_view(events[self.source_field].pt, axis=1) + pt *= smearing.to_numpy() + + return events + + def init_func(self) -> None: + """Function to initialize the calibrator. + + Sets the required and produced columns for the calibrator. + """ + self.uses |= { + # nano columns + f"{self.source_field}.{{pt,superclusterEta,r9}}", + optional_column(f"{self.source_field}.rawPt"), + } + self.produces |= { + f"{self.source_field}.pt", + optional_column(f"{self.source_field}.rawPt"), + } + + # add columns with unceratinties if requested + if self.with_uncertainties and getattr(self, "dataset_inst", None): + if self.dataset_inst.is_mc: + self.produces |= {f"{self.source_field}.pt_res_{{up,down}}"} + + def requires_func(self, reqs: dict) -> None: + """Function to add necessary requirements. + + This function add the :py:class:`~columnflow.tasks.external.BundleExternalFiles` + task to the requirements. + + :param reqs: Dictionary of requirements. + """ + from columnflow.tasks.external import BundleExternalFiles + reqs["external_files"] = BundleExternalFiles.req(self.task) + + def setup_func(self, reqs: dict, inputs: dict, reader_targets: InsertableDict) -> None: + """Setup function before event chunk loop. + + This function loads the correction file and sets up the correction tool. + Additionally, the *resolution_config* is retrieved. + If :py:attr:`deterministic_seed_index` is set to a value greater than or equal to 0, + random generator based on object-specific random seeds are setup. + + :param reqs: Dictionary with resolved requirements. + :param inputs: Dictionary with inputs (not used). + :param reader_targets: Dictionary for optional additional columns to load + (not used). + """ + bundle = reqs["external_files"] + self.resolution_config = self.get_resolution_config() + + # create the egamma corrector + import correctionlib + correctionlib.highlevel.Correction.__call__ = correctionlib.highlevel.Correction.evaluate + correction_set = correctionlib.CorrectionSet.from_string( + self.get_correction_file(bundle.files).load(formatter="gzip").decode("utf-8"), + ) + self.resolution_corrector = correction_set[self.resolution_config.correction_set] + + # check versions + assert self.resolution_corrector.version in [0, 1, 2] + + # use deterministic seeds for random smearing if requested + if self.deterministic_seed_index >= 0: + idx = self.deterministic_seed_index + bit_generator = np.random.SFC64 + def deterministic_normal(loc, scale, seed, idx_offset=0): + return np.asarray([ + np.random.Generator(bit_generator(_seed)).normal(_loc, _scale, size=idx + 1 + idx_offset)[-1] + for _loc, _scale, _seed in zip(loc, scale, seed) + ]) + self.deterministic_normal = functools.partial(deterministic_normal, idx_offset=0) + self.deterministic_normal_up = functools.partial(deterministic_normal, idx_offset=1) + self.deterministic_normal_down = functools.partial(deterministic_normal, idx_offset=2) + + +pec = egamma_scale_corrector.derive( + "pec", cls_dict={ + "source_field": "Photon", + "with_uncertainties": True, + "get_correction_file": (lambda self, external_files: external_files.photon_ss), + "get_scale_config": (lambda self: self.config_inst.x.pec), + }, +) + +per = egamma_resolution_corrector.derive( + "per", cls_dict={ + "source_field": "Photon", + "with_uncertainties": True, + # function to determine the correction file + "get_correction_file": (lambda self, external_files: external_files.photon_ss), + # function to determine the tec config + "get_resolution_config": (lambda self: self.config_inst.x.per), + }, +) + + +@calibrator( + uses={per, pec}, + produces={per, pec}, + with_uncertainties=True, + get_correction_file=None, + get_scale_config=None, + get_resolution_config=None, + deterministic_seed_index=-1, +) +def photons(self, events: ak.Array, **kwargs) -> ak.Array: + """ + Calibrator for photons. This calibrator runs the energy scale and resolution calibrators + for photons. + + Careful! Always apply resolution before scale corrections for MC. + """ + if self.dataset_inst.is_mc: + events = self[per](events, **kwargs) + + if self.with_uncertainties or self.dataset_inst.is_data: + events = self[pec](events, **kwargs) + + return events + + +@photons.init +def photons_init(self) -> None: + # forward argument to the producers + + if pec not in self.deps_kwargs: + self.deps_kwargs[pec] = dict() + if per not in self.deps_kwargs: + self.deps_kwargs[per] = dict() + self.deps_kwargs[pec]["with_uncertainties"] = self.with_uncertainties + self.deps_kwargs[per]["with_uncertainties"] = self.with_uncertainties + + self.deps_kwargs[per]["deterministic_seed_index"] = self.deterministic_seed_index + if self.get_correction_file is not None: + self.deps_kwargs[pec]["get_correction_file"] = self.get_correction_file + self.deps_kwargs[per]["get_correction_file"] = self.get_correction_file + + if self.get_resolution_config is not None: + self.deps_kwargs[per]["get_resolution_config"] = self.get_resolution_config + if self.get_scale_config is not None: + self.deps_kwargs[pec]["get_scale_config"] = self.get_scale_config + + +photons_nominal = photons.derive("photons_nominal", cls_dict={"with_uncertainties": False}) + + +eer = egamma_resolution_corrector.derive( + "eer", cls_dict={ + "source_field": "Electron", + # calculation of superclusterEta for electrons requires the deltaEtaSC + "uses": {"Electron.deltaEtaSC"}, + "with_uncertainties": True, + # function to determine the correction file + "get_correction_file": (lambda self, external_files: external_files.electron_ss), + # function to determine the tec config + "get_resolution_config": (lambda self: self.config_inst.x.eer), + }, +) + +eec = egamma_scale_corrector.derive( + "eec", cls_dict={ + "source_field": "Electron", + # calculation of superclusterEta for electrons requires the deltaEtaSC + "uses": {"Electron.deltaEtaSC"}, + "with_uncertainties": True, + "get_correction_file": (lambda self, external_files: external_files.electron_ss), + "get_scale_config": (lambda self: self.config_inst.x.eec), + }, +) + + +@calibrator( + uses={eer, eec}, + produces={eer, eec}, + with_uncertainties=True, + get_correction_file=None, + get_scale_config=None, + get_resolution_config=None, + deterministic_seed_index=-1, +) +def electrons(self, events: ak.Array, **kwargs) -> ak.Array: + """ + Calibrator for electrons. This calibrator runs the energy scale and resolution calibrators + for electrons. + + Careful! Always apply resolution before scale corrections for MC. + """ + if self.dataset_inst.is_mc: + events = self[eer](events, **kwargs) + + if self.with_uncertainties or self.dataset_inst.is_data: + events = self[eec](events, **kwargs) + + return events + + +@electrons.init +def electrons_init(self) -> None: + # forward argument to the producers + + if eec not in self.deps_kwargs: + self.deps_kwargs[eec] = dict() + if eer not in self.deps_kwargs: + self.deps_kwargs[eer] = dict() + self.deps_kwargs[eec]["with_uncertainties"] = self.with_uncertainties + self.deps_kwargs[eer]["with_uncertainties"] = self.with_uncertainties + + self.deps_kwargs[eer]["deterministic_seed_index"] = self.deterministic_seed_index + if self.get_correction_file is not None: + self.deps_kwargs[eec]["get_correction_file"] = self.get_correction_file + self.deps_kwargs[eer]["get_correction_file"] = self.get_correction_file + + if self.get_resolution_config is not None: + self.deps_kwargs[eer]["get_resolution_config"] = self.get_resolution_config + if self.get_scale_config is not None: + self.deps_kwargs[eec]["get_scale_config"] = self.get_scale_config + + +electrons_nominal = photons.derive("electrons_nominal", cls_dict={"with_uncertainties": False}) diff --git a/columnflow/columnar_util.py b/columnflow/columnar_util.py index 171ab3661..7470b2459 100644 --- a/columnflow/columnar_util.py +++ b/columnflow/columnar_util.py @@ -20,7 +20,7 @@ import multiprocessing import multiprocessing.pool from functools import partial -from collections import namedtuple, OrderedDict, deque +from collections import namedtuple, OrderedDict, deque, defaultdict import law import order as od @@ -1848,6 +1848,9 @@ def __init__( f"to set: {e.args[0]}", ) raise e + + # remove keyword from further processing + kwargs.pop(attr) else: try: deps = set(law.util.make_list(getattr(self.__class__, attr))) @@ -1862,11 +1865,15 @@ def __init__( # also register a set for storing instances, filled in create_dependencies setattr(self, f"{attr}_instances", set()) + # set all other keyword arguments as instance attributes + for attr, value in kwargs.items(): + setattr(self, attr, value) + # dictionary of dependency class to instance, set in create_dependencies self.deps = DotDict() # dictionary of keyword arguments mapped to dependenc classes to be forwarded to their init - self.deps_kwargs = DotDict() + self.deps_kwargs = defaultdict(dict) # TODO: avoid using `defaultdict` # deferred part of the initialization if deferred_init: @@ -1931,8 +1938,15 @@ def deferred_init(self, instance_cache: dict | None = None) -> dict: self.init_func() # instantiate dependencies again, but only perform updates - self.create_dependencies(instance_cache, only_update=True) + # self.create_dependencies(instance_cache, only_update=True) + + # NOTE: the above does not correctly propagate `deps_kwargs` to the dependencies. + # As a workaround, instantiate all dependencies fully a second time by + # invalidating the instance cache and setting `only_update` to False + instance_cache = {} + self.create_dependencies(instance_cache, only_update=False) + # NOTE: return value currently not being used anywhere -> remove? return instance_cache def create_dependencies( diff --git a/columnflow/hist_util.py b/columnflow/hist_util.py index 92a9ed42a..0f7eeaffb 100644 --- a/columnflow/hist_util.py +++ b/columnflow/hist_util.py @@ -6,6 +6,8 @@ from __future__ import annotations +__all__ = [] + import law import order as od diff --git a/columnflow/inference/__init__.py b/columnflow/inference/__init__.py index 7926a9f78..224e7818e 100644 --- a/columnflow/inference/__init__.py +++ b/columnflow/inference/__init__.py @@ -22,6 +22,11 @@ class ParameterType(enum.Enum): """ Parameter type flag. + + :cvar rate_gauss: Gaussian rate parameter. + :cvar rate_uniform: Uniform rate parameter. + :cvar rate_unconstrained: Unconstrained rate parameter. + :cvar shape: Shape parameter. """ rate_gauss = "rate_gauss" @@ -30,10 +35,20 @@ class ParameterType(enum.Enum): shape = "shape" def __str__(self: ParameterType) -> str: + """ + Returns the string representation of the parameter type. + + :returns: The string representation of the parameter type. + """ return self.value @property def is_rate(self: ParameterType) -> bool: + """ + Checks if the parameter type is a rate type. + + :returns: *True* if the parameter type is a rate type, *False* otherwise. + """ return self in ( self.rate_gauss, self.rate_uniform, @@ -42,6 +57,11 @@ def is_rate(self: ParameterType) -> bool: @property def is_shape(self: ParameterType) -> bool: + """ + Checks if the parameter type is a shape type. + + :returns: *True* if the parameter type is a shape type, *False* otherwise. + """ return self in ( self.shape, ) @@ -50,6 +70,15 @@ def is_shape(self: ParameterType) -> bool: class ParameterTransformation(enum.Enum): """ Flags denoting transformations to be applied on parameters. + + :cvar none: No transformation. + :cvar centralize: Centralize the parameter. + :cvar symmetrize: Symmetrize the parameter. + :cvar asymmetrize: Asymmetrize the parameter. + :cvar asymmetrize_if_large: Asymmetrize the parameter if it is large. + :cvar normalize: Normalize the parameter. + :cvar effect_from_shape: Derive effect from shape. + :cvar effect_from_rate: Derive effect from rate. """ none = "none" @@ -62,16 +91,31 @@ class ParameterTransformation(enum.Enum): effect_from_rate = "effect_from_rate" def __str__(self: ParameterTransformation) -> str: + """ + Returns the string representation of the parameter transformation. + + :returns: The string representation of the parameter transformation. + """ return self.value @property def from_shape(self: ParameterTransformation) -> bool: + """ + Checks if the transformation is derived from shape. + + :returns: *True* if the transformation is derived from shape, *False* otherwise. + """ return self in ( self.effect_from_shape, ) @property def from_rate(self: ParameterTransformation) -> bool: + """ + Checks if the transformation is derived from rate. + + :returns: *True* if the transformation is derived from rate, *False* otherwise. + """ return self in ( self.effect_from_rate, ) @@ -81,12 +125,20 @@ class ParameterTransformations(tuple): """ Container around a sequence of :py:class:`ParameterTransformation`'s with a few convenience methods. + + :param transformations: A sequence of :py:class:`ParameterTransformation` or their string names. """ def __new__( cls, transformations: Sequence[ParameterTransformation | str], ) -> ParameterTransformations: + """ + Creates a new instance of :py:class:`ParameterTransformations`. + + :param transformations: A sequence of :py:class:`ParameterTransformation` or their string names. + :returns: A new instance of :py:class:`ParameterTransformations`. + """ # TODO: at this point one could object / complain in case incompatible transfos are used transformations = [ (t if isinstance(t, ParameterTransformation) else ParameterTransformation[t]) @@ -98,10 +150,20 @@ def __new__( @property def any_from_shape(self: ParameterTransformations) -> bool: + """ + Checks if any transformation is derived from shape. + + :returns: *True* if any transformation is derived from shape, *False* otherwise. + """ return any(t.from_shape for t in self) @property def any_from_rate(self: ParameterTransformations) -> bool: + """ + Checks if any transformation is derived from rate. + + :returns: *True* if any transformation is derived from rate, *False* otherwise. + """ return any(t.from_rate for t in self) @@ -143,6 +205,7 @@ class InferenceModel(Derivable): is_signal: True config_mc_datasets: [hh_ggf] scale: 1.0 + is_dynamic: False parameters: - name: lumi type: rate_gauss @@ -161,6 +224,7 @@ class InferenceModel(Derivable): config_process: ttbar config_mc_datasets: [tt_sl, tt_dl, tt_fh] scale: 1.0 + is_dynamic: False parameters: - name: lumi type: rate_gauss @@ -197,7 +261,7 @@ class InferenceModel(Derivable): type: DotDict - The internal data structure representing the model. + The internal data structure representing the model, see :py:meth:`InferenceModel.model_spec`. """ # optional initialization method @@ -244,8 +308,12 @@ def inference_model( ) -> DerivableMeta | Callable: """ Decorator for creating a new :py:class:`InferenceModel` subclass with additional, optional - *bases* and attaching the decorated function to it as ``init_func``. All additional *kwargs* are - added as class members of the new subclasses. + *bases* and attaching the decorated function to it as ``init_func``. All additional *kwargs* + are added as class members of the new subclass. + + :param func: The function to be decorated and attached as ``init_func``. + :param bases: Optional tuple of base classes for the new subclass. + :returns: The new subclass or a decorator function. """ def decorator(func: Callable) -> DerivableMeta: # create the class dict @@ -290,18 +358,20 @@ def category_spec( Returns a dictionary representing a category (interchangeably called bin or channel in other tools), forwarding all arguments. - - *name*: The name of the category in the model. - - *config_category*: The name of the source category in the config to use. - - *config_variable*: The name of the variable in the config to use. - - *config_data_datasets*: List of names or patterns of datasets in the config to use for - real data. - - *data_from_processes*: Optional list of names of :py:meth:`process_spec` objects that, - when *config_data_datasets* is not defined, make of a fake data contribution. - - *flow_strategy*: A :py:class:`FlowStrategy` instance describing the strategy to handle + :param name: The name of the category in the model. + :param config_category: The name of the source category in the config to use. + :param config_variable: The name of the variable in the config to use. + :param config_data_datasets: List of names or patterns of datasets in the config to use for + real data. + :param data_from_processes: Optional list of names of :py:meth:`process_spec` objects that, + when *config_data_datasets* is not defined, make up a fake data contribution. + :param flow_strategy: A :py:class:`FlowStrategy` instance describing the strategy to handle over- and underflow bin contents. - - *mc_stats*: Either *None* to disable MC stat uncertainties, or an integer, a float or - a tuple of thereof to control the options of MC stat options. - - *empty_bin_value*: When bins are no content, they are filled with this value. + :param mc_stats: Either *None* to disable MC stat uncertainties, or an integer, a float or + a tuple thereof to control the options of MC stat options. + :param empty_bin_value: When bins have no content, they are filled with this value. + :returns: A dictionary representing the category. + """ return DotDict([ ("name", str(name)), @@ -327,15 +397,19 @@ def process_spec( is_signal: bool = False, config_mc_datasets: Sequence[str] | None = None, scale: float | int = 1.0, + is_dynamic: bool = False, ) -> DotDict: """ Returns a dictionary representing a process, forwarding all arguments. - - *name*: The name of the process in the model. - - *is_signal*: A boolean flag deciding whether this process describes signal. - - *config_process*: The name of the source process in the config to use. - - *config_mc_datasets*: List of names or patterns of MC datasets in the config to use. - - *scale*: A float value to scale the process, defaulting to 1.0. + :param name: The name of the process in the model. + :param is_signal: A boolean flag deciding whether this process describes signal. + :param config_process: The name of the source process in the config to use. + :param config_mc_datasets: List of names or patterns of MC datasets in the config to use. + :param scale: A float value to scale the process, defaulting to 1.0. + :param is_dynamic: A boolean flag deciding whether this process is dynamic, i.e., whether it + is created on-the-fly. + :returns: A dictionary representing the process. """ return DotDict([ ("name", str(name)), @@ -343,6 +417,7 @@ def process_spec( ("config_process", str(config_process) if config_process else None), ("config_mc_datasets", list(map(str, config_mc_datasets or []))), ("scale", float(scale)), + ("is_dynamic", bool(is_dynamic)), ("parameters", []), ]) @@ -358,14 +433,15 @@ def parameter_spec( """ Returns a dictionary representing a (nuisance) parameter, forwarding all arguments. - - *name*: The name of the parameter in the model. - - *type*: A :py:class:`ParameterType` instance describing the type of this parameter. - - *transformations*: A sequence of :py:class:`ParameterTransformation` instances - describing transformations to be applied to the effect of this parameter. - - *config_shift_source*: The name of a systematic shift source in the config that this - parameter corresponds to. - - *effect*: An arbitrary object describing the effect of the parameter (e.g. float for - symmetric rate effects, 2-tuple for down/up variation, etc). + :param name: The name of the parameter in the model. + :param type: A :py:class:`ParameterType` instance describing the type of this parameter. + :param transformations: A sequence of :py:class:`ParameterTransformation` instances + describing transformations to be applied to the effect of this parameter. + :param config_shift_source: The name of a systematic shift source in the config that this + parameter corresponds to. + :param effect: An arbitrary object describing the effect of the parameter (e.g. float for + symmetric rate effects, 2-tuple for down/up variation, etc). + :returns: A dictionary representing the parameter. """ return DotDict([ ("name", str(name)), @@ -384,8 +460,9 @@ def parameter_group_spec( """ Returns a dictionary representing a group of parameter names. - - *name*: The name of the parameter group in the model. - - *parameter_names*: Names of parameter objects this group contains. + :param name: The name of the parameter group in the model. + :param parameter_names: Names of parameter objects this group contains. + :returns: A dictionary representing the group of parameter names. """ return DotDict([ ("name", str(name)), @@ -395,8 +472,11 @@ def parameter_group_spec( @classmethod def require_shapes_for_parameter(self, param_obj: dict) -> bool: """ - Returns *True* if for a certain parameter object *param_obj* varied shapes are needed, and - *False* otherwise. + Function to check if for a certain parameter object *param_obj* varied + shapes are needed. + + :param param_obj: The parameter object to check. + :returns: *True* if varied shapes are needed, *False* otherwise. """ if param_obj.type.is_shape: # the shape might be build from a rate, in which case input shapes are not required @@ -435,6 +515,9 @@ def to_yaml(self, stream: TextIO | None = None) -> str | None: """ Writes the content of the :py:attr:`model` into a file-like object *stream* when given, and returns a string representation otherwise. + + :param stream: A file-like object to write the model content into. + :returns: A string representation of the model content if *stream* is not provided. """ return yaml.dump(self.model, stream=stream, Dumper=self.YamlDumper) @@ -469,6 +552,10 @@ def get_categories( Returns a list of categories whose name match *category*. *category* can be a string, a pattern, or sequence of them. When *only_names* is *True*, only names of categories are returned rather than structured dictionaries. + + :param category: A string, pattern, or sequence of them to match category names. + :param only_names: A boolean flag to return only names of categories if set to *True*. + :returns: A list of matching categories or their names. """ # rename arguments to make their meaning explicit category_pattern = category @@ -491,6 +578,12 @@ def get_category( pattern, or sequence of them. An exception is raised if no or more than one category is found, unless *silent* is *True* in which case *None* is returned. When *only_name* is *True*, only the name of the category is returned rather than a structured dictionary. + + :param category: A string, pattern, or sequence of them to match category names. + :param silent: A boolean flag to return *None* instead of raising an exception if no or + more than one category is found. + :param only_name: A boolean flag to return only the name of the category if set to *True*. + :returns: A single matching category or its name. """ # rename arguments to make their meaning explicit category_name = category @@ -516,6 +609,9 @@ def has_category( """ Returns *True* if a category whose name matches *category* is existing, and *False* otherwise. *category* can be a string, a pattern, or sequence of them. + + :param category: A string, pattern, or sequence of them to match category names. + :returns: *True* if a matching category exists, *False* otherwise. """ # rename arguments to make their meaning explicit category_pattern = category @@ -528,6 +624,8 @@ def add_category(self, *args, **kwargs) -> None: Adds a new category with all *args* and *kwargs* used to create the structured category dictionary via :py:meth:`category_spec`. If a category with the same name already exists, an exception is raised. + + :raises ValueError: If a category with the same name already exists. """ # create the object category = self.category_spec(*args, **kwargs) @@ -544,9 +642,10 @@ def remove_category( category: str | Sequence[str], ) -> bool: """ - Removes one or more categories whose names match *category*. Returns *True* if at least one - category was removed, and *False* otherwise. *category* can be a string, a pattern, or - sequence of them. + Removes one or more categories whose names match *category*. + + :param category: A string, pattern, or sequence of them to match category names. + :returns: *True* if at least one category was removed, *False* otherwise. """ # rename arguments to make their meaning explicit category_pattern = category @@ -584,6 +683,12 @@ def get_processes( When *only_names* is *True*, only names of processes are returned rather than structured dictionaries. When *flat* is *True*, a flat, unique list of process names is returned. + + :param process: A string, pattern, or sequence of them to match process names. + :param category: A string, pattern, or sequence of them to filter categories. + :param only_names: A boolean flag to return only names of processes if set to *True*. + :param flat: A boolean flag to return a flat, unique list of process names if set to *True*. + :returns: A dictionary of processes mapped to the category name, or a list of process names. """ # rename arguments to make their meaning explicit process_pattern = process @@ -624,13 +729,21 @@ def get_process( silent: bool = False, ) -> DotDict | str: """ - Returns a single process whose name matches *process*, and optionally, whose category's name - matches *category*. Both *process* and *category* can be a string, a pattern, or sequence of - them. + Returns a single process whose name matches *process*, and optionally, whose category's + name matches *category*. Both *process* and *category* can be a string, a pattern, or + sequence of them. - An exception is raised if no or more than one process is found, unless - *silent* is *True* in which case *None* is returned. When *only_name* is *True*, only the - name of the process is returned rather than a structured dictionary. + An exception is raised if no or more than one process is found, unless *silent* is *True* + in which case *None* is returned. When *only_name* is *True*, only the name of the + process is returned rather than a structured dictionary. + + :param process: A string, pattern, or sequence of them to match process names. + :param category: A string, pattern, or sequence of them to match category names. + :param silent: A boolean flag to return *None* instead of raising an exception if no or + more than one process is found. + :param only_name: A boolean flag to return only the name of the process if set to *True*. + :returns: A single matching process or its name. + :raises ValueError: If no process or more than one process is found and *silent* is *False*. """ # rename arguments to make their meaning explicit process_name = process @@ -676,8 +789,12 @@ def has_process( ) -> bool: """ Returns *True* if a process whose name matches *process*, and optionally whose category's - name matches *category*, is existing, and *False* otherwise. Both *process* and *category* - can be a string, a pattern, or sequence of them. + name matches *category*, exists, and *False* otherwise. Both *process* and *category* can + be a string, a pattern, or sequence of them. + + :param process: A string, pattern, or sequence of them to match process names. + :param category: A string, pattern, or sequence of them to match category names. + :returns: *True* if a matching process exists, *False* otherwise. """ # rename arguments to make their meaning explicit process_pattern = process @@ -700,6 +817,14 @@ def add_process( If a process with the same name already exists in one of the categories, an exception is raised unless *silent* is *True*. + + :param args: Positional arguments used to create the process. + :param category: A string, pattern, or sequence of them to match category names. + :param silent: A boolean flag to suppress exceptions if a process with the same name + already exists. + :param kwargs: Keyword arguments used to create the process. + :raises ValueError: If a process with the same name already exists in one of the + categories and *silent* is *False*. """ # rename arguments to make their meaning explicit category_pattern = category @@ -732,8 +857,12 @@ def remove_process( ) -> bool: """ Removes one or more processes whose names match *process*, and optionally whose category's - name match *category*. Both *process* and *category* can be a string, a pattern, or sequence - of them. Returns *True* if at least one process was removed, and *False* otherwise. + name matches *category*. Both *process* and *category* can be a string, a pattern, or + sequence of them. Returns *True* if at least one process was removed, and *False* otherwise. + + :param process: A string, pattern, or sequence of them to match process names. + :param category: A string, pattern, or sequence of them to match category names. + :returns: *True* if at least one process was removed, *False* otherwise. """ # rename arguments to make their meaning explicit process_pattern = process @@ -772,13 +901,21 @@ def get_parameters( flat: bool = False, ) -> dict[str, dict[str, DotDict | str]] | list[str]: """ - Returns a dictionary of parameter whose names match *parameter*, mapped twice to the name of - the category and the name of the process they belong to. Categories and processes can + Returns a dictionary of parameters whose names match *parameter*, mapped twice to the name + of the category and the name of the process they belong to. Categories and processes can optionally be filtered through *category* and *process*. All three, *parameter*, *process* and *category* can be a string, a pattern, or sequence of them. When *only_names* is *True*, only names of parameters are returned rather than structured dictionaries. When *flat* is *True*, a flat, unique list of parameter names is returned. + + :param parameter: A string, pattern, or sequence of them to match parameter names. + :param process: A string, pattern, or sequence of them to match process names. + :param category: A string, pattern, or sequence of them to match category names. + :param only_names: A boolean flag to return only names of parameters if set to *True*. + :param flat: A boolean flag to return a flat, unique list of parameter names if set to *True*. + :returns: A dictionary of parameters mapped to category and process names, or a list of + parameter names. """ # rename arguments to make their meaning explicit parameter_pattern = parameter @@ -835,6 +972,15 @@ def get_parameter( An exception is raised if no or more than one parameter is found, unless *silent* is *True* in which case *None* is returned. When *only_name* is *True*, only the name of the parameter is returned rather than a structured dictionary. + + :param parameter: A string, pattern, or sequence of them to match parameter names. + :param process: A string, pattern, or sequence of them to match process names. + :param category: A string, pattern, or sequence of them to match category names. + :param only_name: A boolean flag to return only the name of the parameter if set to *True*. + :param silent: A boolean flag to return *None* instead of raising an exception if no or more + than one parameter is found. + :returns: A single matching parameter or its name. + :raises ValueError: If no parameter or more than one parameter is found and *silent* is *False*. """ # rename arguments to make their meaning explicit parameter_name = parameter @@ -895,9 +1041,14 @@ def has_parameter( ) -> bool: """ Returns *True* if a parameter whose name matches *parameter*, and optionally whose - category's and process' name match *category* and *process*, is existing, and *False* - otherwise. All three, *parameter*, *process* and *category* can be a string, a pattern, or - sequence of them. + category's and process' name match *category* and *process*, exists, and *False* + otherwise. All three, *parameter*, *process* and *category* can be a string, a pattern, + or sequence of them. + + :param parameter: A string, pattern, or sequence of them to match parameter names. + :param process: A string, pattern, or sequence of them to match process names. + :param category: A string, pattern, or sequence of them to match category names. + :returns: *True* if a matching parameter exists, *False* otherwise. """ # rename arguments to make their meaning explicit parameter_pattern = parameter @@ -930,6 +1081,15 @@ def add_parameter( If a parameter with the same name already exists in one of the processes throughout the categories, an exception is raised. + + :param args: Positional arguments used to create the parameter. + :param process: A string, pattern, or sequence of them to match process names. + :param category: A string, pattern, or sequence of them to match category names. + :param group: A string, pattern, or sequence of them to specify parameter groups. + :param kwargs: Keyword arguments used to create the parameter. + :returns: The created parameter. + :raises ValueError: If a parameter with the same name already exists in one of the processes + throughout the categories. """ # rename arguments to make their meaning explicit process_pattern = process @@ -969,8 +1129,12 @@ def remove_parameter( """ Removes one or more parameters whose names match *parameter*, and optionally whose category's and process' name match *category* and *process*. All three, *parameter*, - *process* and *category* can be a string, a pattern, or sequence of them. Returns *True* if - at least one parameter was removed, and *False* otherwise. + *process* and *category* can be a string, a pattern, or sequence of them. + + :param parameter: A string, pattern, or sequence of them to match parameter names. + :param process: A string, pattern, or sequence of them to match process names. + :param category: A string, pattern, or sequence of them to match category names. + :returns: *True* if at least one parameter was removed, *False* otherwise. """ # rename arguments to make their meaning explicit parameter_pattern = parameter @@ -1008,11 +1172,15 @@ def get_parameter_groups( only_names: bool = False, ) -> list[DotDict | str]: """ - Returns a list of parameter groups whose name match *group*. *group* can be a string, a + Returns a list of parameter groups whose names match *group*. *group* can be a string, a pattern, or sequence of them. When *only_names* is *True*, only names of parameter groups are returned rather than structured dictionaries. + + :param group: A string, pattern, or sequence of them to match group names. + :param only_names: A boolean flag to return only names of parameter groups if set to *True*. + :returns: A list of parameter groups or their names. """ # rename arguments to make their meaning explicit group_pattern = group @@ -1036,6 +1204,11 @@ def get_parameter_group( An exception is raised in case no or more than one parameter group is found. When *only_name* is *True*, only the name of the parameter group is returned rather than a structured dictionary. + + :param group: A string, pattern, or sequence of them to match group names. + :param only_name: A boolean flag to return only the name of the parameter group if set to *True*. + :returns: A single matching parameter group or its name. + :raises ValueError: If no parameter group or more than one parameter group is found. """ # rename arguments to make their meaning explicit group_name = group @@ -1055,8 +1228,11 @@ def has_parameter_group( group: str | Sequence[str], ) -> bool: """ - Returns *True* if a parameter group whose name matches *group* is existing, and *False* + Returns *True* if a parameter group whose name matches *group* exists, and *False* otherwise. *group* can be a string, a pattern, or sequence of them. + + :param group: A string, pattern, or sequence of them to match group names. + :returns: *True* if a matching parameter group exists, *False* otherwise. """ # rename arguments to make their meaning explicit group_pattern = group @@ -1069,6 +1245,10 @@ def add_parameter_group(self, *args, **kwargs) -> None: Adds a new parameter group with all *args* and *kwargs* used to create the structured parameter group dictionary via :py:meth:`parameter_group_spec`. If a group with the same name already exists, an exception is raised. + + :param args: Positional arguments used to create the parameter group. + :param kwargs: Keyword arguments used to create the parameter group. + :raises ValueError: If a parameter group with the same name already exists. """ # create the instance group = self.parameter_group_spec(*args, **kwargs) @@ -1087,6 +1267,9 @@ def remove_parameter_group( Removes one or more parameter groups whose names match *group*. *group* can be a string, a pattern, or sequence of them. Returns *True* if at least one group was removed, and *False* otherwise. + + :param group: A string, pattern, or sequence of them to match group names. + :returns: *True* if at least one group was removed, *False* otherwise. """ # rename arguments to make their meaning explicit group_pattern = group @@ -1108,11 +1291,15 @@ def add_parameter_to_group( group: str | Sequence[str], ) -> bool: """ - Adds a parameter named *parameter* to one or multiple parameter groups whose name match + Adds a parameter named *parameter* to one or multiple parameter groups whose names match *group*. *group* can be a string, a pattern, or sequence of them. When *parameter* is a pattern or regular expression, all previously added, matching parameters are added. - Otherwise, *parameter* is added as as. If a parameter was added to at least one group, + Otherwise, *parameter* is added as is. If a parameter was added to at least one group, *True* is returned and *False* otherwise. + + :param parameter: A string, pattern, or sequence of them to match parameter names. + :param group: A string, pattern, or sequence of them to match group names. + :returns: *True* if at least one parameter was added to a group, *False* otherwise. """ # rename arguments to make their meaning explicit parameter_pattern = parameter @@ -1150,6 +1337,10 @@ def remove_parameter_from_groups( Removes all parameters matching *parameter* from parameter groups whose names match *group*. Both *parameter* and *group* can be a string, a pattern, or sequence of them. Returns *True* if at least one parameter was removed, and *False* otherwise. + + :param parameter: A string, pattern, or sequence of them to match parameter names. + :param group: A string, pattern, or sequence of them to match group names. + :returns: *True* if at least one parameter was removed, *False* otherwise. """ # rename arguments to make their meaning explicit parameter_pattern = parameter @@ -1184,6 +1375,9 @@ def get_categories_with_process( """ Returns a flat list of category names that contain processes matching *process*. *process* can be a string, a pattern, or sequence of them. + + :param process: A string, pattern, or sequence of them to match process names. + :returns: A list of category names containing matching processes. """ # rename arguments to make their meaning explicit process_pattern = process @@ -1199,10 +1393,15 @@ def get_processes_with_parameter( ) -> list[str] | dict[str, list[str]]: """ Returns a dictionary of names of processes that contain a parameter whose names match - *parameter*, mapped to categories names. Categories can optionally be filtered through + *parameter*, mapped to category names. Categories can optionally be filtered through *category*. Both *parameter* and *category* can be a string, a pattern, or sequence of them. When *flat* is *True*, a flat, unique list of process names is returned. + + :param parameter: A string, pattern, or sequence of them to match parameter names. + :param category: A string, pattern, or sequence of them to match category names. + :param flat: A boolean flag to return a flat, unique list of process names if set to *True*. + :returns: A dictionary of process names mapped to category names, or a flat list of process names. """ # rename arguments to make their meaning explicit parameter_pattern = parameter @@ -1234,10 +1433,15 @@ def get_categories_with_parameter( ) -> list[str] | dict[str, list[str]]: """ Returns a dictionary of category names mapping to process names that contain parameters - whose name match *parameter*. Processes can optionally be filtered through *process*. Both + whose names match *parameter*. Processes can optionally be filtered through *process*. Both *parameter* and *process* can be a string, a pattern, or sequence of them. When *flat* is *True*, a flat, unique list of category names is returned. + + :param parameter: A string, pattern, or sequence of them to match parameter names. + :param process: A string, pattern, or sequence of them to match process names. + :param flat: A boolean flag to return a flat, unique list of category names if set to *True*. + :returns: A dictionary of category names mapped to process names, or a flat list of category names. """ # rename arguments to make their meaning explicit parameter_pattern = parameter @@ -1267,7 +1471,10 @@ def get_groups_with_parameter( ) -> list[str]: """ Returns a list of names of parameter groups that contain a parameter whose name matches - *parameter*, which can be a string, a pattern, or sequence of them. + *parameter*. *parameter* can be a string, a pattern, or sequence of them. + + :param parameter: A string, pattern, or sequence of them to match parameter names. + :returns: A list of names of parameter groups containing the matching parameter. """ # rename arguments to make their meaning explicit parameter_pattern = parameter @@ -1291,6 +1498,8 @@ def cleanup( Cleans the internal model structure by removing empty and dangling objects by calling :py:meth:`remove_empty_categories`, :py:meth:`remove_dangling_parameters_from_groups` (receiving *keep_parameters*), and :py:meth:`remove_empty_parameter_groups` in that order. + + :param keep_parameters: A string, pattern, or sequence of them to specify parameters to keep. """ self.remove_empty_categories() self.remove_dangling_parameters_from_groups(keep_parameters=keep_parameters) @@ -1313,6 +1522,8 @@ def remove_dangling_parameters_from_groups( """ Removes names of parameters from parameter groups that are not assigned to any process in any category. + + :param keep_parameters: A string, pattern, or sequence of them to specify parameters to keep. """ # get a list of all parameters parameter_names = self.get_parameters("*", flat=True) @@ -1354,7 +1565,11 @@ def iter_processes( """ Generator that iteratively yields all processes whose names match *process*, optionally in all categories whose names match *category*. The yielded value is a 2-tuple containing - the cagegory name and the process object. + the category name and the process object. + + :param process: A string, pattern, or sequence of them to match process names. + :param category: A string, pattern, or sequence of them to match category names. + :returns: A generator yielding 2-tuples of category name and process object. """ processes = self.get_processes(process=process, category=category) for category_name, processes in processes.items(): @@ -1370,7 +1585,12 @@ def iter_parameters( """ Generator that iteratively yields all parameters whose names match *parameter*, optionally in all processes and categories whose names match *process* and *category*. The yielded - value is a 3-tuple containing the cagegory name, the process name and the parameter object. + value is a 3-tuple containing the category name, the process name, and the parameter object. + + :param parameter: A string, pattern, or sequence of them to match parameter names. + :param process: A string, pattern, or sequence of them to match process names. + :param category: A string, pattern, or sequence of them to match category names. + :returns: A generator yielding 3-tuples of category name, process name, and parameter object. """ parameters = self.get_parameters(parameter=parameter, process=process, category=category) for category_name, parameters in parameters.items(): @@ -1390,8 +1610,12 @@ def scale_process( ) -> bool: """ Sets the scale attribute of all processes whose names match *process*, optionally in all - categories whose names match *category*, to *scale*. Returns *True* if at least one process - was found and scale, and *False* otherwise. + categories whose names match *category*, to *scale*. + + :param scale: The scale value to set for the matching processes. + :param process: A string, pattern, or sequence of them to match process names. + :param category: A string, pattern, or sequence of them to match category names. + :returns: *True* if at least one process was found and scaled, *False* otherwise. """ found = False for _, process in self.iter_processes(process=process, category=category): diff --git a/columnflow/inference/cms/datacard.py b/columnflow/inference/cms/datacard.py index 241834b69..aa41180e5 100644 --- a/columnflow/inference/cms/datacard.py +++ b/columnflow/inference/cms/datacard.py @@ -16,7 +16,7 @@ from columnflow.inference import ( InferenceModel, ParameterType, ParameterTransformation, FlowStrategy, ) -from columnflow.util import DotDict, maybe_import, real_path, ensure_dir, safe_div +from columnflow.util import DotDict, maybe_import, real_path, ensure_dir, safe_div, maybe_int np = maybe_import("np") hist = maybe_import("hist") @@ -104,7 +104,7 @@ def write( blocks.observations = [ ("bin", list(rates)), ("observation", [ - int(round(_rates["data"], self.rate_precision)) + maybe_int(round(_rates["data"], self.rate_precision)) for _rates in rates.values() ]), ] @@ -539,8 +539,27 @@ def get_shapes(param_name): safe_div(integral(h_up), integral(h_nom)), ) - # dedicated data handling - if cat_obj.config_data_datasets: + # data handling, first checking if data should be faked, then if real data exists + if cat_obj.data_from_processes: + # fake data from processes + h_data = [] + for proc_name in cat_obj.data_from_processes: + if proc_name not in hists: + logger.warning( + f"process '{proc_name}' not found in histograms for created fake data, " + "skipping", + ) + continue + h_data.append(hists[proc_name]["nominal"]) + if not h_data: + proc_str = ",".join(map(str, cat_obj.data_from_processes)) + raise Exception(f"no requested process '{proc_str}' found to create fake data") + h_data = sum(h_data[1:], h_data[0].copy()) + data_name = data_pattern.format(category=cat_name) + out_file[data_name] = h_data + _rates["data"] = float(h_data.sum().value) + + elif cat_obj.config_data_datasets: if "data" not in hists: raise Exception( f"the inference model '{self.inference_model_inst.name}' is configured to " @@ -555,14 +574,6 @@ def get_shapes(param_name): out_file[data_name] = h_data _rates["data"] = h_data.sum().value - elif cat_obj.data_from_processes: - # fake data from processes - h_data = [hists[proc_name]["nominal"] for proc_name in cat_obj.data_from_processes] - h_data = sum(h_data[1:], h_data[0].copy()) - data_name = data_pattern.format(category=cat_name) - out_file[data_name] = h_data - _rates["data"] = int(round(h_data.sum().value)) - return (rates, effects, nom_pattern_comb, syst_pattern_comb) @classmethod diff --git a/columnflow/plotting/plot_all.py b/columnflow/plotting/plot_all.py index a30db1e58..5de7d1591 100644 --- a/columnflow/plotting/plot_all.py +++ b/columnflow/plotting/plot_all.py @@ -6,9 +6,15 @@ from __future__ import annotations +__all__ = [] + from columnflow.types import Sequence from columnflow.util import maybe_import, try_float -from columnflow.plotting.plot_util import get_position, get_cms_label +from columnflow.plotting.plot_util import ( + get_position, + get_cms_label, + remove_label_placeholders, +) hist = maybe_import("hist") np = maybe_import("numpy") @@ -175,6 +181,7 @@ def plot_all( "ratio_kwargs": dict (optional), The *style_config* expects fields (all optional): + "gridspec_cfg": dict, "ax_cfg": dict, "rax_cfg": dict, "legend_cfg": dict, @@ -192,37 +199,43 @@ def plot_all( with a logarithmic scale. :return: tuple of plot figure and axes """ - # available plot methods mapped to their names - plot_methods = { - func.__name__: func - for func in [draw_error_bands, draw_stack, draw_hist, draw_profile, draw_errorbars] - } - + # general mplhep style plt.style.use(mplhep.style.CMS) + # setup figure and axes rax = None + grid_spec = {"left": 0.15, "right": 0.95, "top": 0.95, "bottom": 0.1} + grid_spec |= style_config.get("gridspec_cfg", {}) if not skip_ratio: - fig, axs = plt.subplots(2, 1, gridspec_kw=dict(height_ratios=[3, 1], hspace=0), sharex=True) + grid_spec |= {"height_ratios": [3, 1], "hspace": 0} + fig, axs = plt.subplots(2, 1, gridspec_kw=grid_spec, sharex=True) (ax, rax) = axs else: - fig, ax = plt.subplots() + fig, ax = plt.subplots(gridspec_kw=grid_spec) axs = (ax,) + # invoke all plots methods + plot_methods = { + func.__name__: func + for func in [draw_error_bands, draw_stack, draw_hist, draw_profile, draw_errorbars] + } for key, cfg in plot_config.items(): + # check if required fields are present if "method" not in cfg: raise ValueError(f"no method given in plot_cfg entry {key}") - method = cfg["method"] - if "hist" not in cfg: raise ValueError(f"no histogram(s) given in plot_cfg entry {key}") - hist = cfg["hist"] - kwargs = cfg.get("kwargs", {}) - plot_methods[method](ax, hist, **kwargs) + # invoke the method + method = cfg["method"] + h = cfg["hist"] + plot_methods[method](ax, h, **cfg.get("kwargs", {})) + + # repeat for ratio axes if configured if not skip_ratio and "ratio_kwargs" in cfg: # take ratio_method if the ratio plot requires a different plotting method method = cfg.get("ratio_method", method) - plot_methods[method](rax, hist, **cfg["ratio_kwargs"]) + plot_methods[method](rax, h, **cfg.get("ratio_kwargs", {})) # axis styling ax_kwargs = { @@ -241,17 +254,26 @@ def plot_all( # prioritize style_config ax settings ax_kwargs.update(style_config.get("ax_cfg", {})) - # ax configs that can not be handled by `ax.set` - minorxticks = ax_kwargs.pop("minorxticks", None) - minoryticks = ax_kwargs.pop("minoryticks", None) + # some settings cannot be handled by ax.set + xminorticks = ax_kwargs.pop("xminorticks", ax_kwargs.pop("minorxticks", None)) + yminorticks = ax_kwargs.pop("yminorticks", ax_kwargs.pop("minoryticks", None)) + xloc = ax_kwargs.pop("xloc", None) + yloc = ax_kwargs.pop("yloc", None) + # set all values ax.set(**ax_kwargs) - if minorxticks is not None: - ax.set_xticks(minorxticks, minor=True) - if minoryticks is not None: - ax.set_xticks(minoryticks, minor=True) - + # set manual configs + if xminorticks is not None: + ax.set_xticks(xminorticks, minor=True) + if yminorticks is not None: + ax.set_xticks(yminorticks, minor=True) + if xloc is not None: + ax.set_xlabel(ax.get_xlabel(), loc=xloc) + if yloc is not None: + ax.set_ylabel(ax.get_ylabel(), loc=yloc) + + # ratio plot if not skip_ratio: # hard-coded line at 1 rax.axhline(y=1.0, linestyle="dashed", color="gray") @@ -262,11 +284,25 @@ def plot_all( "yscale": "linear", } rax_kwargs.update(style_config.get("rax_cfg", {})) + + # some settings cannot be handled by ax.set + xloc = rax_kwargs.pop("xloc", None) + yloc = rax_kwargs.pop("yloc", None) + + # set all values rax.set(**rax_kwargs) + # set manual configs + if xloc is not None: + rax.set_xlabel(rax.get_xlabel(), loc=xloc) + if yloc is not None: + rax.set_ylabel(rax.get_ylabel(), loc=yloc) + + # remove x-label from main axis if "xlabel" in rax_kwargs: ax.set_xlabel("") + # label alignment fig.align_labels() # legend @@ -284,6 +320,9 @@ def plot_all( legend_kwargs.update(style_config.get("legend_cfg", {})) + if "title" in legend_kwargs: + legend_kwargs["title"] = remove_label_placeholders(legend_kwargs["title"]) + # retrieve the legend handles and their labels handles, labels = ax.get_legend_handles_labels() @@ -310,22 +349,6 @@ def plot_all( if callable(update_handles_labels): update_handles_labels(ax, handles, labels, n_cols) - # assume all `StepPatch` objects are part of MC stack - in_stack = [ - isinstance(handle, mpl.patches.StepPatch) - for handle in handles - ] - - # reverse order of entries that are part of the stack - if any(in_stack): - def revere_entries(entries, mask): - entries = np.array(entries, dtype=object) - entries[mask] = entries[mask][::-1] - return list(entries) - - handles = revere_entries(handles, in_stack) - labels = revere_entries(labels, in_stack) - # make legend using ordered handles/labels ax.legend(handles, labels, **legend_kwargs) @@ -353,6 +376,7 @@ def revere_entries(entries, mask): cms_label_kwargs.update(style_config.get("cms_label_cfg", {})) mplhep.cms.label(**cms_label_kwargs) + # finalization fig.tight_layout() return fig, axs diff --git a/columnflow/plotting/plot_functions_1d.py b/columnflow/plotting/plot_functions_1d.py index 0c60ff0fb..8a0911b24 100644 --- a/columnflow/plotting/plot_functions_1d.py +++ b/columnflow/plotting/plot_functions_1d.py @@ -6,6 +6,8 @@ from __future__ import annotations +__all__ = [] + from collections import OrderedDict import law @@ -14,7 +16,7 @@ from columnflow.util import maybe_import from columnflow.plotting.plot_all import plot_all from columnflow.plotting.plot_util import ( - prepare_plot_config, + prepare_stack_plot_config, prepare_style_config, remove_residual_axis, apply_variable_settings, @@ -45,7 +47,6 @@ def plot_variable_per_process( density: bool | None = False, shape_norm: bool | None = False, yscale: str | None = "", - hide_errors: bool | None = None, process_settings: dict | None = None, variable_settings: dict | None = None, **kwargs, @@ -336,7 +337,6 @@ def plot_cutflow( density: bool | None = False, shape_norm: bool = False, yscale: str | None = None, - hide_errors: bool | None = None, process_settings: dict | None = None, **kwargs, ) -> plt.Figure: @@ -350,11 +350,7 @@ def plot_cutflow( hists = hists_merge_cutflow_steps(hists) # setup plotting config - plot_config = prepare_plot_config( - hists, - shape_norm=shape_norm, - hide_errors=hide_errors, - ) + plot_config = prepare_stack_plot_config(hists, shape_norm=shape_norm, **kwargs) if shape_norm: # switch normalization to normalizing to `initial step` bin diff --git a/columnflow/plotting/plot_util.py b/columnflow/plotting/plot_util.py index 0dc3627ef..bb9dccb70 100644 --- a/columnflow/plotting/plot_util.py +++ b/columnflow/plotting/plot_util.py @@ -6,6 +6,8 @@ from __future__ import annotations +__all__ = [] + import re import operator import functools @@ -266,11 +268,15 @@ def get_stack_integral() -> float: ) # remove remaining placeholders - proc_inst.label = re.sub("__[A-Z0-9]+__", "", proc_inst.label) + proc_inst.label = remove_label_placeholders(proc_inst.label) return hists +def remove_label_placeholders(label: str) -> str: + return re.sub("__[A-Z0-9]+__", "", label) + + def apply_variable_settings( hists: dict, variable_insts: list[od.Variable], @@ -450,7 +456,7 @@ def prepare_style_config( "legend_cfg": {}, "annotate_cfg": {"text": cat_label or ""}, "cms_label_cfg": { - "lumi": round(0.001 * config_inst.x.luminosity.get("nominal"), 2), # /pb -> /fb + "lumi": f"{0.001 * config_inst.x.luminosity.get('nominal'):.1f}", # /pb -> /fb "com": config_inst.campaign.ecm, }, } @@ -467,17 +473,17 @@ def prepare_style_config( return style_config -def prepare_plot_config( +def prepare_stack_plot_config( hists: OrderedDict, shape_norm: bool | None = False, hide_errors: bool | None = None, + **kwargs, ) -> OrderedDict: """ Prepares a plot config with one entry to create plots containing a stack of backgrounds with uncertainty bands, unstacked processes as lines and data entrys with errorbars. """ - # separate histograms into stack, lines and data hists mc_hists, mc_colors, mc_edgecolors, mc_labels = [], [], [], [] line_hists, line_colors, line_labels, line_hide_errors = [], [], [], [] @@ -511,9 +517,7 @@ def prepare_plot_config( h_data = sum(data_hists[1:], data_hists[0].copy()) if mc_hists: h_mc = sum(mc_hists[1:], mc_hists[0].copy()) - # reverse hists when building MC stack so that the - # first process is on top - h_mc_stack = hist.Stack(*mc_hists[::-1]) + h_mc_stack = hist.Stack(*mc_hists) # setup plotting configs plot_config = OrderedDict() @@ -526,10 +530,10 @@ def prepare_plot_config( "hist": h_mc_stack, "kwargs": { "norm": mc_norm, - "label": mc_labels[::-1], - "color": mc_colors[::-1], - "edgecolor": mc_edgecolors[::-1], - "linewidth": [(0 if c is None else 1) for c in mc_colors[::-1]], + "label": mc_labels, + "color": mc_colors, + "edgecolor": mc_edgecolors, + "linewidth": [(0 if c is None else 1) for c in mc_colors], }, } diff --git a/columnflow/production/cms/electron.py b/columnflow/production/cms/electron.py index 635b14124..0773124c8 100644 --- a/columnflow/production/cms/electron.py +++ b/columnflow/production/cms/electron.py @@ -9,7 +9,7 @@ from dataclasses import dataclass from columnflow.production import Producer, producer -from columnflow.util import maybe_import, InsertableDict +from columnflow.util import maybe_import, InsertableDict, load_correction_set from columnflow.columnar_util import set_ak_column, flat_np_view, layout_ak_array np = maybe_import("numpy") @@ -37,12 +37,11 @@ def new( # purely for backwards compatibility with the old tuple format if isinstance(obj, cls): return obj - elif isinstance(obj, list) or isinstance(obj, tuple) or isinstance(obj, set): + if isinstance(obj, list) or isinstance(obj, tuple) or isinstance(obj, set): return cls(*obj) - elif isinstance(obj, dict): + if isinstance(obj, dict): return cls(**obj) - else: - raise ValueError(f"cannot convert {obj} to ElectronSFConfig") + raise ValueError(f"cannot convert {obj} to ElectronSFConfig") @producer( @@ -111,11 +110,7 @@ def electron_weights( } # loop over systematics - for syst, postfix in [ - ("sf", ""), - ("sfup", "_up"), - ("sfdown", "_down"), - ]: + for syst, postfix in zip(self.sf_variations, ["", "_up", "_down"]): # get the inputs for this type of variation variable_map_syst = { **variable_map, @@ -160,15 +155,18 @@ def electron_weights_setup( ) -> None: bundle = reqs["external_files"] - # create the corrector - import correctionlib - correctionlib.highlevel.Correction.__call__ = correctionlib.highlevel.Correction.evaluate - correction_set = correctionlib.CorrectionSet.from_string( - self.get_electron_file(bundle.files).load(formatter="gzip").decode("utf-8"), - ) + # load the corrector + correction_set = load_correction_set(self.get_electron_file(bundle.files)) + self.electron_config: ElectronSFConfig = self.get_electron_config() self.electron_sf_corrector = correction_set[self.electron_config.correction] + # the ValType key accepts different arguments for efficiencies and scale factors + if self.electron_config.correction.endswith("Eff"): + self.sf_variations = ["nom", "up", "down"] + else: + self.sf_variations = ["sf", "sfup", "sfdown"] + # check versions if self.supported_versions and self.electron_sf_corrector.version not in self.supported_versions: raise Exception(f"unsupported electron sf corrector version {self.electron_sf_corrector.version}") diff --git a/columnflow/production/cms/muon.py b/columnflow/production/cms/muon.py index 908517840..683c34975 100644 --- a/columnflow/production/cms/muon.py +++ b/columnflow/production/cms/muon.py @@ -9,7 +9,7 @@ from dataclasses import dataclass from columnflow.production import Producer, producer -from columnflow.util import maybe_import, InsertableDict +from columnflow.util import maybe_import, InsertableDict, load_correction_set from columnflow.columnar_util import set_ak_column, flat_np_view, layout_ak_array np = maybe_import("numpy") @@ -146,12 +146,9 @@ def muon_weights_setup( ) -> None: bundle = reqs["external_files"] - # create the corrector - import correctionlib - correctionlib.highlevel.Correction.__call__ = correctionlib.highlevel.Correction.evaluate - correction_set = correctionlib.CorrectionSet.from_string( - self.get_muon_file(bundle.files).load(formatter="gzip").decode("utf-8"), - ) + # load the corrector + correction_set = load_correction_set(self.get_muon_file(bundle.files)) + self.muon_config: MuonSFConfig = self.get_muon_config() self.muon_sf_corrector = correction_set[self.muon_config.correction] diff --git a/columnflow/production/cms/seeds.py b/columnflow/production/cms/seeds.py index a199c0d60..75c684a56 100644 --- a/columnflow/production/cms/seeds.py +++ b/columnflow/production/cms/seeds.py @@ -7,6 +7,7 @@ from __future__ import annotations import hashlib +import abc import law @@ -54,7 +55,7 @@ def create_seed(val: int, n_hex: int = 16) -> int: "Jet.nConstituents", "Jet.nElectrons", "Jet.nMuons", ])), ) -def deterministic_event_seeds(self: Producer, events: ak.Array, **kwargs) -> ak.Array: +def deterministic_event_seeds(self, events: ak.Array, **kwargs) -> ak.Array: """ Produces deterministic event seeds and stores them in *events* which is also returned. @@ -136,7 +137,7 @@ def deterministic_event_seeds(self: Producer, events: ak.Array, **kwargs) -> ak. @deterministic_event_seeds.init -def deterministic_event_seeds_init(self: Producer) -> None: +def deterministic_event_seeds_init(self) -> None: """ Producer initialization that adds columns to the set of *used* columns based on the *event_columns*, *object_count_columns*, and *object_columns* lists. @@ -148,7 +149,7 @@ def deterministic_event_seeds_init(self: Producer) -> None: @deterministic_event_seeds.setup def deterministic_event_seeds_setup( - self: Producer, + self, reqs: dict, inputs: dict, reader_targets: InsertableDict, @@ -175,58 +176,109 @@ def apply_route(ak_array: ak.Array, route: Route) -> ak.Array | None: self.apply_route = apply_route -@producer( - uses={"Jet.pt"}, - produces={"Jet.deterministic_seed"}, -) -def deterministic_jet_seeds(self: Producer, events: ak.Array, **kwargs) -> ak.Array: - """ - Produces deterministic seeds for each jet and stores them in *events* which is also returned. - The jet seeds are based on the event seeds like the ones produced by - :py:func:`deterministic_event_seeds` which is not called by this producer for the purpose of - of modularity. The strategy for producing seeds is identical. +class deterministic_object_seeds(Producer): - .. note:: + @property + @abc.abstractmethod + def object_field(self) -> str: + ... - The jet seeds depend on the position of the particular jet in the event. It is up to the - user to bring them into the desired order before invoking this producer. - """ - # create the seeds - primes = self.primes[events.deterministic_seed % len(self.primes)] - jet_seed = events.deterministic_seed + ( - primes * ak.values_astype(ak.local_index(events.Jet, axis=1) + self.primes[50], np.uint64) - ) - np_jet_seed = np.asarray(ak.flatten(jet_seed)) - np_jet_seed[:] = create_seed_vec(np_jet_seed) + @property + @abc.abstractmethod + def prime_offset(self) -> int: + ... - # store them - events = set_ak_column(events, "Jet.deterministic_seed", jet_seed, value_type=np.uint64) + def call_func(self, events: ak.Array, **kwargs) -> ak.Array: + """Base class to produce object-specific random seeds. - # uniqueness test across all jets in the chunk for debugging - # n_jets = ak.sum(ak.num(events.Jet, axis=1)) - # n_seeds = len(set(np_jet_seed)) - # match_text = "yes" if n_jets == n_seeds else "NO !!!" - # print(f"jets: {n_jets}, unique seeds: {n_seeds}, match: {match_text}") + Produces deterministic seeds for each object in :py:attr:`object_field` + and stores them in *events* which is also returned. + The object-specific seeds are based on the event seeds like the ones produced by + :py:func:`deterministic_event_seeds` which is not called by this producer for the purpose of + of modularity. The strategy for producing seeds is identical. - return events + :param events: The events array. + :return: The events array with the object seeds stored in *object_field.deterministic_seed*. + .. note:: -@deterministic_jet_seeds.setup -def deterministic_jet_seeds_setup( - self: Producer, - reqs: dict, - inputs: dict, - reader_targets: InsertableDict, -) -> None: - # store primes in array - self.primes = np.array(primes, dtype=np.uint64) + The object seeds depend on the position of the particular object in the event. It is up to the + user to bring them into the desired order before invoking this producer. + """ + # create the seeds + primes = self.primes[events.deterministic_seed % len(self.primes)] + object_seed = events.deterministic_seed + ( + primes * ak.values_astype( + ak.local_index(events[self.object_field], axis=1) + self.primes[self.prime_offset], + np.uint64, + ) + ) + np_object_seed = np.asarray(ak.flatten(object_seed)) + np_object_seed[:] = create_seed_vec(np_object_seed) + + # store them + events = set_ak_column(events, f"{self.object_field}.deterministic_seed", object_seed, value_type=np.uint64) + + # uniqueness test across all jets in the chunk for debugging + # n_objects = ak.sum(ak.num(events[self.object_field], axis=1)) + # n_seeds = len(set(np_object_seed)) + # match_text = "yes" if n_jets == n_seeds else "NO !!!" + # print(f"{self.object_field}: {n_jets}, unique seeds: {n_seeds}, match: {match_text}") + + return events + + def init_func(self) -> None: + self.uses |= {f"{self.object_field}.pt"} + self.produces |= {f"{self.object_field}.deterministic_seed"} + + def setup_func( + self, + reqs: dict, + inputs: dict, + reader_targets: InsertableDict, + ) -> None: + """Setup before entering the event chunk loop. + + Saves the :py:attr:`~columnflow.util.primes` in an numpy array for later use. + + :param reqs: Resolved requirements (not used). + :param inputs: Dictionary for inputs (not used). + :param reader_targets: Dictionary for additional column to retrieve (not used). + """ + # store primes in array + self.primes = np.array(primes, dtype=np.uint64) + + +deterministic_jet_seeds = deterministic_object_seeds.derive( + "deterministic_jet_seeds", + cls_dict={ + "object_field": "Jet", + "prime_offset": 50, + }, +) + +deterministic_electron_seeds = deterministic_object_seeds.derive( + "deterministic_electron_seeds", + cls_dict={ + "object_field": "Electron", + "prime_offset": 60, + }, +) + +deterministic_photon_seeds = deterministic_object_seeds.derive( + "deterministic_photon_seeds", + cls_dict={ + "object_field": "Photon", + "prime_offset": 70, + }, +) @producer( uses={deterministic_event_seeds, deterministic_jet_seeds}, produces={deterministic_event_seeds, deterministic_jet_seeds}, ) -def deterministic_seeds(self: Producer, events: ak.Array, **kwargs) -> ak.Array: +def deterministic_seeds(self, events: ak.Array, **kwargs) -> ak.Array: """ Wrapper producer that invokes :py:func:`deterministic_event_seeds` and :py:func:`deterministic_jet_seeds`. diff --git a/columnflow/production/cms/supercluster_eta.py b/columnflow/production/cms/supercluster_eta.py new file mode 100644 index 000000000..b1285b724 --- /dev/null +++ b/columnflow/production/cms/supercluster_eta.py @@ -0,0 +1,34 @@ +""" +Module to calculate Photon super cluster eta. +Source: https://twiki.cern.ch/twiki/bin/view/CMS/EgammaNanoAOD#How_to_get_photon_supercluster_e +""" + +import law +import functools + +from columnflow.production import producer +from columnflow.util import maybe_import +from columnflow.columnar_util import set_ak_column + +np = maybe_import("numpy") +ak = maybe_import("awkward") + +logger = law.logger.get_logger(__name__) + +set_ak_column_f32 = functools.partial(set_ak_column, value_type=np.float32) + + +@producer( + uses={"Electron.{pt,phi,eta,deltaEtaSC}"}, + produces={"Electron.superclusterEta"}, +) +def electron_sceta(self, events: ak.Array, **kwargs) -> ak.Array: + """ + Returns the electron super cluster eta. + """ + + events = set_ak_column_f32( + events, "Electron.superclusterEta", + events.Electron.eta + events.Electron.deltaEtaSC, + ) + return events diff --git a/columnflow/tasks/cms/inference.py b/columnflow/tasks/cms/inference.py index 9386a47f6..a168d193f 100644 --- a/columnflow/tasks/cms/inference.py +++ b/columnflow/tasks/cms/inference.py @@ -7,19 +7,23 @@ from collections import OrderedDict, defaultdict import law +import order as od from columnflow.tasks.framework.base import Requirements, AnalysisTask, wrapper_factory from columnflow.tasks.framework.mixins import ( CalibratorsMixin, SelectorStepsMixin, ProducersMixin, MLModelsMixin, InferenceModelMixin, + HistHookMixin, WeightProducerMixin, ) from columnflow.tasks.framework.remote import RemoteWorkflow from columnflow.tasks.histograms import MergeHistograms, MergeShiftedHistograms -from columnflow.util import dev_sandbox +from columnflow.util import dev_sandbox, DotDict from columnflow.config_util import get_datasets_from_process class CreateDatacards( + HistHookMixin, InferenceModelMixin, + WeightProducerMixin, MLModelsMixin, ProducersMixin, SelectorStepsMixin, @@ -57,7 +61,12 @@ def get_mc_datasets(self, proc_obj: dict) -> list[str]: ) ] - # if not, check the config + # if the proc object is dynamic, it is calculated and the fly (e.g. via a hist hook) + # and doesn't have any additional requirements + if proc_obj.is_dynamic: + return [] + + # otherwise, check the config return [ dataset_inst.name for dataset_inst in get_datasets_from_process(self.config_inst, proc_obj.config_process) @@ -145,6 +154,7 @@ def requires(self): for dataset in self.get_mc_datasets(proc_obj) } for proc_obj in cat_obj.processes + if not proc_obj.is_dynamic } if cat_obj.config_data_datasets: reqs["data"] = { @@ -161,8 +171,14 @@ def requires(self): return reqs def output(self): + hooks_repr = self.hist_hooks_repr cat_obj = self.branch_data - basename = lambda name, ext: f"{name}__cat_{cat_obj.name}__var_{cat_obj.config_variable}.{ext}" + + def basename(name: str, ext: str) -> str: + parts = [name, cat_obj.name, cat_obj.config_variable] + if hooks_repr: + parts.append(f"hooks_{hooks_repr}") + return f"{'__'.join(map(str, parts))}.{ext}" return { "card": self.target(basename("datacard", "txt")), @@ -185,11 +201,13 @@ def run(self): leaf_category_insts = category_inst.get_leaf_categories() or [category_inst] # histogram data per process - hists = OrderedDict() + hists: dict[od.Process, hist.Hist] = dict() with self.publish_step(f"extracting {variable_inst.name} in {category_inst.name} ..."): + # loop over processes and forward them to any possible hist hooks for proc_obj_name, inp in inputs.items(): if proc_obj_name == "data": + # there is not process object for data proc_obj = None process_inst = self.config_inst.get_process("data") else: @@ -220,15 +238,10 @@ def run(self): for p in sub_process_insts if p.id in h.axes["process"] ], - "category": [ - hist.loc(c.id) - for c in leaf_category_insts - if c.id in h.axes["category"] - ], }] # axis reductions - h = h[{"process": sum, "category": sum}] + h = h[{"process": sum}] # add the histogram for this dataset if h_proc is None: @@ -240,30 +253,68 @@ def run(self): if h_proc is None: raise Exception(f"no histograms found for process '{process_inst.name}'") + # save histograms in hist_hook format + hists[process_inst] = h_proc + + # apply hist hooks + hists = self.invoke_hist_hooks(hists) + + # define datacard processes to loop over + cat_processes = list(cat_obj.processes) + if cat_obj.config_data_datasets and not cat_obj.data_from_processes: + cat_processes.append(DotDict({"name": "data"})) + + # after application of hist hooks, we can proceed with the datacard creation + datacard_hists: OrderedDict[str, OrderedDict[str, hist.Hist]] = OrderedDict() + for proc_obj in cat_processes: + # obtain process information from inference model and config again + proc_name = "data" if proc_obj.name == "data" else proc_obj.config_process + process_inst = self.config_inst.get_process(proc_name) + + h_proc = hists.get(process_inst, None) + if h_proc is None: + self.logger.warning( + f"found no histogram for process '{proc_obj.name}', please check your " + f"inference model '{self.inference_model}'", + ) + continue + + # select relevant category + h_proc = h_proc[{ + "category": [ + hist.loc(c.id) + for c in leaf_category_insts + if c.id in h_proc.axes["category"] + ], + }][{"category": sum}] + # create the nominal hist - hists[proc_obj_name] = OrderedDict() + datacard_hists[proc_obj.name] = OrderedDict() nominal_shift_inst = self.config_inst.get_shift("nominal") - hists[proc_obj_name]["nominal"] = h_proc[ + datacard_hists[proc_obj.name]["nominal"] = h_proc[ {"shift": hist.loc(nominal_shift_inst.id)} ] - # per shift - if proc_obj: - for param_obj in proc_obj.parameters: - # skip the parameter when varied hists are not needed - if not self.inference_model_inst.require_shapes_for_parameter(param_obj): - continue - # store the varied hists - hists[proc_obj_name][param_obj.name] = {} - for d in ["up", "down"]: - shift_inst = self.config_inst.get_shift(f"{param_obj.config_shift_source}_{d}") - hists[proc_obj_name][param_obj.name][d] = h_proc[ - {"shift": hist.loc(shift_inst.id)} - ] + # stop here for data + if proc_obj.name == "data": + continue + + # create histograms per shift + for param_obj in proc_obj.parameters: + # skip the parameter when varied hists are not needed + if not self.inference_model_inst.require_shapes_for_parameter(param_obj): + continue + # store the varied hists + datacard_hists[proc_obj_name][param_obj.name] = {} + for d in ["up", "down"]: + shift_inst = self.config_inst.get_shift(f"{param_obj.config_shift_source}_{d}") + datacard_hists[proc_obj_name][param_obj.name][d] = h_proc[ + {"shift": hist.loc(shift_inst.id)} + ] # forward objects to the datacard writer outputs = self.output() - writer = DatacardWriter(self.inference_model_inst, {cat_obj.name: hists}) + writer = DatacardWriter(self.inference_model_inst, {cat_obj.name: datacard_hists}) with outputs["card"].localize("w") as tmp_card, outputs["shapes"].localize("w") as tmp_shapes: writer.write(tmp_card.abspath, tmp_shapes.abspath, shapes_path_ref=outputs["shapes"].basename) diff --git a/columnflow/tasks/framework/base.py b/columnflow/tasks/framework/base.py index 2e4cf3bde..f5d352ea7 100644 --- a/columnflow/tasks/framework/base.py +++ b/columnflow/tasks/framework/base.py @@ -98,7 +98,7 @@ class AnalysisTask(BaseTask, law.SandboxTask): exclude_params_index = {"user"} exclude_params_req = {"user", "notify_slack", "notify_mattermost", "notify_custom"} exclude_params_repr = {"user", "notify_slack", "notify_mattermost", "notify_custom"} - exclude_params_branch = {"user", "notify_slack", "notify_mattermost", "notify_custom"} + exclude_params_branch = {"user"} exclude_params_workflow = {"user", "notify_slack", "notify_mattermost", "notify_custom"} # cached and parsed sections of the law config for faster lookup diff --git a/columnflow/tasks/framework/plotting.py b/columnflow/tasks/framework/plotting.py index ee4695914..a63fef391 100644 --- a/columnflow/tasks/framework/plotting.py +++ b/columnflow/tasks/framework/plotting.py @@ -271,7 +271,7 @@ def update_plot_kwargs(self, kwargs: dict) -> dict: # update style_config style_config = kwargs.get("style_config", {}) if isinstance(custom_style_config, dict) and isinstance(style_config, dict): - style_config = law.util.merge_dicts(custom_style_config, style_config) + style_config = law.util.merge_dicts(style_config, custom_style_config) kwargs["style_config"] = style_config # update other defaults diff --git a/columnflow/tasks/histograms.py b/columnflow/tasks/histograms.py index bfc316e9e..2124adb02 100644 --- a/columnflow/tasks/histograms.py +++ b/columnflow/tasks/histograms.py @@ -573,7 +573,7 @@ class MergeHistograms( remove_previous = luigi.BoolParameter( default=False, significant=False, - description="when True, remove particlar input histograms after merging; default: False", + description="when True, remove particular input histograms after merging; default: False", ) sandbox = dev_sandbox(law.config.get("analysis", "default_columnar_sandbox")) @@ -602,7 +602,7 @@ def _get_variables(self): # optional dynamic behavior: determine not yet created variables and require only those if self.only_missing: - missing = self.output().count(existing=False, keys=True)[1] + missing = self.output()["hists"].count(existing=False, keys=True)[1] variables = sorted(missing, key=variables.index) return variables diff --git a/columnflow/tasks/ml.py b/columnflow/tasks/ml.py index 5b0779987..0d2f9ce74 100644 --- a/columnflow/tasks/ml.py +++ b/columnflow/tasks/ml.py @@ -301,13 +301,9 @@ class MergeMLStats( SelectorMixin, CalibratorsMixin, DatasetTask, - law.tasks.ForestMerge, + law.LocalWorkflow, + RemoteWorkflow, ): - # recursively merge 20 files into one - merge_factor = 20 - - # skip receiving some parameters via req - exclude_params_req_get = {"workflow"} # upstream requirements reqs = Requirements( @@ -315,40 +311,35 @@ class MergeMLStats( ) def create_branch_map(self): - # DatasetTask implements a custom branch map, but we want to use the one in ForestMerge - return law.tasks.ForestMerge.create_branch_map(self) + # dummy branch map + return {0: None} - def merge_workflow_requires(self): - return self.reqs.PrepareMLEvents.req(self, _exclude={"branches"}) + def workflow_requires(self): + reqs = super().workflow_requires() + reqs["events"] = self.reqs.PrepareMLEvents.req_different_branching(self) + return reqs - def merge_requires(self, start_branch, end_branch): - return self.reqs.PrepareMLEvents.req( + def requires(self): + return self.reqs.PrepareMLEvents.req_different_branching( self, - branches=((start_branch, end_branch),), + branch=-1, workflow="local", - _exclude={"branch"}, ) - def merge_output(self): + def output(self): return {"stats": self.target("stats.json")} - def trace_merge_inputs(self, inputs): - return super().trace_merge_inputs(inputs["collection"].targets.values()) - @law.decorator.notify @law.decorator.log def run(self): - return super().run() - - def merge(self, inputs, output): # merge input stats merged_stats = defaultdict(float) - for inp in inputs: + for inp in self.input().collection.targets.values(): stats = inp["stats"].load(formatter="json", cache=False) self.merge_counts(merged_stats, stats) # write the output - output["stats"].dump(merged_stats, indent=4, formatter="json", cache=False) + self.output()["stats"].dump(merged_stats, indent=4, formatter="json", cache=False) @classmethod def merge_counts(cls, dst: dict, src: dict) -> dict: @@ -532,7 +523,7 @@ def workflow_requires(self): calibrators=_calibrators, selector=_selector, producers=_producers, - tree_index=-1) + ) for dataset_inst in dataset_insts } for (config_inst, dataset_insts), _calibrators, _selector, _producers in zip( diff --git a/columnflow/tasks/reduction.py b/columnflow/tasks/reduction.py index a67540089..0ae088788 100644 --- a/columnflow/tasks/reduction.py +++ b/columnflow/tasks/reduction.py @@ -455,7 +455,7 @@ def create_branch_map(self): def workflow_requires(self): reqs = super().workflow_requires() reqs["stats"] = self.reqs.MergeReductionStats.req_different_branching(self) - reqs["events"] = self.reqs.ReduceEvents.req_different_branching(self) + reqs["events"] = self.reqs.ReduceEvents.req_different_branching(self, branches=((0, -1),)) return reqs def requires(self): diff --git a/columnflow/tasks/selection.py b/columnflow/tasks/selection.py index 8da1bc052..2b5c719f3 100644 --- a/columnflow/tasks/selection.py +++ b/columnflow/tasks/selection.py @@ -24,9 +24,9 @@ logger = law.logger.get_logger(__name__) -default_selection_hists_optional = law.config.get_expanded_bool( +default_create_selection_hists = law.config.get_expanded_bool( "analysis", - "default_selection_hists_optional", + "default_create_selection_hists", True, ) @@ -39,9 +39,6 @@ class SelectEvents( law.LocalWorkflow, RemoteWorkflow, ): - # flag that sets the *hists* output to optional if True - selection_hists_optional = default_selection_hists_optional - # default sandbox, might be overwritten by selector function sandbox = dev_sandbox(law.config.get("analysis", "default_columnar_sandbox")) @@ -59,6 +56,9 @@ class SelectEvents( # strategy for handling missing source columns when adding aliases on event chunks missing_column_alias_strategy = "original" + # whether histogram outputs should be created + create_selection_hists = default_create_selection_hists + def workflow_requires(self): reqs = super().workflow_requires() @@ -99,9 +99,12 @@ def output(self): outputs = { "results": self.target(f"results_{self.branch}.parquet"), "stats": self.target(f"stats_{self.branch}.json"), - "hists": self.target(f"hists_{self.branch}.pickle", optional=self.selection_hists_optional), } + # add histograms if requested + if self.create_selection_hists: + outputs["hists"] = self.target(f"hists_{self.branch}.pickle") + # add additional columns in case the selector produces some if self.selector_inst.produced_columns: outputs["columns"] = self.target(f"columns_{self.branch}.parquet") @@ -246,7 +249,8 @@ def run(self): # save stats outputs["stats"].dump(stats, formatter="json") - outputs["hists"].dump(hists, formatter="pickle") + if self.create_selection_hists: + outputs["hists"].dump(hists, formatter="pickle") # print some stats eff = safe_div(stats["num_events_selected"], stats["num_events"]) @@ -289,12 +293,12 @@ class MergeSelectionStats( law.LocalWorkflow, RemoteWorkflow, ): - # flag that sets the *hists* output to optional if True - selection_hists_optional = default_selection_hists_optional - # default sandbox, might be overwritten by selector function (needed to load hist objects) sandbox = dev_sandbox(law.config.get("analysis", "default_columnar_sandbox")) + # whether histogram outputs should be created + create_selection_hists = default_create_selection_hists + # upstream requirements reqs = Requirements( RemoteWorkflow.reqs, @@ -314,41 +318,29 @@ def requires(self): return self.reqs.SelectEvents.req_different_branching(self, workflow="local", branch=-1) def output(self): - return { - "stats": self.target("stats.json"), - "hists": self.target("hists.pickle", optional=self.selection_hists_optional), - } + outputs = {"stats": self.target("stats.json")} + if self.create_selection_hists: + outputs["hists"] = self.target("hists.pickle") + return outputs @law.decorator.notify @law.decorator.log def run(self): - # check that hists are present for all inputs - inputs = list(self.input().collection.targets.values()) - hist_inputs_exist = [inp["hists"].exists() for inp in inputs] - if any(hist_inputs_exist) and not all(hist_inputs_exist): - logger.warning( - f"For dataset {self.dataset_inst.name}, cf.SelectEvents has produced hists for " - "some but not all files. Histograms will not be merged and an empty pickle file will be stored.", - ) - # merge input stats merged_stats = defaultdict(float) merged_hists = {} - - for inp in inputs: + for inp in self.input().collection.targets.values(): stats = inp["stats"].load(formatter="json", cache=False) self.merge_counts(merged_stats, stats) - - # merge hists only if all hists are present - if all(hist_inputs_exist): - for inp in inputs: + if self.create_selection_hists: hists = inp["hists"].load(formatter="pickle", cache=False) self.merge_counts(merged_hists, hists) - # write the outputs + # write outputs outputs = self.output() outputs["stats"].dump(merged_stats, formatter="json", cache=False) - outputs["hists"].dump(merged_hists, formatter="pickle", cache=False) + if self.create_selection_hists: + outputs["hists"].dump(merged_hists, formatter="pickle", cache=False) @classmethod def merge_counts(cls, dst: dict, src: dict) -> dict: @@ -406,9 +398,7 @@ def create_branch_map(self): return law.tasks.ForestMerge.create_branch_map(self) def merge_workflow_requires(self): - reqs = { - "selection": self.reqs.SelectEvents.req(self, _exclude={"branches"}), - } + reqs = {"selection": self.reqs.SelectEvents.req_different_branching(self)} if self.dataset_inst.is_mc: reqs["normalization"] = self.norm_weight_producer.run_requires() @@ -418,7 +408,7 @@ def merge_workflow_requires(self): def merge_requires(self, start_branch, end_branch): reqs = { "selection": [ - self.reqs.SelectEvents.req(self, branch=b) + self.reqs.SelectEvents.req_different_branching(self, branch=b) for b in range(start_branch, end_branch) ], } @@ -484,8 +474,8 @@ def zip_results_and_columns(self, inputs, tmp_dir): route_filter = RouteFilter(write_columns) for inp in inputs: - events = inp["columns"].load(formatter="awkward") - steps = inp["results"].load(formatter="awkward").steps + events = inp["columns"].load(formatter="awkward", cache=False) + steps = inp["results"].load(formatter="awkward", cache=False).steps # add normalization weight if self.dataset_inst.is_mc: diff --git a/columnflow/util.py b/columnflow/util.py index ddefadb3f..1df09623d 100644 --- a/columnflow/util.py +++ b/columnflow/util.py @@ -30,7 +30,6 @@ from columnflow import env_is_dev, env_is_remote from columnflow.types import Callable, Any, Sequence, Union, ModuleType - #: Placeholder for an unset value. UNSET = object() @@ -455,6 +454,15 @@ def try_int(i: Any) -> bool: return False +def maybe_int(i: Any) -> Any: + """ + Returns *i* as an integer if it is a whole number, and as a float otherwise. + """ + if isinstance(i, (int, bool)) or (isinstance(i, float) and i.is_integer()): + return int(i) + return i + + def is_pattern(s: str) -> bool: """ Returns *True* if a string *s* contains pattern characters such as "*" or "?", and *False* @@ -934,3 +942,20 @@ def __init__(self, *args, key, value, **kwargs): def __str__(self) -> str: return str(self.value) + + +def load_correction_set(target: law.FileSystemFileTarget) -> Any: + """ + Loads a correction set using the correctionlib from a file *target*. + """ + import correctionlib + + # extend the Correction object + correctionlib.highlevel.Correction.__call__ = correctionlib.highlevel.Correction.evaluate + + # use the path when the input file is a normal json + if target.ext() == "json": + return correctionlib.CorrectionSet.from_file(target.abspath) + + # otherwise, assume the input file is compressed + return correctionlib.CorrectionSet.from_string(target.load(formatter="gzip").decode("utf-8")) diff --git a/docs/api/types.rst b/docs/api/types.rst index e2c680ef4..4bf7f71f9 100644 --- a/docs/api/types.rst +++ b/docs/api/types.rst @@ -6,4 +6,4 @@ :members: :undoc-members: :show-inheritance: - + :imported-members: diff --git a/docs/api/util.rst b/docs/api/util.rst index 790d510b3..ff1be7a49 100644 --- a/docs/api/util.rst +++ b/docs/api/util.rst @@ -1,218 +1,9 @@ ``columnflow.util`` =================== -.. automodule:: columnflow.util - .. currentmodule:: columnflow.util - -Summary -------- - -.. autosummary:: - - UNSET - env_is_remote - env_is_dev - primes - maybe_import - import_plt - import_ROOT - import_file - create_random_name - expand_path - real_path - ensure_dir - wget - call_thread - call_proc - ensure_proxy - dev_sandbox - safe_div - try_float - is_pattern - is_regex - pattern_matcher - dict_add_strict - get_source_code - DotDict - MockModule - FunctionArgs - ClassPropertyDescriptor - classproperty - DerivableMeta - Derivable - - -Attributes ----------- - -``UNSET`` -+++++++++ - -.. autoattribute:: columnflow.util.UNSET - -``env_is_remote`` -+++++++++++++++++ - -.. autoattribute:: columnflow.util.env_is_remote - -``env_is_dev`` -++++++++++++++ - -.. autoattribute:: columnflow.util.env_is_dev - -``primes`` -++++++++++ - -.. autoattribute:: columnflow.util.primes - -Functions ---------- - -``maybe_import`` -++++++++++++++++ - -.. autofunction:: maybe_import - -``import_plt`` -++++++++++++++ - -.. autofunction:: import_plt - -``import_ROOT`` -+++++++++++++++ - -.. autofunction:: import_ROOT - -``import_file`` -+++++++++++++++ - -.. autofunction:: import_file - -``create_random_name`` -++++++++++++++++++++++ - -.. autofunction:: create_random_name - -``expand_path`` -+++++++++++++++ - -.. autofunction:: expand_path - -``real_path`` -+++++++++++++ - -.. autofunction:: real_path - -``ensure_dir`` -++++++++++++++ - -.. autofunction:: ensure_dir - -``wget`` -++++++++ - -.. autofunction:: wget - -``call_thread`` -+++++++++++++++ - -.. autofunction:: call_thread - -``call_proc`` -+++++++++++++ - -.. autofunction:: call_proc - -``ensure_proxy`` -++++++++++++++++ - -.. autofunction:: ensure_proxy - -``dev_sandbox`` -+++++++++++++++ - -.. autofunction:: dev_sandbox - -``safe_div`` -++++++++++++ - -.. autofunction:: safe_div - -``try_float`` -++++++++++++++ - -.. autofunction:: try_float - -``is_pattern`` -++++++++++++++ - -.. autofunction:: is_pattern - -``is_regex`` -++++++++++++ - -.. autofunction:: is_regex - -``pattern_matcher`` -+++++++++++++++++++ - -.. autofunction:: pattern_matcher - -``dict_add_strict`` -+++++++++++++++++++ - -.. autofunction:: dict_add_strict - -``get_source_code`` -+++++++++++++++++++ - -.. autofunction:: get_source_code - -``classproperty`` -+++++++++++++++++ - -.. autofunction:: classproperty - -Classes -------- - -``DotDict`` -+++++++++++ - -.. autoclass:: DotDict - :members: - :special-members: - -``MockModule`` -++++++++++++++ - -.. autoclass:: MockModule - :members: - -``FunctionArgs`` -++++++++++++++++ - -.. autoclass:: FunctionArgs - :members: - -``ClassPropertyDescriptor`` -+++++++++++++++++++++++++++ - -.. autoclass:: ClassPropertyDescriptor - :members: - :special-members: - -``DerivableMeta`` -+++++++++++++++++ - -.. autoclass:: DerivableMeta - :members: - :special-members: - -``Derivable`` -+++++++++++++ - -.. autoclass:: Derivable +.. automodule:: columnflow.util + :autosummary: :members: + :undoc-members: :special-members: diff --git a/docs/conf.py b/docs/conf.py index 99d139078..b958286ba 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -87,6 +87,7 @@ autodoc_default_options = { "member-order": "bysource", "show-inheritance": True, + "ignore-module-all": True, } autosectionlabel_prefix_document = True @@ -99,7 +100,7 @@ intersphinx_mapping = { "python": ("https://docs.python.org/3", None), - "coffea": ("https://coffeateam.github.io/coffea", None), + "coffea": ("https://coffea-hep.readthedocs.io/en/latest/", None), "law": ("https://law.readthedocs.io/en/latest/", None), "order": ("https://python-order.readthedocs.io/en/latest/", None), "ak": ("https://awkward-array.org/doc/main", None), diff --git a/law.cfg b/law.cfg index 0d6ae338f..f0d7e7b82 100644 --- a/law.cfg +++ b/law.cfg @@ -44,8 +44,8 @@ default_keep_reduced_events: True # slightly to the left to avoid them being excluded from the last bin; None leads to automatic mode default_histogram_last_edge_inclusive: None -# boolean flag that, if True, sets the *hists* output of cf.SelectEvents and cf.MergeSelectionStats to optional -default_selection_hists_optional: True +# boolean flag that, if True, configures cf.SelectEvents to create statistics histograms +default_create_selection_hists: True # wether or not the ensure_proxy decorator should be skipped, even if used by task's run methods skip_ensure_proxy: False diff --git a/sandboxes/_setup_cmssw.sh b/sandboxes/_setup_cmssw.sh index 3e6411d03..373c67c0e 100644 --- a/sandboxes/_setup_cmssw.sh +++ b/sandboxes/_setup_cmssw.sh @@ -334,7 +334,7 @@ setup_cmssw() { # prepend persistent path fragments again to ensure priority for local packages and # remove the conda based python fragments since there are too many overlaps between packages - export PYTHONPATH="${CF_PERSISTENT_PYTHONPATH}:$( echo ${PYTHONPATH} | sed "s|${CF_CONDA_PYTHONPATH}||g" )" + export PYTHONPATH="${CF_INITIAL_PYTHONPATH}:${CF_PERSISTENT_PYTHONPATH}:$( echo ${PYTHONPATH} | sed "s|${CF_CONDA_PYTHONPATH}||g" )" export PATH="${CF_PERSISTENT_PATH}:${PATH}" # mark this as a bash sandbox for law diff --git a/sandboxes/_setup_venv.sh b/sandboxes/_setup_venv.sh index 1d0830f6c..f746551ed 100644 --- a/sandboxes/_setup_venv.sh +++ b/sandboxes/_setup_venv.sh @@ -148,7 +148,7 @@ setup_venv() { # prepend persistent path fragments to priotize packages in the outer env export CF_VENV_PYTHONPATH="${install_path}/lib/python${pyv}/site-packages" - export PYTHONPATH="${CF_PERSISTENT_PYTHONPATH}:${CF_VENV_PYTHONPATH}:${PYTHONPATH}" + export PYTHONPATH="${CF_INITIAL_PYTHONPATH}:${CF_PERSISTENT_PYTHONPATH}:${CF_VENV_PYTHONPATH}:${PYTHONPATH}" export PATH="${CF_PERSISTENT_PATH}:${PATH}" diff --git a/tests/run_tests b/tests/run_tests index 943ffe526..963f3398d 100755 --- a/tests/run_tests +++ b/tests/run_tests @@ -32,6 +32,9 @@ action() { ret="$?" [ "${gret}" = "0" ] && gret="${ret}" + # test_inference + echo + bash "${this_dir}/run_test" test_inference "${cf_dir}/sandboxes/venv_columnar${dev}.sh" # test_hist_util echo bash "${this_dir}/run_test" test_hist_util "${cf_dir}/sandboxes/venv_columnar${dev}.sh" diff --git a/tests/test_inference.py b/tests/test_inference.py new file mode 100644 index 000000000..f4c3c016a --- /dev/null +++ b/tests/test_inference.py @@ -0,0 +1,240 @@ +import unittest +from columnflow.inference import ( + InferenceModel, ParameterType, ParameterTransformation, ParameterTransformations, + FlowStrategy, +) +from columnflow.util import DotDict + + +class TestInferenceModel(unittest.TestCase): + + def test_process_spec(self): + # Test data + name = "test_process" + config_process = "test_config_process" + is_signal = True + config_mc_datasets = ["dataset1", "dataset2"] + scale = 2.0 + is_dynamic = True + + # Expected result + expected_result = DotDict([ + ("name", "test_process"), + ("is_signal", True), + ("config_process", "test_config_process"), + ("config_mc_datasets", ["dataset1", "dataset2"]), + ("scale", 2.0), + ("parameters", []), + ("is_dynamic", True), + ]) + + # Call the method + result = InferenceModel.process_spec( + name=name, + config_process=config_process, + is_signal=is_signal, + config_mc_datasets=config_mc_datasets, + scale=scale, + is_dynamic=is_dynamic, + ) + + # Assert the result + self.assertEqual(result, expected_result) + + def test_category_spec(self): + # Test data + name = "test_category" + config_category = "test_config_category" + config_variable = "test_config_variable" + config_data_datasets = ["dataset1", "dataset2"] + data_from_processes = ["process1", "process2"] + mc_stats = (10, 0.1) + empty_bin_value = 1e-4 + + # Expected result + expected_result = DotDict([ + ("name", "test_category"), + ("config_category", "test_config_category"), + ("config_variable", "test_config_variable"), + ("config_data_datasets", ["dataset1", "dataset2"]), + ("data_from_processes", ["process1", "process2"]), + ("flow_strategy", FlowStrategy.warn), + ("mc_stats", (10, 0.1)), + ("empty_bin_value", 1e-4), + ("processes", []), + ]) + + # Call the method + result = InferenceModel.category_spec( + name=name, + config_category=config_category, + config_variable=config_variable, + config_data_datasets=config_data_datasets, + data_from_processes=data_from_processes, + mc_stats=mc_stats, + empty_bin_value=empty_bin_value, + ) + + # Assert the result + self.assertEqual(result, expected_result) + + def test_parameter_spec(self): + # Test data + name = "test_parameter" + type = ParameterType.rate_gauss + transformations = [ParameterTransformation.centralize, ParameterTransformation.symmetrize] + config_shift_source = "test_shift_source" + effect = 1.5 + + # Expected result + expected_result = DotDict([ + ("name", "test_parameter"), + ("type", ParameterType.rate_gauss), + ("transformations", ParameterTransformations(transformations)), + ("config_shift_source", "test_shift_source"), + ("effect", 1.5), + ]) + + # Call the method + result = InferenceModel.parameter_spec( + name=name, + type=type, + transformations=transformations, + config_shift_source=config_shift_source, + effect=effect, + ) + + # Assert the result + self.assertEqual(result, expected_result) + + def test_parameter_spec_with_default_transformations(self): + # Test data + name = "test_parameter" + type = ParameterType.rate_gauss + config_shift_source = "test_shift_source" + effect = 1.5 + + # Expected result + expected_result = DotDict([ + ("name", "test_parameter"), + ("type", ParameterType.rate_gauss), + ("transformations", ParameterTransformations([ParameterTransformation.none])), + ("config_shift_source", "test_shift_source"), + ("effect", 1.5), + ]) + + # Call the method + result = InferenceModel.parameter_spec( + name=name, + type=type, + config_shift_source=config_shift_source, + effect=effect, + ) + + # Assert the result + self.assertEqual(result, expected_result) + + def test_parameter_spec_with_string_type_and_transformations(self): + # Test data + name = "test_parameter" + type = "rate_gauss" + transformations = ["centralize", "symmetrize"] + config_shift_source = "test_shift_source" + effect = 1.5 + + # Expected result + expected_result = DotDict([ + ("name", "test_parameter"), + ("type", ParameterType.rate_gauss), + ("transformations", ParameterTransformations([ + ParameterTransformation.centralize, + ParameterTransformation.symmetrize, + ])), + ("config_shift_source", "test_shift_source"), + ("effect", 1.5), + ]) + + # Call the method + result = InferenceModel.parameter_spec( + name=name, + type=type, + transformations=transformations, + config_shift_source=config_shift_source, + effect=effect, + ) + + # Assert the result + self.assertEqual(result, expected_result) + + def test_parameter_group_spec(self): + # Test data + name = "test_group" + parameter_names = ["param1", "param2", "param3"] + + # Expected result + expected_result = DotDict([ + ("name", "test_group"), + ("parameter_names", ["param1", "param2", "param3"]), + ]) + + # Call the method + result = InferenceModel.parameter_group_spec( + name=name, + parameter_names=parameter_names, + ) + + # Assert the result + self.assertEqual(result, expected_result) + + def test_parameter_group_spec_with_no_parameter_names(self): + # Test data + name = "test_group" + + # Expected result + expected_result = DotDict([ + ("name", "test_group"), + ("parameter_names", []), + ]) + + # Call the method + result = InferenceModel.parameter_group_spec( + name=name, + ) + + # Assert the result + self.assertEqual(result, expected_result) + + def test_require_shapes_for_parameter_shape(self): + # No shape is required if the parameter type is a rate + types = [ParameterType.rate_gauss, ParameterType.rate_uniform, ParameterType.rate_unconstrained] + for t in types: + with self.subTest(t=t): + param_obj = DotDict.wrap({ + "type": t, + "transformations": ParameterTransformations([ParameterTransformation.effect_from_rate]), + "name": "test_param", + }) + result = InferenceModel.require_shapes_for_parameter(param_obj) + self.assertFalse(result) + + # if the transformation is shape-based expect True + param_obj.transformations = ParameterTransformations([ParameterTransformation.effect_from_shape]) + result = InferenceModel.require_shapes_for_parameter(param_obj) + self.assertTrue(result) + + # No shape is required if the transformation is from a rate + param_obj = DotDict.wrap({ + "type": ParameterType.shape, + "transformations": ParameterTransformations([ParameterTransformation.effect_from_rate]), + "name": "test_param", + }) + result = InferenceModel.require_shapes_for_parameter(param_obj) + self.assertFalse(result) + + param_obj.transformations = ParameterTransformations([ParameterTransformation.effect_from_shape]) + result = InferenceModel.require_shapes_for_parameter(param_obj) + self.assertTrue(result) + + +if __name__ == "__main__": + unittest.main()