Skip to content

Commit

Permalink
added functionality for generation of test case for BatchParallelKmeans
Browse files Browse the repository at this point in the history
  • Loading branch information
Hoppe committed Dec 12, 2023
1 parent 8e8390b commit 2a2c8a9
Show file tree
Hide file tree
Showing 2 changed files with 191 additions and 0 deletions.
107 changes: 107 additions & 0 deletions heat/utils/data/spherical.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Create a sperical dataset."""
import heat as ht
import torch


def create_spherical_dataset(
Expand Down Expand Up @@ -50,3 +51,109 @@ def create_spherical_dataset(
data = ht.concatenate((cluster1, cluster2, cluster3, cluster4), axis=0)
# Note: enhance when shuffel is available
return data


def create_clusters(
n_samples, n_features, n_clusters, cluster_mean, cluster_std, cluster_weight=None, device=None
):
"""
Creates a DNDarray of shape (n_samples, n_features), split=0, and dtype=ht.float32, that is balanced (i.e. roughly same size of samples on each process).
The data set consists of n_clusters clusters, each of which is sampled from a multivariate normal distribution with mean cluster_mean[k,:] and covariance matrix cluster_std[k,:,:].
The clusters are of the same size (quantitatively) and distributed evenly over the processes, unless cluster_weight is specified.
Parameters
------------
n_samples: int
Number of overall samples
n_features: int
Number of features
n_clusters: int
Number of clusters
cluster_mean: torch.Tensor of shape (n_clusters, n_features)
featurewise mean (center) of each cluster; of course not the true mean, but rather the mean according to which the elements of the cluster are sampled.
cluster_std: torch.Tensor of shape (n_clusters, n_features, n_features), or (n_clusters,)
featurewise standard deviation of each cluster from the mean value; of course not the true std, but rather the std according to which the elements of the cluster are sampled.
If shape is (n_clusters,), std is assumed to be the same in each direction for each cluster
cluster_weight: torch.Tensor of shape (n_clusters,), optional
On each process, cluster_weight is assumed to be a torch.Tensor whose entries add up to 1. The i-th entry of cluster_weight on process p specified which amount of the samples on process p
is sampled according to the distribution of cluster i. Thus, this parameter allows to distribute the n_cluster clusters unevenly over the processes.
If None, each cluster is distributed evenly over all processes.
device: Optional[str] = None,
The device on which the data is stored. If None, the default device is used.
"""
device = ht.devices.sanitize_device(device)

if cluster_weight is None:
cluster_weight = torch.ones(n_clusters) / n_clusters
else:
if not isinstance(cluster_weight, torch.Tensor):
raise TypeError(
"cluster_weight must be None or a torch.Tensor, but is {}".format(
type(cluster_weight)
)
)
elif not cluster_weight.shape == (n_clusters,):
raise ValueError(
"If a torch.Tensor, cluster_weight must be of shape (n_clusters,), but is {}".format(
cluster_weight.shape
)
)
elif not torch.allclose(torch.sum(cluster_weight), torch.tensor(1.0)):
raise ValueError(
"If a torch.Tensor, cluster_weight must add up to 1, but adds up to {}".format(
torch.sum(cluster_weight)
)
)
if not isinstance(cluster_mean, torch.Tensor):
raise TypeError("cluster_mean must be a torch.Tensor, but is {}".format(type(cluster_mean)))
elif not cluster_mean.shape == (n_clusters, n_features):
raise ValueError(
"cluster_mean must be of shape (n_clusters, n_features), but is {}".format(
cluster_mean.shape
)
)
if not isinstance(cluster_std, torch.Tensor):
raise TypeError("cluster_std must be a torch.Tensor, but is {}".format(type(cluster_std)))
elif not cluster_std.shape == (
n_clusters,
n_features,
n_features,
) and not cluster_std.shape == (n_clusters,):
raise ValueError(
"cluster_std must be of shape (n_clusters, n_features, n_features) or (n_clusters,), but is {}".format(
cluster_std.shape
)
)
if cluster_std.shape == (n_clusters,):
cluster_std = torch.stack(
[torch.eye(n_features) * cluster_std[k] for k in range(n_clusters)], dim=0
)

global_shape = (n_samples, n_features)
local_shape = ht.MPI_WORLD.chunk(global_shape, 0)[1]
local_size_of_clusters = [int(local_shape[0] * cluster_weight[k]) for k in range(n_clusters)]
if sum(local_size_of_clusters) != local_shape[0]:
local_size_of_clusters[0] += local_shape[0] - sum(local_size_of_clusters)
distributions = [
torch.distributions.multivariate_normal.MultivariateNormal(
cluster_mean[k, :], cluster_std[k]
)
for k in range(n_clusters)
]
local_data = [
distributions[k].sample((local_size_of_clusters[k],)).to(device.torch_device)
for k in range(n_clusters)
]
local_data = torch.cat(local_data, dim=0)
rand_perm = torch.randperm(local_shape[0])
local_data = local_data[rand_perm, :]
data = ht.DNDarray(
local_data,
global_shape,
dtype=ht.float32,
split=0,
device=device,
comm=ht.MPI_WORLD,
balanced=True,
)
return data
84 changes: 84 additions & 0 deletions heat/utils/data/tests/test_spherical.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import heat as ht
import unittest
import torch
from heat.core.tests.test_suites.basic_test import TestCase


class TestCreateClusters(TestCase):
def test_create_cluster(self):
n_samples = ht.MPI_WORLD.size * 10 + 3
n_features = 3
n_clusters = ht.MPI_WORLD.size
cluster_mean = torch.arange(n_clusters, dtype=torch.float32).repeat(n_features, 1).T

# test case with uneven distribution of clusters over processes and variances given as vector
cluster_weight = torch.zeros(n_clusters)
cluster_weight[ht.MPI_WORLD.rank] += 0.5
cluster_weight[0] += 0.5
cluster_std = 0.01 * torch.ones(n_clusters)
data = ht.utils.data.spherical.create_clusters(
n_samples, n_features, n_clusters, cluster_mean, cluster_std, cluster_weight
)
self.assertEqual(data.shape, (n_samples, n_features))
self.assertEqual(data.dtype, ht.float32)

# test case with even distribution of clusters over processes and variances given as matrix
cluster_weight = None
cluster_std = 0.01 * torch.rand(n_clusters, n_features, n_features)
cluster_std = torch.transpose(cluster_std, 1, 2) @ cluster_std
data = ht.utils.data.spherical.create_clusters(
n_samples, n_features, n_clusters, cluster_mean, cluster_std, cluster_weight
)
self.assertEqual(data.shape, (n_samples, n_features))
self.assertEqual(data.dtype, ht.float32)

def test_if_errors_are_catched(self):
n_samples = ht.MPI_WORLD.size * 10 + 3
n_features = 3
n_clusters = ht.MPI_WORLD.size
cluster_mean = torch.arange(n_clusters, dtype=torch.float32).repeat(n_features, 1).T
cluster_std = 0.01 * torch.ones(n_clusters)

with self.assertRaises(TypeError):
ht.utils.data.spherical.create_clusters(
n_samples, n_features, n_clusters, "abc", cluster_std
)
with self.assertRaises(ValueError):
ht.utils.data.spherical.create_clusters(
n_samples, n_features, n_clusters, torch.zeros(2, 2), cluster_std
)
with self.assertRaises(TypeError):
ht.utils.data.spherical.create_clusters(
n_samples, n_features, n_clusters, cluster_mean, "abc"
)
with self.assertRaises(ValueError):
ht.utils.data.spherical.create_clusters(
n_samples, n_features, n_clusters, cluster_mean, torch.zeros(2, 2)
)
with self.assertRaises(TypeError):
ht.utils.data.spherical.create_clusters(
n_samples, n_features, n_clusters, cluster_mean, cluster_std, "abc"
)
with self.assertRaises(ValueError):
ht.utils.data.spherical.create_clusters(
n_samples,
n_features,
n_clusters,
cluster_mean,
cluster_std,
torch.ones(
n_clusters + 1,
),
)
with self.assertRaises(ValueError):
ht.utils.data.spherical.create_clusters(
n_samples,
n_features,
n_clusters,
cluster_mean,
cluster_std,
2
* torch.ones(
n_clusters,
),
)

0 comments on commit 2a2c8a9

Please sign in to comment.