Skip to content

Commit

Permalink
Merge pull request #43 from aai-institute/main-hotfixes
Browse files Browse the repository at this point in the history
Make first argument to `@nnbench.parametrize` a single iterable
  • Loading branch information
nicholasjng authored Feb 1, 2024
2 parents 17020b9 + 2c044e6 commit d839f62
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 12 deletions.
4 changes: 2 additions & 2 deletions docs/guides/benchmarks.md
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,9 @@ import nnbench
from training import prepare_data, train_rf, accuracy

@nnbench.parametrize(
{"n_estimators": 10, "max_depth": 2},
({"n_estimators": 10, "max_depth": 2},
{"n_estimators": 50, "max_depth": 5},
{"n_estimators": 100, "max_depth": 10}
{"n_estimators": 100, "max_depth": 10})
)
def benchmark_accuracy(n_estimators: int, max_depth: int, random_state: int) -> float:
X_train, X_test, y_train, y_test = prepare_data()
Expand Down
18 changes: 8 additions & 10 deletions src/nnbench/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,11 +102,11 @@ def decorator(fun: Callable) -> Benchmark:


def parametrize(
*parameters: dict[str, Any],
parameters: Iterable[dict[str, Any]],
setUp: Callable[..., None] = NoOp,
tearDown: Callable[..., None] = NoOp,
tags: tuple[str, ...] = (),
) -> list[Benchmark] | Callable[[Callable], list[Benchmark]]:
) -> Callable[[Callable], list[Benchmark]]:
"""
Define a family of benchmarks over a function with varying parameters.
Expand All @@ -116,7 +116,7 @@ def parametrize(
Parameters
----------
*parameters: dict[str, Any]
parameters: Iterable[dict[str, Any]]
The different sets of parameters defining the benchmark family.
setUp: Callable[..., None]
A setup hook to run before each of the benchmarks.
Expand All @@ -127,9 +127,8 @@ def parametrize(
Returns
-------
list[Benchmark] | Callable[[Callable], list[Benchmark]]
The resulting benchmark family (if no arguments were given), or a parametrized decorator
returning the benchmark family.
Callable[[Callable], list[Benchmark]]
A parametrized decorator returning the benchmark family.
"""

def decorator(fn: Callable) -> list[Benchmark]:
Expand All @@ -150,7 +149,7 @@ def product(
tearDown: Callable[..., None] = NoOp,
tags: tuple[str, ...] = (),
**iterables: Iterable,
) -> list[Benchmark] | Callable[[Callable], list[Benchmark]]:
) -> Callable[[Callable], list[Benchmark]]:
"""
Define a family of benchmarks over a cartesian product of one or more iterables.
Expand All @@ -171,9 +170,8 @@ def product(
Returns
-------
list[Benchmark] | Callable[[Callable], list[Benchmark]]
The resulting benchmark family (if no arguments were given), or a parametrized decorator
returning the benchmark family.
Callable[[Callable], list[Benchmark]]
A parametrized decorator returning the benchmark family.
"""

def decorator(fn: Callable) -> list[Benchmark]:
Expand Down

0 comments on commit d839f62

Please sign in to comment.