Skip to content

Commit

Permalink
cleaning + adding momentum_forward to be consistent
Browse files Browse the repository at this point in the history
  • Loading branch information
vturrisi authored Jan 8, 2024
1 parent 16ffb0e commit a0199db
Showing 1 changed file with 66 additions and 123 deletions.
189 changes: 66 additions & 123 deletions solo/methods/all4one.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2021 solo-learn development team.
# Copyright 2024 solo-learn development team.

# Permission is hereby granted, free of charge, to any person obtaining a copy of
# this software and associated documentation files (the "Software"), to deal in
Expand All @@ -17,35 +17,24 @@
# OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.

import argparse
from audioop import cross
import pickle
from typing import Any, Dict, List, Sequence, Tuple


import omegaconf
import torch
import torch.nn as nn
import torch.nn.functional as F
from solo.losses.nnclr import nnclr_loss_func
from solo.methods.base import BaseMethod, BaseMomentumMethod
from solo.methods.base import BaseMomentumMethod
from solo.utils.misc import gather, omegaconf_select



import math
from solo.utils.positional_encodings import PositionalEncodingPermute1D, PositionalEncoding2D, PositionalEncoding3D, Summer
from solo.utils.momentum import initialize_momentum_params
from solo.utils.positional_encodings import PositionalEncodingPermute1D, Summer


import pickle

class All4One(BaseMomentumMethod):
queue: torch.Tensor

def __init__(
self, cfg: omegaconf.DictConfig
):

def __init__(self, cfg: omegaconf.DictConfig):
super().__init__(cfg)

self.temperature: float = cfg.method_kwargs.temperature
Expand All @@ -67,7 +56,6 @@ def __init__(
nn.BatchNorm1d(proj_output_dim),
)


