diff --git a/gluefactory/datasets/base_dataset.py b/gluefactory/datasets/base_dataset.py index ef622cbc..b3114c99 100644 --- a/gluefactory/datasets/base_dataset.py +++ b/gluefactory/datasets/base_dataset.py @@ -161,9 +161,12 @@ def get_data_loader(self, split, shuffle=None, pinned=False, distributed=False): except omegaconf.MissingMandatoryValue: batch_size = self.conf.batch_size num_workers = self.conf.get("num_workers", batch_size) + drop_last = True if split == "train" else False if distributed: shuffle = False - sampler = torch.utils.data.distributed.DistributedSampler(dataset) + sampler = torch.utils.data.distributed.DistributedSampler( + dataset, drop_last=drop_last + ) else: sampler = None if shuffle is None: @@ -178,7 +181,7 @@ def get_data_loader(self, split, shuffle=None, pinned=False, distributed=False): num_workers=num_workers, worker_init_fn=worker_init_fn, prefetch_factor=self.conf.prefetch_factor, - drop_last=True if split == "train" else False, + drop_last=drop_last, ) def get_overfit_loader(self, split): diff --git a/gluefactory/models/utils/losses.py b/gluefactory/models/utils/losses.py index cca17636..06c7958b 100644 --- a/gluefactory/models/utils/losses.py +++ b/gluefactory/models/utils/losses.py @@ -69,5 +69,5 @@ def nll_loss(self, log_assignment, data): weights[:, :m, :n] = positive weights[:, :m, -1] = neg0 - weights[:, -1, :m] = neg1 + weights[:, -1, :n] = neg1 return weights