Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Apr 3, 2023
1 parent d77fbe6 commit f6814a1
Show file tree
Hide file tree
Showing 7 changed files with 0 additions and 9 deletions.
1 change: 0 additions & 1 deletion examples/cars/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ def train(
shuffle: bool,
save_dir: str,
):

model = Model(
lr=lr,
mining=mining,
Expand Down
1 change: 0 additions & 1 deletion quaterion/dataset/similarity_data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ class SimilarityDataLoader(DataLoader, Generic[T_co]):
"""

def __init__(self, dataset: Dataset, **kwargs):

if "collate_fn" not in kwargs:
kwargs["collate_fn"] = self.__class__.pre_collate_fn
self._original_dataset = dataset
Expand Down
2 changes: 0 additions & 2 deletions quaterion/loss/cos_face_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ def __init__(
margin: Optional[float] = 0.35,
scale: Optional[float] = 64.0,
):

super(GroupLoss, self).__init__()

self.kernel = nn.Parameter(torch.FloatTensor(embedding_size, num_groups))
Expand All @@ -36,7 +35,6 @@ def __init__(
self.margin = margin

def forward(self, embeddings: Tensor, groups: LongTensor) -> Tensor:

"""Compute loss value
Args:
embeddings: shape: (batch_size, vector_length) - Output embeddings from the
Expand Down
2 changes: 0 additions & 2 deletions quaterion/loss/online_contrastive_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,6 @@ def forward(
)

if self._mining == "all":

num_positive_pairs = anchor_positive_mask.sum()
positive_loss = anchor_positive_dists.sum() / torch.max(
num_positive_pairs, torch.tensor(1e-16)
Expand All @@ -106,7 +105,6 @@ def forward(
).sum() / torch.max(num_negative_pairs, torch.tensor(1e-16))

else: # batch-hard pair mining

# get the hardest positive for each anchor
# shape: (batch_size,)
hardest_positive_dists = anchor_positive_dists.max(dim=1)[0]
Expand Down
1 change: 0 additions & 1 deletion quaterion/train/cache/cache_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ def __init__(
self,
encoders: Dict[str, CacheEncoder],
):

super().__init__()
self.encoders = encoders
for key, encoder in self.encoders.items():
Expand Down
1 change: 0 additions & 1 deletion quaterion/train/trainable_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,6 @@ def _maybe_compute_xbm_loss(
):
loss_obj = self.loss # Assign to tmp variable for better type inference
if isinstance(loss_obj, GroupLoss):

memory_embeddings, memory_groups = self._xbm_buffer.get()
memory_loss = loss_obj.xbm_loss(
embeddings, targets["groups"], memory_embeddings, memory_groups
Expand Down
1 change: 0 additions & 1 deletion tests/test_import.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,2 @@
def test_import_main_classes():

from quaterion import Quaterion, TrainableModel

0 comments on commit f6814a1

Please sign in to comment.