diff --git a/CHANGELOG.md b/CHANGELOG.md index 43fbf11a0..1cd8b507e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,9 +19,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `RPCholesky.reduce` in `coreax.solvers.coresubset` now computes the iteration step correctly. (https://github.com/gchq/coreax/pull/825) - `RPCholesky.reduce` in `coreax.solvers.coresubset` now does not produce duplicate - points in the coreset. (https://github.com/gchq/coreax/pull/836) + points in the coreset.(https://github.com/gchq/coreax/pull/836) +- Fixed the example `examples.david_map_reduce_weighted` to prevent errors when + downsampling is enabled, and to make it run faster.(https://github.com/gchq/coreax/pull/821) - Build includes sub-packages. (https://github.com/gchq/coreax/pull/845) + ### Changed - diff --git a/examples/david_map_reduce_weighted.py b/examples/david_map_reduce_weighted.py index 44cdbe151..f6893adba 100644 --- a/examples/david_map_reduce_weighted.py +++ b/examples/david_map_reduce_weighted.py @@ -44,7 +44,6 @@ import jax.numpy as jnp import matplotlib.pyplot as plt import numpy as np -from flax import linen from jax import random from coreax import ( @@ -104,32 +103,51 @@ def main( if out_path is not None and not out_path.is_absolute(): out_path = Path(__file__).parent.joinpath(out_path) + def downsample_opencv(image_path: str, downsampling_factor: int) -> np.ndarray: + """ + Downsample an image using `func: cv2.resize` and convert it to grayscale. + + :param image_path: Path to the input image file. + :param downsampling_factor: Factor by which to downsample the image. + :return: Grayscale image after downsampling. + """ + img = cv2.imread(image_path) + + # Calculate new dimensions based on downsampling factor + scale_factor = 1 / downsampling_factor + width = int(img.shape[1] * scale_factor) + height = int(img.shape[0] * scale_factor) + dim = (width, height) + + # Resize using INTER_AREA for better downsampling + resized = cv2.resize(img, dim, interpolation=cv2.INTER_AREA) + + # Convert to grayscale after resizing + grayscale_resized = cv2.cvtColor(resized, cv2.COLOR_BGR2GRAY) + + return grayscale_resized + # Path to original image - original_data = cv2.imread(str(in_path)) - image_data = np.asarray(cv2.cvtColor(original_data, cv2.COLOR_BGR2GRAY)) - # Pool/downsample the image - window_shape = (downsampling_factor, downsampling_factor) - pooled_image_data = linen.avg_pool( - image_data[..., None], window_shape, strides=window_shape - )[..., 0] - block_size = 1_000 // downsampling_factor - - print(f"Image dimensions: {pooled_image_data.shape}") - pre_coreset_data = np.column_stack(np.nonzero(pooled_image_data < MAX_8BIT)) - pixel_values = pooled_image_data[pooled_image_data < MAX_8BIT] + original_data = downsample_opencv(str(in_path), downsampling_factor) + + block_size = 1_000 // (downsampling_factor**2) + + print(f"Image dimensions: {original_data.shape}") + pre_coreset_data = np.column_stack(np.nonzero(original_data < MAX_8BIT)) + pixel_values = original_data[original_data < MAX_8BIT] pre_coreset_data = np.column_stack((pre_coreset_data, pixel_values)).astype( np.float32 ) num_data_points = pre_coreset_data.shape[0] # Request coreset points - coreset_size = 8_000 // downsampling_factor + coreset_size = 8_000 // (downsampling_factor**2) # Setup the original data object data = Data(pre_coreset_data) # Set the length_scale parameter of the kernel from at most 1000 samples - num_samples_length_scale = min(num_data_points, 1000 // downsampling_factor) + num_samples_length_scale = min(num_data_points, 1000 // (downsampling_factor**2)) random_seed = 1_989 generator = np.random.default_rng(random_seed) idx = generator.choice(num_data_points, num_samples_length_scale, replace=False) @@ -159,10 +177,10 @@ def main( herding_solver = KernelHerding( coreset_size, kernel=herding_kernel, - block_size=1_000 // downsampling_factor, + block_size=block_size, ) mapped_herding_solver = MapReduce( - herding_solver, leaf_size=10_000 // downsampling_factor + herding_solver, leaf_size=16_000 // (downsampling_factor**2) ) herding_coreset, _ = eqx.filter_jit(mapped_herding_solver.reduce)(data) herding_weights = weights_optimiser.solve(data, herding_coreset.coreset) @@ -192,11 +210,44 @@ def main( print(f"Random sampling coreset MMD: {random_mmd}") print(f"Herding coreset MMD: {herding_mmd}") + def transform_marker_size( + weights: np.ndarray, + scale_factor: int = 15, + min_size: int = 4 * downsampling_factor, + ) -> np.ndarray: + """ + Transform coreset weights to marker sizes for plotting. + + :param weights: Array of coreset weights to be transformed. + :param scale_factor: Ratio of the largest and the smallest marker sizes. + :param min_size: Smallest marker size. + :return: Array of transformed marker sizes for plotting. + """ + # Define threshold percentiles + lower_percentile, upper_percentile = 1, 99 + + # Clip weights to reduce the effect of outliers + clipped_weights = np.clip( + weights, + np.percentile(weights, lower_percentile), + np.percentile(weights, upper_percentile), + ) + + # Normalize weights to a [0, 1] range + normalized_weights = (clipped_weights - clipped_weights.min()) / ( + clipped_weights.max() - clipped_weights.min() + ) + + # Apply exponential scaling to get the desired spread + transformed_sizes = min_size + (scale_factor**normalized_weights - 1) * min_size + + return transformed_sizes + print("Plotting") # Plot the pre-coreset image plt.figure(figsize=(10, 5)) plt.subplot(1, 3, 1) - plt.imshow(pooled_image_data, cmap="gray") + plt.imshow(original_data, cmap="gray") plt.title("Pre-Coreset") plt.axis("off") @@ -208,7 +259,7 @@ def main( -herding_coreset.coreset.data[:, 0], c=herding_coreset.coreset.data[:, 2], cmap="gray", - s=np.exp(2.0 * coreset_size * herding_weights).reshape(1, -1), + s=(transform_marker_size(herding_weights)).reshape(1, -1), marker="h", alpha=0.8, ) @@ -222,7 +273,7 @@ def main( random_coreset.coreset.data[:, 1], -random_coreset.coreset.data[:, 0], c=random_coreset.coreset.data[:, 2], - s=1.0, + s=25 * downsampling_factor, cmap="gray", marker="h", alpha=0.8, @@ -248,4 +299,4 @@ def main( if __name__ == "__main__": - main() + main(out_path=Path("data/david_coreset_2.png"))