We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent ac9c333 commit b5fd704Copy full SHA for b5fd704
pytorch_classification/vision_transformer/train.py
@@ -52,14 +52,14 @@ def main(args):
52
shuffle=True,
53
pin_memory=True,
54
num_workers=nw,
55
- collate_fn=train_data_set.collate_fn)
+ collate_fn=train_dataset.collate_fn)
56
57
val_loader = torch.utils.data.DataLoader(val_dataset,
58
batch_size=batch_size,
59
shuffle=False,
60
61
62
- collate_fn=val_data_set.collate_fn)
+ collate_fn=val_dataset.collate_fn)
63
64
model = create_model(num_classes=5, has_logits=False).to(device)
65
0 commit comments