Skip to content

Commit

Permalink
MAINT: address Ruff LOG015 error
Browse files Browse the repository at this point in the history
  • Loading branch information
redeboer committed Jan 7, 2025
1 parent e151210 commit f8623dc
Showing 1 changed file with 10 additions and 14 deletions.
24 changes: 10 additions & 14 deletions scripts/execute_jax_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,8 @@
format="%(asctime)s.%(msecs)03d %(levelname)-8s %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
LOGGER = logging.getLogger()
_LOGGER = logging.getLogger()
mute_jax_warnings()

yaml.add_representer(defaultdict, Representer.represent_dict)

BENCHMARK_CASES = [
Expand Down Expand Up @@ -78,7 +77,7 @@ def main() -> int:
existing_benchmark = t.get(n)
if existing_benchmark is not None:
if all(len(v) == NUMBER_OF_RUNS for v in existing_benchmark.values()):
logging.warning(f"Benchmark for {n:,} events already exists")
_LOGGER.warning(f"Benchmark for {n:,} events already exists")
progress_bar.update(NUMBER_OF_RUNS)
continue
t[n] = defaultdict(list)
Expand Down Expand Up @@ -118,33 +117,30 @@ def create_amplitude_model() -> AmplitudeModel:
def prepare_functions(
model: AmplitudeModel,
) -> tuple[ParametrizedBackendFunction, PositionalArgumentFunction]:
original_log_level = LOGGER.getEffectiveLevel()
LOGGER.setLevel(logging.INFO)
logging.info("Unfolding intensity expression")
_LOGGER.info("Unfolding intensity expression")
unfolded_intensity_expr = perform_cached_doit(model.full_expression)
logging.info("Substituting parameters")
_LOGGER.info("Substituting parameters")
substituted_expr = unfolded_intensity_expr.xreplace(model.parameter_defaults)
logging.info("Lambdifying full intensity expression")
_LOGGER.info("Lambdifying full intensity expression")
parametrized_func = create_parametrized_function(
unfolded_intensity_expr,
parameters=model.parameter_defaults,
backend="jax",
)
logging.info("Lambdifying substituted intensity expression")
_LOGGER.info("Lambdifying substituted intensity expression")
substituted_func = create_function(substituted_expr, backend="jax")
logging.info("Finished function lambdification")
LOGGER.setLevel(original_log_level)
_LOGGER.info("Finished function lambdification")
return parametrized_func, substituted_func


def generate_sample(
model: AmplitudeModel, n_events: int, seed: int | None = None
) -> DataSample:
transformer = create_data_transformer(model)
original_log_level = LOGGER.getEffectiveLevel()
LOGGER.setLevel(logging.ERROR)
original_log_level = _LOGGER.getEffectiveLevel()
_LOGGER.setLevel(logging.ERROR)
phsp_sample = generate_phasespace_sample(model.decay, n_events, seed)
LOGGER.setLevel(original_log_level)
_LOGGER.setLevel(original_log_level)
return transformer(phsp_sample)


Expand Down

0 comments on commit f8623dc

Please sign in to comment.