diff --git a/examples/train_cifar100_with_xbm.py b/examples/train_cifar100_with_xbm.py index 74009081..344e859e 100644 --- a/examples/train_cifar100_with_xbm.py +++ b/examples/train_cifar100_with_xbm.py @@ -1,9 +1,7 @@ -"""TODO -The XBM feature is still experimental and should not be considered as production ready. -Currently, its usage of memory grows exponentially, -and this limits its ability to be used with larger batch and/or buffer sizes. -This will be fixed in an update soon. -Take this example as a demonstration of the API for now. +"""NOTE: +This sample script is design to fit in the memory of a 4GB GPU. +You can increase the XBM buffer size that can fit your GPU's memory + to get most out of the XBM feature. """ import argparse @@ -53,7 +51,7 @@ def get_dataloader(): dataset = SimilarityGroupDataset( datasets.CIFAR100(root=path, download=True, transform=transform) ) - dataloader = GroupSimilarityDataLoader(dataset, batch_size=128, shuffle=True) + dataloader = GroupSimilarityDataLoader(dataset, batch_size=64, shuffle=True) return dataloader @@ -78,10 +76,9 @@ def forward(self, images): class Model(TrainableModel): - def __init__(self, embedding_size: int, lr: float, mining: str): + def __init__(self, embedding_size: int, lr: float): self._embedding_size = embedding_size self._lr = lr - self._mining = mining super().__init__() def configure_encoders(self) -> Union[Encoder, Dict[str, Encoder]]: @@ -94,7 +91,7 @@ def configure_loss(self) -> SimilarityLoss: return TripletLoss(mining="semi_hard") def configure_xbm(self) -> XbmConfig: - return XbmConfig(buffer_size=1024) + return XbmConfig(buffer_size=2048) def configure_metrics(self) -> Union[AttachedMetric, List[AttachedMetric]]: return AttachedMetric( diff --git a/poetry.lock b/poetry.lock index 5956ca82..d1cd5d56 100644 --- a/poetry.lock +++ b/poetry.lock @@ -824,7 +824,7 @@ resolved_reference = "a90cdd5925783c2b0ed3b8d39897cd4eaf942e2a" [[package]] name = "quaterion-models" -version = "0.1.17" +version = "0.1.19" description = "The collection of building blocks to build fine-tunable similarity learning models" category = "main" optional = false diff --git a/pyproject.toml b/pyproject.toml index 69e52f23..cc84698f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "quaterion" -version = "0.1.32" +version = "0.1.33" description = "Similarity Learning fine-tuning framework" authors = ["Quaterion Authors "] packages = [ diff --git a/quaterion/loss/triplet_loss.py b/quaterion/loss/triplet_loss.py index 0aaa709f..13b70319 100644 --- a/quaterion/loss/triplet_loss.py +++ b/quaterion/loss/triplet_loss.py @@ -138,73 +138,33 @@ def _semi_hard_triplet_loss( Returns: Tensor: zero-size tensor, XBM loss value. """ - # compute the distance matrix - # shape: (batch_size_a, batch_size_b) - dists = self.distance_metric.distance_matrix(embeddings_a, embeddings_b) - # compute masks to express the positive and negative pairs + # Calculate the pairwise distances between all embeddings # shape: (batch_size_a, batch_size_b) - groups_a = groups_a.unsqueeze(1) - anchor_positive_pairs = groups_a == groups_b.unsqueeze(1).t() - anchor_negative_pairs = ~anchor_positive_pairs - - batch_size_a = torch.numel(groups_a) - batch_size_b = torch.numel(groups_b) - - # compute the mask to express the semi-hard-negatives - # WARNING: `torch.repeat()` copies the underlying data - # so it consumes more memory - dists_tile = dists.repeat([batch_size_b, 1]) - mask = anchor_negative_pairs.repeat([batch_size_b, 1]) & ( - dists_tile > torch.reshape(dists.t(), [-1, 1]) - ) + distances = self.distance_metric.distance_matrix(embeddings_a, embeddings_b) - mask_final = torch.reshape( - torch.sum(mask, 1, keepdims=True) > 0.0, [batch_size_b, batch_size_a] - ) - mask_final = mask_final.t() + # Find the indices of all positive and negative pairs + positive_indices = groups_a[:, None] == groups_b[None, :] + negative_indices = groups_a[:, None] != groups_b[None, :] - # negatives_outside: smallest D(a, n) where D(a, n) > D(a, p). - negatives_outside = torch.reshape( - get_masked_minimum(dists_tile, mask), [batch_size_b, batch_size_a] - ) - negatives_outside = negatives_outside.t() + # Calculate the distance between the anchor and positive examples + pos_distance = torch.masked_select(distances, positive_indices) - # negatives_inside: largest D(a, n). - negatives_inside = get_masked_maximum(dists, anchor_negative_pairs) - negatives_inside = negatives_inside.repeat([1, batch_size_b]) + # Calculate the distance between the anchor and negative examples + neg_distance = torch.masked_select(distances, negative_indices) - # select either semi-hard negative or the largest negative - # based on the condition the mask previously computed - semi_hard_negatives = torch.where( - mask_final, negatives_outside, negatives_inside - ) + # Calculate the basic triplet loss + basic_loss = pos_distance[:, None] - neg_distance[None, :] + self._margin - loss_matrix = (dists - semi_hard_negatives) + self._margin + # Zero out the loss for negative distances larger than the positive distance + zero_loss = torch.clamp(basic_loss, min=0.0) - # the paper takes all the positives accept the diagonal - # this is only relevant where it's running for the regular loss - mask_positives = ( - anchor_positive_pairs.float() - - torch.eye(batch_size_a, device=groups_a.device) - if torch.allclose(groups_a, groups_b) - else anchor_positive_pairs - ) + # Zero out the loss for distances larger than the margin + semi_hard_loss = torch.clamp(zero_loss, max=self._margin) - # average by the number of positives - num_positives = torch.sum(mask_positives) + loss = torch.mean(semi_hard_loss) - triplet_loss = ( - torch.sum( - torch.max( - loss_matrix * mask_positives, - torch.tensor([0.0], device=groups_a.device), - ) - ) - / num_positives - ) - - return triplet_loss + return loss def forward( self,