diff --git a/benchmark/blobs_benchmark.py b/benchmark/blobs_benchmark.py index e7ab76c6..68504d3a 100644 --- a/benchmark/blobs_benchmark.py +++ b/benchmark/blobs_benchmark.py @@ -39,7 +39,6 @@ from coreax import Data, SlicedScoreMatching from coreax.kernels import ( - ScalarValuedKernel, SquaredExponentialKernel, SteinKernel, median_heuristic, @@ -56,7 +55,7 @@ from coreax.weights import MMDWeightsOptimiser -def setup_kernel(x: jax.Array, random_seed: int = 45) -> ScalarValuedKernel: +def setup_kernel(x: jax.Array, random_seed: int = 45) -> SquaredExponentialKernel: """ Set up a squared exponential kernel using the median heuristic. @@ -73,7 +72,7 @@ def setup_kernel(x: jax.Array, random_seed: int = 45) -> ScalarValuedKernel: def setup_stein_kernel( - sq_exp_kernel: ScalarValuedKernel, dataset: Data, random_seed: int = 45 + sq_exp_kernel: SquaredExponentialKernel, dataset: Data, random_seed: int = 45 ) -> SteinKernel: """ Set up a Stein Kernel for Stein Thinning. @@ -99,7 +98,7 @@ def setup_stein_kernel( def setup_solvers( coreset_size: int, - sq_exp_kernel: ScalarValuedKernel, + sq_exp_kernel: SquaredExponentialKernel, stein_kernel: SteinKernel, delta: float, random_seed: int = 45, diff --git a/benchmark/mnist_benchmark.py b/benchmark/mnist_benchmark.py index 1c3af0cd..fc565395 100644 --- a/benchmark/mnist_benchmark.py +++ b/benchmark/mnist_benchmark.py @@ -52,7 +52,7 @@ import umap from flax import linen as nn from flax.training import train_state -from jaxtyping import Array, Float, Int +from jaxtyping import Array, Float from torch.utils.data import DataLoader, Dataset from torchvision import transforms @@ -430,7 +430,7 @@ def prepare_datasets() -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarr return train_data_jax, train_targets_jax, test_data_jax, test_targets_jax -def calculate_delta(n: Int[Array, "1"]) -> Float[Array, "1"]: +def calculate_delta(n: int) -> Float[Array, "1"]: """ Calculate the delta parameter for kernel thinning. @@ -451,7 +451,7 @@ def calculate_delta(n: Int[Array, "1"]) -> Float[Array, "1"]: if log_log_n > 0: return 1 / (n * log_log_n) return 1 / (n * log_n) - return 1 / n + return jnp.array(1 / n) def initialise_solvers( @@ -496,7 +496,7 @@ def _get_thinning_solver(_size: int) -> MapReduce: coreset_size=_size, kernel=kernel, random_key=key, - delta=calculate_delta(num_data_points), + delta=calculate_delta(num_data_points).item(), sqrt_kernel=sqrt_kernel, )