Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: balanced clustering #2999

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions python/python/lance/cuvs/kmeans.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ def __init__(
seed: Optional[int] = None,
device: Optional[str] = None,
itopk_size: int = 10,
balance_factor: Optional[float] = None,
cluster_counts: Optional[torch.Tensor] = None,
):
if metric == "dot":
raise ValueError(
Expand All @@ -70,6 +72,8 @@ def __init__(
centroids=centroids,
seed=seed,
device=device,
balance_factor=balance_factor,
cluster_counts=cluster_counts,
)

if self.device.type != "cuda" or not torch.cuda.is_available():
Expand All @@ -95,9 +99,13 @@ def fit(
logging.info("Total rebuild time: %s", self.time_rebuild)

def rebuild_index(self):
centroids = self.centroids
if self.balance_factor is not None:
self.pad_centroids()

rebuild_time_start = time.time()
cagra_metric = "sqeuclidean"
dim = self.centroids.shape[1]
dim = centroids.shape[1]
graph_degree = max(dim // 4, 32)
nn_descent_degree = graph_degree * 2
index_params = cagra.IndexParams(
Expand All @@ -107,7 +115,7 @@ def rebuild_index(self):
build_algo="nn_descent",
compression=None,
)
self.index = cagra.build(index_params, self.centroids)
self.index = cagra.build(index_params, centroids)
rebuild_time_end = time.time()
self.time_rebuild += rebuild_time_end - rebuild_time_start

Expand All @@ -121,6 +129,9 @@ def _transform(
if self.metric == "cosine":
data = torch.nn.functional.normalize(data)

if self.padded_centroids is not None:
data = self.pad_data(data)

search_time_start = time.time()
device = torch.device("cuda")
out_idx = raft_common.device_ndarray.empty((data.shape[0], 1), dtype="uint32")
Expand Down
6 changes: 6 additions & 0 deletions python/python/lance/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1472,6 +1472,7 @@ def create_index(
storage_options: Optional[Dict[str, str]] = None,
filter_nan: bool = True,
one_pass_ivfpq: bool = False,
balance_factor: Optional[float] = None,
**kwargs,
) -> LanceDataset:
"""Create index on column.
Expand Down Expand Up @@ -1534,6 +1535,9 @@ def create_index(
for nullable columns. Obtains a small speed boost.
one_pass_ivfpq: bool
Defaults to False. If enabled, index type must be "IVF_PQ". Reduces disk IO.
balance_factor: float, optional
A factor used to balance clusters. No balancing by default. 1 is often a
good value. Only enabled if using an accelerator.
kwargs :
Parameters passed to the index building process.

Expand Down Expand Up @@ -1683,6 +1687,7 @@ def create_index(
num_sub_vectors=num_sub_vectors,
batch_size=20480,
filter_nan=filter_nan,
balance_factor=balance_factor,
)
)
timers["ivf+pq_train:end"] = time.time()
Expand Down Expand Up @@ -1783,6 +1788,7 @@ def create_index(
metric,
accelerator,
filter_nan=filter_nan,
balance_factor=balance_factor,
)
timers["ivf_train:end"] = time.time()
ivf_train_time = timers["ivf_train:end"] - timers["ivf_train:start"]
Expand Down
77 changes: 63 additions & 14 deletions python/python/lance/torch/kmeans.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ def __init__(
centroids: Optional[torch.Tensor] = None,
seed: Optional[int] = None,
device: Optional[str] = None,
balance_factor: Optional[float] = None,
cluster_counts: Optional[torch.Tensor] = None,
):
self.k = k
self.max_iters = max_iters
Expand All @@ -82,6 +84,13 @@ def __init__(
self.device = preferred_device(device)
self.tolerance = tolerance
self.seed = seed
self.balance_factor = balance_factor
self.padded_centroids = None

if cluster_counts is None:
self.counts = torch.zeros(k, device=self.device)
else:
self.counts = cluster_counts

self.y2 = None

Expand Down Expand Up @@ -169,14 +178,12 @@ def fit(
logging.debug("Total distance: %s, iter: %s", self.total_distance, i)
logging.info("Finish KMean training in %s", time.time() - start)

def _updated_centroids(
self, centroids: torch.Tensor, counts: torch.Tensor
) -> torch.Tensor:
centroids = centroids / counts[:, None]
zero_counts = counts == 0
def _updated_centroids(self, centroids: torch.Tensor) -> torch.Tensor:
centroids = centroids / self.counts[:, None]
zero_counts = self.counts == 0
for idx in zero_counts.nonzero(as_tuple=False):
# split the largest cluster and remove empty cluster
max_idx = torch.argmax(counts).item()
max_idx = torch.argmax(self.counts).item()
# add 1% gassuian noise to the largest centroid
# do this twice so we effectively split the largest cluster into 2
# rand_like returns on [0, 1) so we need to shift it to [-0.5, 0.5)
Expand Down Expand Up @@ -229,9 +236,9 @@ def _fit_once(
new_centroids = torch.zeros_like(
self.centroids, device=self.device, dtype=torch.float32
)
counts_per_part = torch.zeros(self.centroids.shape[0], device=self.device)
ones = torch.ones(1024 * 16, device=self.device)
self.rebuild_index()
self.counts = torch.zeros(self.k, device=self.device)
ones = torch.ones(1024 * 16, device=self.device)
for idx, chunk in enumerate(data):
if idx % 50 == 0:
logging.info("Kmeans::train: epoch %s, chunk %s", epoch, idx)
Expand All @@ -253,7 +260,7 @@ def _fit_once(
ones = torch.ones(len(ids), out=ones, device=self.device)

new_centroids.index_add_(0, ids, chunk.type(torch.float32))
counts_per_part.index_add_(0, ids, ones[: ids.shape[0]])
self.counts.index_add_(0, ids, ones[: ids.shape[0]])
del ids
del dists
del chunk
Expand All @@ -274,13 +281,50 @@ def _fit_once(
raise StopIteration("kmeans: converged")

# cast to the type we get the data in
self.centroids = self._updated_centroids(new_centroids, counts_per_part).type(
dtype
)
self.centroids = self._updated_centroids(new_centroids).type(dtype)
return total_dist

def pad_centroids(self):
if self.metric == "dot":
self.padded_centroids = torch.cat(
[
self.centroids,
-(self.balance_factor * self.counts).unsqueeze(1),
],
dim=1,
)
else:
self.padded_centroids = torch.cat(
[
self.centroids,
torch.sqrt(self.balance_factor * self.counts).unsqueeze(1),
],
dim=1,
)
self.y2 = (self.padded_centroids * self.padded_centroids).sum(dim=1)

def rebuild_index(self):
self.y2 = (self.centroids * self.centroids).sum(dim=1)
if self.balance_factor is not None:
self.pad_centroids()

def pad_data(self, data):
if self.metric == "dot":
return torch.cat(
[
data,
torch.ones(data.size(0), 1, device=data.device, dtype=data.dtype),
],
dim=1,
)
else:
return torch.cat(
[
data,
torch.zeros(data.size(0), 1, device=data.device, dtype=data.dtype),
],
dim=1,
)

def _transform(
self,
Expand All @@ -290,10 +334,15 @@ def _transform(
if self.metric == "cosine":
data = torch.nn.functional.normalize(data)

centroids = self.centroids
if self.padded_centroids is not None:
centroids = self.padded_centroids
data = self.pad_data(data)

if self.metric in ["l2", "cosine"]:
return self.dist_func(data, self.centroids, y2=y2)
return self.dist_func(data, centroids, y2=y2)
else:
return self.dist_func(data, self.centroids)
return self.dist_func(data, centroids)

def transform(
self, data: Union[pa.Array, np.ndarray, torch.Tensor]
Expand Down
11 changes: 11 additions & 0 deletions python/python/lance/vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ def train_ivf_centroids_on_accelerator(
sample_rate: int = 256,
max_iters: int = 50,
filter_nan: bool = True,
balance_factor: Optional[float] = None,
) -> (np.ndarray, Any):
"""Use accelerator (GPU or MPS) to train kmeans."""

Expand Down Expand Up @@ -277,6 +278,7 @@ def train_ivf_centroids_on_accelerator(
metric=metric_type,
device="cuda",
centroids=init_centroids,
balance_factor=balance_factor,
)
else:
logging.info("Training IVF partitions using GPU(%s)", accelerator)
Expand All @@ -286,15 +288,22 @@ def train_ivf_centroids_on_accelerator(
metric=metric_type,
device=accelerator,
centroids=init_centroids,
balance_factor=balance_factor,
)
kmeans.fit(ds)

centroids = kmeans.centroids.cpu().numpy()
counts = kmeans.counts.cpu().numpy()

with tempfile.NamedTemporaryFile(delete=False) as f:
np.save(f, centroids)
logging.info("Saved centroids to %s", f.name)

if balance_factor is not None:
with tempfile.NamedTemporaryFile(delete=False) as f:
np.save(f, counts)
logging.info("Saved cluster counts to %s", f.name)

return centroids, kmeans


Expand Down Expand Up @@ -598,6 +607,7 @@ def one_pass_train_ivf_pq_on_accelerator(
sample_rate: int = 256,
max_iters: int = 50,
filter_nan: bool = True,
balance_factor: Optional[float] = None,
):
centroids, kmeans = train_ivf_centroids_on_accelerator(
dataset,
Expand All @@ -609,6 +619,7 @@ def one_pass_train_ivf_pq_on_accelerator(
sample_rate=sample_rate,
max_iters=max_iters,
filter_nan=filter_nan,
balance_factor=balance_factor,
)
dataset_residuals = compute_partitions(
dataset,
Expand Down