# momentum projector
self.momentum_projector = nn.Sequential(
nn.Linear(self.features_dim, proj_hidden_dim),
Expand All @@ -81,8 +69,6 @@ def __init__(
)
initialize_momentum_params(self.projector, self.momentum_projector)



# predictor
self.predictor = nn.Sequential(
nn.Linear(proj_output_dim, pred_hidden_dim),
Expand All @@ -100,22 +86,19 @@ def __init__(
)

# internal transformer
encoder_layer = nn.TransformerEncoderLayer(d_model=proj_output_dim,
nhead=8,
dim_feedforward=proj_output_dim * 2,
batch_first=True,
dropout=0.1)

encoder_layer = nn.TransformerEncoderLayer(
d_model=proj_output_dim,
nhead=8,
dim_feedforward=proj_output_dim * 2,
batch_first=True,
dropout=0.1,
)

self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=3)

# positional encoder
self.pos_enc = Summer(PositionalEncodingPermute1D(5))






# queue
self.register_buffer("queue", torch.randn(self.queue_size, proj_output_dim))
self.register_buffer("queue_y", -torch.ones(self.queue_size, dtype=torch.long))
Expand All @@ -124,7 +107,6 @@ def __init__(
# NN index queue
self.register_buffer("queue_index", -torch.ones(self.queue_size, dtype=torch.long))


@staticmethod
def add_and_assert_specific_cfg(cfg: omegaconf.DictConfig) -> omegaconf.DictConfig:
"""Adds method specific default values/checks for config.
Expand Down Expand Up @@ -159,9 +141,7 @@ def learnable_params(self) -> List[dict]:
{"params": self.projector.parameters()},
{"params": self.predictor.parameters()},
{"params": self.predictor2.parameters()},
{"params": self.transformer_encoder.parameters(), "lr": 0.1}


{"params": self.transformer_encoder.parameters(), "lr": 0.1},
]
return super().learnable_params + extra_learnable_params

Expand Down Expand Up @@ -191,7 +171,6 @@ def dequeue_and_enqueue(self, z: torch.Tensor, y: torch.Tensor, idx: torch.Tenso
y = gather(y)
idx = gather(idx)


batch_size = z.shape[0]

ptr = int(self.queue_ptr) # type: ignore
Expand All @@ -203,13 +182,10 @@ def dequeue_and_enqueue(self, z: torch.Tensor, y: torch.Tensor, idx: torch.Tenso
# NN indexes
self.queue_index[ptr : ptr + batch_size] = idx



ptr = (ptr + batch_size) % self.queue_size

self.queue_ptr[0] = ptr # type: ignore


@torch.no_grad()
def find_nn(self, z: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Finds the nearest neighbors of a sample.
Expand All @@ -230,20 +206,31 @@ def find_nn(self, z: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:

idxx = (z @ self.queue.T).max(dim=1)[1]


_, idx = (z @ self.queue.T).topk(5, dim=1)



nn = self.queue[idx]
nn_idx = self.queue_index[idx]
nn_lb = self.queue_y[idx]




return idxx, nn, nn_idx, nn_lb

@torch.no_grad()
def momentum_forward(self, X: torch.Tensor) -> Dict:
"""Performs the forward pass of the momentum backbone and projector.
Args:
X (torch.Tensor): batch of images in tensor format.
Returns:
Dict[str, Any]: a dict containing the outputs of
the parent and the momentum projected features.
"""

out = super().momentum_forward(X)
z = self.momentum_projector(out["feats"])
out.update({"z": z})
return out

def forward(self, X: torch.Tensor, *args, **kwargs) -> Dict[str, Any]:
"""Performs forward pass of the online backbone, projector and predictor.
Expand All @@ -261,7 +248,6 @@ def forward(self, X: torch.Tensor, *args, **kwargs) -> Dict[str, Any]:
p = self.predictor(z)
return {**out, "z": z, "p": p}


def off_diagonal(self, x):
"""Extracts off-diagonal elements.
Expand All @@ -271,12 +257,11 @@ def off_diagonal(self, x):
Returns:
torch.Tensor:
flattened off-diagonal elements.
"""
"""
n, m = x.shape
assert n == m
return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten()


def save_NN(self, img_indexes, nn1_idx, nn1_lb):
"""Auxiliar function to store the NNs.
Expand All @@ -285,18 +270,18 @@ def save_NN(self, img_indexes, nn1_idx, nn1_lb):
nn1_idx (torch.Tensor): batch of NN indexes in tensor format.
nn1_lb (torch.Tensor): batch of NN labels in tensor format.
"""

with open(f"NNIDX/FirstNN/{self.current_epoch}__{self.global_step}__NNS.pickle", 'wb') as f:
pickle.dump(nn1_idx.cpu().numpy(),f)

with open(f"NNIDX/FirstNN/{self.current_epoch}__{self.global_step}__IDX.pickle", 'wb') as f:
pickle.dump(img_indexes.cpu().numpy(),f)
"""

with open(f"NNIDX/FirstNN/{self.current_epoch}__{self.global_step}__Labels.pickle", 'wb') as f:
pickle.dump(nn1_lb.cpu().numpy(),f)
with open(f"NNIDX/FirstNN/{self.current_epoch}__{self.global_step}__NNS.pickle", "wb") as f:
pickle.dump(nn1_idx.cpu().numpy(), f)

with open(f"NNIDX/FirstNN/{self.current_epoch}__{self.global_step}__IDX.pickle", "wb") as f:
pickle.dump(img_indexes.cpu().numpy(), f)

with open(
f"NNIDX/FirstNN/{self.current_epoch}__{self.global_step}__Labels.pickle", "wb"
) as f:
pickle.dump(nn1_lb.cpu().numpy(), f)

def training_step(self, batch: Sequence[Any], batch_idx: int) -> torch.Tensor:
"""Training step for All4One reusing BaseMomentumMethod training step.
Expand All @@ -313,11 +298,10 @@ def training_step(self, batch: Sequence[Any], batch_idx: int) -> torch.Tensor:
targets = batch[-1]
img_indexes = batch[0]


out = super().training_step(batch, batch_idx)
class_loss = out["loss"]
feats1, feats2 = out["feats"]
momentum_feats1, momentum_feats2 = out["momentum_feats"]
momentum_z1, momentum_z2 = out["momentum_feats"]

z1 = self.projector(feats1)
z2 = self.projector(feats2)
Expand All @@ -328,18 +312,9 @@ def training_step(self, batch: Sequence[Any], batch_idx: int) -> torch.Tensor:
p1_2 = self.predictor2(z1)
p2_2 = self.predictor2(z2)

# forward momentum backbone
with torch.no_grad():
z1_momentum = self.momentum_projector(momentum_feats1)
z2_momentum = self.momentum_projector(momentum_feats2)

z1_momentum = F.normalize(z1_momentum, dim=-1)
z2_momentum = F.normalize(z2_momentum, dim=-1)


# find nn
idx1, nn1, nn1_idx, nn1_lb = self.find_nn(z1_momentum)
_, nn2, _, _ = self.find_nn(z2_momentum)
idx1, nn1, *_ = self.find_nn(momentum_z1)
_, nn2, _, _ = self.find_nn(momentum_z2)

trans_emb1 = self.pos_enc(nn1)
trans_emb2 = self.pos_enc(nn2)
Expand All @@ -348,101 +323,69 @@ def training_step(self, batch: Sequence[Any], batch_idx: int) -> torch.Tensor:
strange1 = self.pos_enc(torch.cat((p1_2.unsqueeze(1), nn1), 1)[:, :5, :])
strange2 = self.pos_enc(torch.cat((p2_2.unsqueeze(1), nn2), 1)[:, :5, :])



# Feature dimension task
p1_norm_feat = torch.nn.functional.normalize(z1_momentum, dim=0)
p2_norm_feat = torch.nn.functional.normalize(z2_momentum, dim=0)
p1_norm_feat = torch.nn.functional.normalize(momentum_z1, dim=0)
p2_norm_feat = torch.nn.functional.normalize(momentum_z2, dim=0)
z1_norm_feat = torch.nn.functional.normalize(z1, dim=0)
z2_norm_feat = torch.nn.functional.normalize(z2, dim=0)


corr_matrix_1_feat = p1_norm_feat.T @ z2_norm_feat
corr_matrix_2_feat = p2_norm_feat.T @ z1_norm_feat

on_diag_feat = ((torch.diagonal(corr_matrix_1_feat).add(-1).pow(2).mean() + torch.diagonal(corr_matrix_2_feat).add(-1).pow(
2).mean()) * 0.5).sqrt()
off_diag_feat = ((self.off_diagonal(corr_matrix_1_feat).pow(2).mean() + self.off_diagonal(corr_matrix_2_feat).pow(
2).mean()) * 0.5).sqrt()





on_diag_feat = (
(
torch.diagonal(corr_matrix_1_feat).add(-1).pow(2).mean()
+ torch.diagonal(corr_matrix_2_feat).add(-1).pow(2).mean()
)
* 0.5
).sqrt()
off_diag_feat = (
(
self.off_diagonal(corr_matrix_1_feat).pow(2).mean()
+ self.off_diagonal(corr_matrix_2_feat).pow(2).mean()
)
* 0.5
).sqrt()

rich_emb1 = self.transformer_encoder(trans_emb1)[:, 0, :]
rich_emb2 = self.transformer_encoder(trans_emb2)[:, 0, :]



strange_emb1 = self.transformer_encoder(strange1)[:, 0, :]
strange_emb2 = self.transformer_encoder(strange2)[:, 0, :]





# ------- contrastive loss -------
att_nnclr_loss = (
nnclr_loss_func(rich_emb1, strange_emb2) / 2
+ nnclr_loss_func(rich_emb2, strange_emb1) / 2
)


nnclr_loss = (
nnclr_loss_func(nn1[:,0,:], p2, temperature=self.temperature) / 2
+ nnclr_loss_func(nn2[:,0,:], p1, temperature=self.temperature) / 2
nnclr_loss_func(nn1[:, 0, :], p2, temperature=self.temperature) / 2
+ nnclr_loss_func(nn2[:, 0, :], p1, temperature=self.temperature) / 2
)


feature_loss = (
(0.5 * on_diag_feat + 0.5 * off_diag_feat)*10
)






feature_loss = (0.5 * on_diag_feat + 0.5 * off_diag_feat) * 10

b = targets.size(0)



final_losss = (0.5*att_nnclr_loss + 0.5*nnclr_loss + 0.5*feature_loss)





final_losss = 0.5 * att_nnclr_loss + 0.5 * nnclr_loss + 0.5 * feature_loss

nn_acc = (targets == self.queue_y[idx1]).sum() / b

self.dequeue_and_enqueue(z1_momentum, targets, img_indexes)
self.dequeue_and_enqueue(momentum_z1, targets, img_indexes)

z1_std = F.normalize(z1, dim=-1).std(dim=0).mean()
z2_std = F.normalize(z2, dim=-1).std(dim=0).mean()
z_std = (z1_std + z2_std) / 2


# Uncomment to save the NNs for analysis

#if (self.current_epoch == 1) or (self.current_epoch % 10 == 0) or (self.current_epoch == 30) or (self.current_epoch == 50) or (self.current_epoch == 70) or (self.current_epoch == 90):
# self.save_NN(img_indexes, nn1_idx, nn1_lb)



metrics = {
"train_comb_loss": final_losss,
"train_nnclr_loss": nnclr_loss,
"train_att_nnclr_loss": att_nnclr_loss,
"train_feature_loss": feature_loss,
"train_nn_acc": nn_acc,



"train_z_std": z_std,
}
self.log_dict(metrics, on_epoch=True, sync_dist=True)

return final_losss + class_loss

0 comments on commit a0199db

Please sign in to comment.