Skip to content

Commit

Permalink
Optimize memory consumption of semi-hard mining (#186)
Browse files Browse the repository at this point in the history
* Optimize memory consumption of semi-hard mining

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Bump version

* Upd depts

* Upd depts

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
monatis and pre-commit-ci[bot] authored Dec 8, 2022
1 parent 60ea692 commit 9537389
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 69 deletions.
17 changes: 7 additions & 10 deletions examples/train_cifar100_with_xbm.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
"""TODO
The XBM feature is still experimental and should not be considered as production ready.
Currently, its usage of memory grows exponentially,
and this limits its ability to be used with larger batch and/or buffer sizes.
This will be fixed in an update soon.
Take this example as a demonstration of the API for now.
"""NOTE:
This sample script is design to fit in the memory of a 4GB GPU.
You can increase the XBM buffer size that can fit your GPU's memory
to get most out of the XBM feature.
"""

import argparse
Expand Down Expand Up @@ -53,7 +51,7 @@ def get_dataloader():
dataset = SimilarityGroupDataset(
datasets.CIFAR100(root=path, download=True, transform=transform)
)
dataloader = GroupSimilarityDataLoader(dataset, batch_size=128, shuffle=True)
dataloader = GroupSimilarityDataLoader(dataset, batch_size=64, shuffle=True)
return dataloader


Expand All @@ -78,10 +76,9 @@ def forward(self, images):


class Model(TrainableModel):
def __init__(self, embedding_size: int, lr: float, mining: str):
def __init__(self, embedding_size: int, lr: float):
self._embedding_size = embedding_size
self._lr = lr
self._mining = mining
super().__init__()

def configure_encoders(self) -> Union[Encoder, Dict[str, Encoder]]:
Expand All @@ -94,7 +91,7 @@ def configure_loss(self) -> SimilarityLoss:
return TripletLoss(mining="semi_hard")

def configure_xbm(self) -> XbmConfig:
return XbmConfig(buffer_size=1024)
return XbmConfig(buffer_size=2048)

def configure_metrics(self) -> Union[AttachedMetric, List[AttachedMetric]]:
return AttachedMetric(
Expand Down
2 changes: 1 addition & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "quaterion"
version = "0.1.32"
version = "0.1.33"
description = "Similarity Learning fine-tuning framework"
authors = ["Quaterion Authors <[email protected]>"]
packages = [
Expand Down
74 changes: 17 additions & 57 deletions quaterion/loss/triplet_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,73 +138,33 @@ def _semi_hard_triplet_loss(
Returns:
Tensor: zero-size tensor, XBM loss value.
"""
# compute the distance matrix
# shape: (batch_size_a, batch_size_b)
dists = self.distance_metric.distance_matrix(embeddings_a, embeddings_b)

# compute masks to express the positive and negative pairs
# Calculate the pairwise distances between all embeddings
# shape: (batch_size_a, batch_size_b)
groups_a = groups_a.unsqueeze(1)
anchor_positive_pairs = groups_a == groups_b.unsqueeze(1).t()
anchor_negative_pairs = ~anchor_positive_pairs

batch_size_a = torch.numel(groups_a)
batch_size_b = torch.numel(groups_b)

# compute the mask to express the semi-hard-negatives
# WARNING: `torch.repeat()` copies the underlying data
# so it consumes more memory
dists_tile = dists.repeat([batch_size_b, 1])
mask = anchor_negative_pairs.repeat([batch_size_b, 1]) & (
dists_tile > torch.reshape(dists.t(), [-1, 1])
)
distances = self.distance_metric.distance_matrix(embeddings_a, embeddings_b)

mask_final = torch.reshape(
torch.sum(mask, 1, keepdims=True) > 0.0, [batch_size_b, batch_size_a]
)
mask_final = mask_final.t()
# Find the indices of all positive and negative pairs
positive_indices = groups_a[:, None] == groups_b[None, :]
negative_indices = groups_a[:, None] != groups_b[None, :]

# negatives_outside: smallest D(a, n) where D(a, n) > D(a, p).
negatives_outside = torch.reshape(
get_masked_minimum(dists_tile, mask), [batch_size_b, batch_size_a]
)
negatives_outside = negatives_outside.t()
# Calculate the distance between the anchor and positive examples
pos_distance = torch.masked_select(distances, positive_indices)

# negatives_inside: largest D(a, n).
negatives_inside = get_masked_maximum(dists, anchor_negative_pairs)
negatives_inside = negatives_inside.repeat([1, batch_size_b])
# Calculate the distance between the anchor and negative examples
neg_distance = torch.masked_select(distances, negative_indices)

# select either semi-hard negative or the largest negative
# based on the condition the mask previously computed
semi_hard_negatives = torch.where(
mask_final, negatives_outside, negatives_inside
)
# Calculate the basic triplet loss
basic_loss = pos_distance[:, None] - neg_distance[None, :] + self._margin

loss_matrix = (dists - semi_hard_negatives) + self._margin
# Zero out the loss for negative distances larger than the positive distance
zero_loss = torch.clamp(basic_loss, min=0.0)

# the paper takes all the positives accept the diagonal
# this is only relevant where it's running for the regular loss
mask_positives = (
anchor_positive_pairs.float()
- torch.eye(batch_size_a, device=groups_a.device)
if torch.allclose(groups_a, groups_b)
else anchor_positive_pairs
)
# Zero out the loss for distances larger than the margin
semi_hard_loss = torch.clamp(zero_loss, max=self._margin)

# average by the number of positives
num_positives = torch.sum(mask_positives)
loss = torch.mean(semi_hard_loss)

triplet_loss = (
torch.sum(
torch.max(
loss_matrix * mask_positives,
torch.tensor([0.0], device=groups_a.device),
)
)
/ num_positives
)

return triplet_loss
return loss

def forward(
self,
Expand Down

0 comments on commit 9537389

Please sign in to comment.