Skip to content

Commit

Permalink
wip Byol update (#190)
Browse files Browse the repository at this point in the history
* lower weight decay

* lower weight decay

* normalization dim

* new loss

* new loss

* load checkpoint

* lower weight decay

* lower weight decay

* normalization dim

* new loss

* new loss

* load checkpoint

* load checkpoint

* load checkpoint
  • Loading branch information
annikabrundyn authored Sep 13, 2020
1 parent 0978172 commit 42f9fa6
Showing 1 changed file with 14 additions and 10 deletions.
24 changes: 14 additions & 10 deletions pl_bolts/models/self_supervised/byol/byol_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class BYOL(pl.LightningModule):
def __init__(self,
num_classes,
learning_rate: float = 0.2,
weight_decay: float = 15e-6,
weight_decay: float = 1.5e-6,
input_height: int = 32,
batch_size: int = 32,
num_workers: int = 0,
Expand Down Expand Up @@ -104,26 +104,27 @@ def forward(self, x):
y, _, _ = self.online_network(x)
return y

def cosine_similarity(self, a, b):
a = F.normalize(a, dim=-1)
b = F.normalize(b, dim=-1)
sim = (a * b).sum(-1).mean()
return sim

def shared_step(self, batch, batch_idx):
(img_1, img_2), y = batch

# Image 1 to image 2 loss
y1, z1, h1 = self.online_network(img_1)
with torch.no_grad():
y2, z2, h2 = self.target_network(img_2)
# L2 normalize
h1_norm = F.normalize(h1, p=2, dim=1)
z2_norm = F.normalize(z2, p=2, dim=1)
loss_a = F.mse_loss(h1_norm, z2_norm)
loss_a = - 2 * self.cosine_similarity(h1, z2)

# Image 2 to image 1 loss
y1, z1, h1 = self.online_network(img_2)
with torch.no_grad():
y2, z2, h2 = self.target_network(img_1)
# L2 normalize
h1_norm = F.normalize(h1, p=2, dim=1)
z2_norm = F.normalize(z2, p=2, dim=1)
loss_b = F.mse_loss(h1_norm, z2_norm)
loss_b = - 2 * self.cosine_similarity(h1, z2)

# Final loss
total_loss = loss_a + loss_b
Expand Down Expand Up @@ -173,7 +174,7 @@ def add_model_specific_args(parent_parser):
# optim
parser.add_argument('--batch_size', type=int, default=256)
parser.add_argument('--learning_rate', type=float, default=1e-3)
parser.add_argument('--weight_decay', type=float, default=15e-6)
parser.add_argument('--weight_decay', type=float, default=1.5e-6)
parser.add_argument('--warmup_epochs', type=float, default=10)

# Model
Expand All @@ -186,6 +187,8 @@ def cli_main():
from pl_bolts.datamodules import CIFAR10DataModule, STL10DataModule, ImagenetDataModule
from pl_bolts.models.self_supervised.simclr import simclr_transforms
from pl_bolts.callbacks.self_supervised import SSLOnlineEvaluator
from pytorch_lightning import seed_everything
seed_everything(1234)

parser = ArgumentParser()

Expand Down Expand Up @@ -235,7 +238,8 @@ def to_device(batch, device):
online_eval = SSLOnlineEvaluator(z_dim=2048, num_classes=dm.num_classes)
online_eval.to_device = to_device

trainer = pl.Trainer.from_argparse_args(args, callbacks=[online_eval])
trainer = pl.Trainer.from_argparse_args(args, max_steps=300000, callbacks=[online_eval])

trainer.fit(model, dm)


Expand Down

0 comments on commit 42f9fa6

Please sign in to comment.