Skip to content

Commit

Permalink
Merge pull request #821 from gchq/bugfix/David-MarkerSize-Bug
Browse files Browse the repository at this point in the history
:fix: correct plot scaling calculation and leaf parameters
  • Loading branch information
rg936672 authored Nov 7, 2024
2 parents 567cf7d + 4463e60 commit 272a30c
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 22 deletions.
5 changes: 4 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `RPCholesky.reduce` in `coreax.solvers.coresubset` now computes the iteration step
correctly. (https://github.com/gchq/coreax/pull/825)
- `RPCholesky.reduce` in `coreax.solvers.coresubset` now does not produce duplicate
points in the coreset. (https://github.com/gchq/coreax/pull/836)
points in the coreset.(https://github.com/gchq/coreax/pull/836)
- Fixed the example `examples.david_map_reduce_weighted` to prevent errors when
downsampling is enabled, and to make it run faster.(https://github.com/gchq/coreax/pull/821)
- Build includes sub-packages. (https://github.com/gchq/coreax/pull/845)


### Changed

-
Expand Down
93 changes: 72 additions & 21 deletions examples/david_map_reduce_weighted.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
from flax import linen
from jax import random

from coreax import (
Expand Down Expand Up @@ -104,32 +103,51 @@ def main(
if out_path is not None and not out_path.is_absolute():
out_path = Path(__file__).parent.joinpath(out_path)

def downsample_opencv(image_path: str, downsampling_factor: int) -> np.ndarray:
"""
Downsample an image using `func: cv2.resize` and convert it to grayscale.
:param image_path: Path to the input image file.
:param downsampling_factor: Factor by which to downsample the image.
:return: Grayscale image after downsampling.
"""
img = cv2.imread(image_path)

# Calculate new dimensions based on downsampling factor
scale_factor = 1 / downsampling_factor
width = int(img.shape[1] * scale_factor)
height = int(img.shape[0] * scale_factor)
dim = (width, height)

# Resize using INTER_AREA for better downsampling
resized = cv2.resize(img, dim, interpolation=cv2.INTER_AREA)

# Convert to grayscale after resizing
grayscale_resized = cv2.cvtColor(resized, cv2.COLOR_BGR2GRAY)

return grayscale_resized

# Path to original image
original_data = cv2.imread(str(in_path))
image_data = np.asarray(cv2.cvtColor(original_data, cv2.COLOR_BGR2GRAY))
# Pool/downsample the image
window_shape = (downsampling_factor, downsampling_factor)
pooled_image_data = linen.avg_pool(
image_data[..., None], window_shape, strides=window_shape
)[..., 0]
block_size = 1_000 // downsampling_factor

print(f"Image dimensions: {pooled_image_data.shape}")
pre_coreset_data = np.column_stack(np.nonzero(pooled_image_data < MAX_8BIT))
pixel_values = pooled_image_data[pooled_image_data < MAX_8BIT]
original_data = downsample_opencv(str(in_path), downsampling_factor)

block_size = 1_000 // (downsampling_factor**2)

print(f"Image dimensions: {original_data.shape}")
pre_coreset_data = np.column_stack(np.nonzero(original_data < MAX_8BIT))
pixel_values = original_data[original_data < MAX_8BIT]
pre_coreset_data = np.column_stack((pre_coreset_data, pixel_values)).astype(
np.float32
)
num_data_points = pre_coreset_data.shape[0]

# Request coreset points
coreset_size = 8_000 // downsampling_factor
coreset_size = 8_000 // (downsampling_factor**2)

# Setup the original data object
data = Data(pre_coreset_data)

# Set the length_scale parameter of the kernel from at most 1000 samples
num_samples_length_scale = min(num_data_points, 1000 // downsampling_factor)
num_samples_length_scale = min(num_data_points, 1000 // (downsampling_factor**2))
random_seed = 1_989
generator = np.random.default_rng(random_seed)
idx = generator.choice(num_data_points, num_samples_length_scale, replace=False)
Expand Down Expand Up @@ -159,10 +177,10 @@ def main(
herding_solver = KernelHerding(
coreset_size,
kernel=herding_kernel,
block_size=1_000 // downsampling_factor,
block_size=block_size,
)
mapped_herding_solver = MapReduce(
herding_solver, leaf_size=10_000 // downsampling_factor
herding_solver, leaf_size=16_000 // (downsampling_factor**2)
)
herding_coreset, _ = eqx.filter_jit(mapped_herding_solver.reduce)(data)
herding_weights = weights_optimiser.solve(data, herding_coreset.coreset)
Expand Down Expand Up @@ -192,11 +210,44 @@ def main(
print(f"Random sampling coreset MMD: {random_mmd}")
print(f"Herding coreset MMD: {herding_mmd}")

def transform_marker_size(
weights: np.ndarray,
scale_factor: int = 15,
min_size: int = 4 * downsampling_factor,
) -> np.ndarray:
"""
Transform coreset weights to marker sizes for plotting.
:param weights: Array of coreset weights to be transformed.
:param scale_factor: Ratio of the largest and the smallest marker sizes.
:param min_size: Smallest marker size.
:return: Array of transformed marker sizes for plotting.
"""
# Define threshold percentiles
lower_percentile, upper_percentile = 1, 99

# Clip weights to reduce the effect of outliers
clipped_weights = np.clip(
weights,
np.percentile(weights, lower_percentile),
np.percentile(weights, upper_percentile),
)

# Normalize weights to a [0, 1] range
normalized_weights = (clipped_weights - clipped_weights.min()) / (
clipped_weights.max() - clipped_weights.min()
)

# Apply exponential scaling to get the desired spread
transformed_sizes = min_size + (scale_factor**normalized_weights - 1) * min_size

return transformed_sizes

print("Plotting")
# Plot the pre-coreset image
plt.figure(figsize=(10, 5))
plt.subplot(1, 3, 1)
plt.imshow(pooled_image_data, cmap="gray")
plt.imshow(original_data, cmap="gray")
plt.title("Pre-Coreset")
plt.axis("off")

Expand All @@ -208,7 +259,7 @@ def main(
-herding_coreset.coreset.data[:, 0],
c=herding_coreset.coreset.data[:, 2],
cmap="gray",
s=np.exp(2.0 * coreset_size * herding_weights).reshape(1, -1),
s=(transform_marker_size(herding_weights)).reshape(1, -1),
marker="h",
alpha=0.8,
)
Expand All @@ -222,7 +273,7 @@ def main(
random_coreset.coreset.data[:, 1],
-random_coreset.coreset.data[:, 0],
c=random_coreset.coreset.data[:, 2],
s=1.0,
s=25 * downsampling_factor,
cmap="gray",
marker="h",
alpha=0.8,
Expand All @@ -248,4 +299,4 @@ def main(


if __name__ == "__main__":
main()
main(out_path=Path("data/david_coreset_2.png"))

0 comments on commit 272a30c

Please sign in to comment.