|
33 | 33 | format="%(asctime)s.%(msecs)03d %(levelname)-8s %(message)s",
|
34 | 34 | datefmt="%Y-%m-%d %H:%M:%S",
|
35 | 35 | )
|
36 |
| -LOGGER = logging.getLogger() |
| 36 | +_LOGGER = logging.getLogger() |
37 | 37 | mute_jax_warnings()
|
38 |
| - |
39 | 38 | yaml.add_representer(defaultdict, Representer.represent_dict)
|
40 | 39 |
|
41 | 40 | BENCHMARK_CASES = [
|
@@ -78,7 +77,7 @@ def main() -> int:
|
78 | 77 | existing_benchmark = t.get(n)
|
79 | 78 | if existing_benchmark is not None:
|
80 | 79 | if all(len(v) == NUMBER_OF_RUNS for v in existing_benchmark.values()):
|
81 |
| - logging.warning(f"Benchmark for {n:,} events already exists") |
| 80 | + _LOGGER.warning(f"Benchmark for {n:,} events already exists") |
82 | 81 | progress_bar.update(NUMBER_OF_RUNS)
|
83 | 82 | continue
|
84 | 83 | t[n] = defaultdict(list)
|
@@ -118,33 +117,30 @@ def create_amplitude_model() -> AmplitudeModel:
|
118 | 117 | def prepare_functions(
|
119 | 118 | model: AmplitudeModel,
|
120 | 119 | ) -> tuple[ParametrizedBackendFunction, PositionalArgumentFunction]:
|
121 |
| - original_log_level = LOGGER.getEffectiveLevel() |
122 |
| - LOGGER.setLevel(logging.INFO) |
123 |
| - logging.info("Unfolding intensity expression") |
| 120 | + _LOGGER.info("Unfolding intensity expression") |
124 | 121 | unfolded_intensity_expr = perform_cached_doit(model.full_expression)
|
125 |
| - logging.info("Substituting parameters") |
| 122 | + _LOGGER.info("Substituting parameters") |
126 | 123 | substituted_expr = unfolded_intensity_expr.xreplace(model.parameter_defaults)
|
127 |
| - logging.info("Lambdifying full intensity expression") |
| 124 | + _LOGGER.info("Lambdifying full intensity expression") |
128 | 125 | parametrized_func = create_parametrized_function(
|
129 | 126 | unfolded_intensity_expr,
|
130 | 127 | parameters=model.parameter_defaults,
|
131 | 128 | backend="jax",
|
132 | 129 | )
|
133 |
| - logging.info("Lambdifying substituted intensity expression") |
| 130 | + _LOGGER.info("Lambdifying substituted intensity expression") |
134 | 131 | substituted_func = create_function(substituted_expr, backend="jax")
|
135 |
| - logging.info("Finished function lambdification") |
136 |
| - LOGGER.setLevel(original_log_level) |
| 132 | + _LOGGER.info("Finished function lambdification") |
137 | 133 | return parametrized_func, substituted_func
|
138 | 134 |
|
139 | 135 |
|
140 | 136 | def generate_sample(
|
141 | 137 | model: AmplitudeModel, n_events: int, seed: int | None = None
|
142 | 138 | ) -> DataSample:
|
143 | 139 | transformer = create_data_transformer(model)
|
144 |
| - original_log_level = LOGGER.getEffectiveLevel() |
145 |
| - LOGGER.setLevel(logging.ERROR) |
| 140 | + original_log_level = _LOGGER.getEffectiveLevel() |
| 141 | + _LOGGER.setLevel(logging.ERROR) |
146 | 142 | phsp_sample = generate_phasespace_sample(model.decay, n_events, seed)
|
147 |
| - LOGGER.setLevel(original_log_level) |
| 143 | + _LOGGER.setLevel(original_log_level) |
148 | 144 | return transformer(phsp_sample)
|
149 | 145 |
|
150 | 146 |
|
|
0 commit comments