111
111
" from data.cifar100 import Cifar100\n " ,
112
112
" from model.resnet_cifar import resnet32\n " ,
113
113
" from model.manager import Manager\n " ,
114
+ " from model.lwf import LWF\n " ,
114
115
" from model.icarl import Exemplars\n " ,
115
116
" from model.icarl import iCaRL\n " ,
116
117
" from utils import plot"
301
302
" criterion = nn.BCEWithLogitsLoss()\n " ,
302
303
" \n " ,
303
304
" for split_i in range(10):\n " ,
304
- " print(f\" # Split {split_i} of run {run_i}\" )\n " ,
305
+ " print(f\" ## Split {split_i} of run {run_i} ## \" )\n " ,
305
306
" \n " ,
306
307
" parameters_to_optimize = net.parameters()\n " ,
307
308
" optimizer = optim.SGD(parameters_to_optimize, lr=LR, momentum=MOMENTUM, weight_decay=WEIGHT_DECAY)\n " ,
326
327
" test_accuracy, all_targets, all_preds = manager.test()\n " ,
327
328
" \n " ,
328
329
" logs[run_i][split_i]['test_accuracy'] = test_accuracy\n " ,
330
+ " logs[run_i][split_i]['conf_mat'] = confusion_matrix(all_targets.to('cpu'), all_preds.to('cpu'))\n " ,
329
331
" \n " ,
330
332
" # Add 10 nodes to last FC layer\n " ,
331
333
" manager.increment_classes(n=10)"
332
334
],
333
335
"execution_count" : null ,
334
336
"outputs" : []
335
337
},
336
- {
337
- "cell_type" : " code" ,
338
- "metadata" : {
339
- "colab_type" : " code" ,
340
- "id" : " JWdj0wvu996S" ,
341
- "colab" : {}
342
- },
343
- "source" : [
344
- " # Confusion matrix over last run test predictions\n " ,
345
- " targets = test_dataset.targets\n " ,
346
- " preds = all_preds.to('cpu').numpy()\n " ,
347
- " \n " ,
348
- " plot.heatmap_cm(targets, preds)"
349
- ],
350
- "execution_count" : null ,
351
- "outputs" : []
352
- },
353
338
{
354
339
"cell_type" : " code" ,
355
340
"metadata" : {
415
400
" ## Learning Without Forgetting"
416
401
]
417
402
},
418
- {
419
- "cell_type" : " markdown" ,
420
- "metadata" : {
421
- "id" : " ERL_PF-cm1N_" ,
422
- "colab_type" : " text"
423
- },
424
- "source" : [
425
- " ### Arguments"
426
- ]
427
- },
428
- {
429
- "cell_type" : " code" ,
430
- "metadata" : {
431
- "colab_type" : " code" ,
432
- "id" : " JHBfXPTXm16d" ,
433
- "colab" : {}
434
- },
435
- "source" : [
436
- " # Training settings for Learning Without Forgetting\n " ,
437
- " RANDOM_STATES = [658, 423, 422] \n " ,
438
- " BATCH_SIZE = 128\n " ,
439
- " LR = 2"
440
- ],
441
- "execution_count" : null ,
442
- "outputs" : []
443
- },
444
403
{
445
404
"cell_type" : " markdown" ,
446
405
"metadata" : {
486
445
" test_dataloaders = [[] for i in range(NUM_RUNS)]\n " ,
487
446
" \n " ,
488
447
" for run_i in range(NUM_RUNS):\n " ,
448
+ " test_subsets = []\n " ,
449
+ " random_state = RANDOM_STATES[run_i]\n " ,
489
450
" \n " ,
490
- " test_subsets = []\n " ,
491
- " random_state = RANDOM_STATES[run_i]\n " ,
451
+ " for split_i in range(10):\n " ,
452
+ " # Download dataset only at first instantiation\n " ,
453
+ " if run_i+split_i == 0:\n " ,
454
+ " download = True\n " ,
455
+ " else:\n " ,
456
+ " download = False\n " ,
492
457
" \n " ,
493
- " for split_i in range(CLASS_BATCH_SIZE):\n " ,
458
+ " # Create CIFAR100 dataset\n " ,
459
+ " train_dataset = Cifar100(DATA_DIR, train=True, download=download, random_state=random_state, transform=train_transform)\n " ,
460
+ " test_dataset = Cifar100(DATA_DIR, train=False, download=False, random_state=random_state, transform=test_transform)\n " ,
494
461
" \n " ,
495
- " # Download dataset only at first instantiation\n " ,
496
- " if(run_i+split_i == 0):\n " ,
497
- " download = True\n " ,
498
- " else:\n " ,
499
- " download = False\n " ,
462
+ " # Subspace of CIFAR100 of 10 classes\n " ,
463
+ " train_dataset.set_classes_batch(train_dataset.batch_splits[split_i])\n " ,
464
+ " test_dataset.set_classes_batch([test_dataset.batch_splits[i] for i in range(0, split_i+1)])\n " ,
500
465
" \n " ,
501
- " # Create CIFAR100 dataset\n " ,
502
- " train_dataset = Cifar100(DATA_DIR, train = True, download = download, random_state = random_state, transform=train_transform)\n " ,
503
- " test_dataset = Cifar100(DATA_DIR, train = False, download = False, random_state = random_state, transform=test_transform)\n " ,
504
- " \n " ,
505
- " # Subspace of CIFAR100 of 10 classes\n " ,
506
- " train_dataset.set_classes_batch(train_dataset.batch_splits[split_i]) \n " ,
507
- " test_dataset.set_classes_batch([test_dataset.batch_splits[i] for i in range(0, split_i+1)])\n " ,
466
+ " # Define train and validation indices\n " ,
467
+ " train_indices, val_indices = train_dataset.train_val_split(VAL_SIZE, random_state)\n " ,
508
468
" \n " ,
509
- " # Define train and validation indices\n " ,
510
- " train_indices, val_indices = train_dataset.train_val_split(VAL_SIZE, random_state)\n " ,
511
- " \n " ,
512
- " train_dataloaders[run_i].append(DataLoader(Subset(train_dataset, train_indices), \n " ,
513
- " batch_size=BATCH_SIZE, shuffle=True, num_workers=4, drop_last=True))\n " ,
514
- " \n " ,
515
- " val_dataloaders[run_i].append(DataLoader(Subset(train_dataset, val_indices), \n " ,
516
- " batch_size=BATCH_SIZE, shuffle=True, num_workers=4, drop_last=True))\n " ,
517
- " \n " ,
518
- " # Dataset with all seen class\n " ,
519
- " test_dataloaders[run_i].append(DataLoader(test_dataset, \n " ,
520
- " batch_size=BATCH_SIZE, shuffle=True, num_workers=4)) "
469
+ " train_dataloaders[run_i].append(DataLoader(Subset(train_dataset, train_indices),\n " ,
470
+ " batch_size=BATCH_SIZE, shuffle=True, num_workers=4, drop_last=True))\n " ,
471
+ " \n " ,
472
+ " val_dataloaders[run_i].append(DataLoader(Subset(train_dataset, val_indices),\n " ,
473
+ " batch_size=BATCH_SIZE, shuffle=True, num_workers=4, drop_last=True))\n " ,
474
+ " \n " ,
475
+ " # Dataset with all seen class\n " ,
476
+ " test_dataloaders[run_i].append(DataLoader(test_dataset,\n " ,
477
+ " batch_size=BATCH_SIZE, shuffle=True, num_workers=4))"
521
478
],
522
479
"execution_count" : null ,
523
480
"outputs" : []
534
491
" dataiter = iter(test_dataloaders[0][5])\n " ,
535
492
" images, labels = dataiter.next()\n " ,
536
493
" \n " ,
537
- " plot.image_grid(images, one_channel=False)\n " ,
538
- " unique_labels = np.unique(labels, return_counts=True)\n " ,
539
- " unique_labels"
494
+ " plot.image_grid(images, one_channel=False)"
540
495
],
541
496
"execution_count" : null ,
542
497
"outputs" : []
543
498
},
499
+ {
500
+ "cell_type" : " markdown" ,
501
+ "metadata" : {
502
+ "id" : " iYwMtMJuLyYe" ,
503
+ "colab_type" : " text"
504
+ },
505
+ "source" : [
506
+ " ### Execution"
507
+ ]
508
+ },
544
509
{
545
510
"cell_type" : " code" ,
546
511
"metadata" : {
547
- "id" : " cw6a_xAumXQW " ,
512
+ "id" : " JpGuC_hSL0jN " ,
548
513
"colab_type" : " code" ,
549
514
"colab" : {}
550
515
},
551
516
"source" : [
552
- " from torch.nn import BCEWithLogitsLoss\n " ,
553
- " from copy import deepcopy\n " ,
554
- " \n " ,
555
- " '''BCE formulation:\n " ,
556
- " let x = logits, z = labels. The logistic loss is\n " ,
557
- " \n " ,
558
- " z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))\n " ,
559
- " '''\n " ,
560
- " \n " ,
561
- " \n " ,
562
- " CLASS_BATCH_SIZE = 10\n " ,
563
- " \n " ,
564
- " \n " ,
565
- " class LWF():\n " ,
566
- " def __init__(self, device, net, old_net, criterion, optimizer, scheduler,\n " ,
567
- " train_dataloader, val_dataloader, test_dataloader, num_classes=10):\n " ,
568
- " \n " ,
569
- " self.device = device\n " ,
570
- " \n " ,
571
- " self.net = net\n " ,
572
- " self.best_net = self.net\n " ,
573
- " self.old_net = old_net # None for first ten classes\n " ,
574
- " \n " ,
575
- " self.criterion = BCEWithLogitsLoss() # Classifier criterion \n " ,
576
- " self.optimizer = optimizer\n " ,
577
- " self.scheduler = scheduler\n " ,
578
- " \n " ,
579
- " self.train_dataloader = train_dataloader\n " ,
580
- " self.val_dataloader = val_dataloader\n " ,
581
- " self.test_dataloader = test_dataloader\n " ,
582
- " \n " ,
583
- " self.num_classes = num_classes # can be incremented ouitside methods in the main, or inside methods\n " ,
584
- " self.order = np.arange(100)\n " ,
585
- " \n " ,
586
- " self.sigmoid = nn.Sigmoid()\n " ,
587
- " \n " ,
588
- " \n " ,
589
- " def warm_up():\n " ,
590
- " pass\n " ,
591
- " \n " ,
592
- " def increment_classes(self, n=10):\n " ,
593
- " \"\"\" Add n classes in the final fully connected layer.\"\"\"\n " ,
594
- " \n " ,
595
- " in_features = self.net.fc.in_features # size of each input sample\n " ,
596
- " out_features = self.net.fc.out_features # size of each output sample\n " ,
597
- " weight = self.net.fc.weight.data\n " ,
598
- " \n " ,
599
- " self.net.fc = nn.Linear(in_features, out_features+n)\n " ,
600
- " self.net.fc.weight.data[:out_features] = weight\n " ,
601
- " \n " ,
602
- " def to_onehot(self, targets): \n " ,
603
- " '''\n " ,
604
- " Args:\n " ,
605
- " targets : dataloader.dataset.targets of the new task images\n " ,
606
- " '''\n " ,
607
- " one_hot_targets = torch.eye(self.num_classes)[targets]\n " ,
608
- " \n " ,
609
- " return one_hot_targets.to(self.device)\n " ,
610
- " \n " ,
611
- " def do_first_batch(self, batch, labels):\n " ,
612
- " \n " ,
613
- " batch = batch.to(self.device)\n " ,
614
- " labels = labels.to(self.device) # new classes labels\n " ,
615
- " \n " ,
616
- " # Zero-ing the gradients\n " ,
617
- " self.optimizer.zero_grad()\n " ,
618
- " \n " ,
619
- " # One hot encoding of new task labels \n " ,
620
- " one_hot_labels = self.to_onehot(labels) # Size = [128, 10]\n " ,
621
- " \n " ,
622
- " # New net forward pass\n " ,
623
- " outputs = self.net(batch) \n " ,
624
- " \n " ,
625
- " loss = self.criterion(outputs, one_hot_labels) # BCE Loss with sigmoids over outputs\n " ,
626
- " \n " ,
627
- " # Get predictions\n " ,
628
- " _, preds = torch.max(outputs.data, 1)\n " ,
629
- " \n " ,
630
- " # Accuracy over NEW IMAGES, not over all images\n " ,
631
- " running_corrects = \\\n " ,
632
- " torch.sum(preds == labels.data).data.item() # Può essere che debba usare targets e non labels\n " ,
633
- " \n " ,
634
- " # Backward pass: computes gradients\n " ,
635
- " loss.backward()\n " ,
636
- " \n " ,
637
- " self.optimizer.step()\n " ,
638
- " \n " ,
639
- " return loss, running_corrects\n " ,
640
- " \n " ,
641
- " \n " ,
642
- " def do_batch(self, batch, labels):\n " ,
643
- " \n " ,
644
- " batch = batch.to(self.device)\n " ,
645
- " labels = labels.to(self.device) # new classes labels\n " ,
646
- " \n " ,
647
- " # Zero-ing the gradients\n " ,
648
- " self.optimizer.zero_grad()\n " ,
649
- " \n " ,
650
- " # One hot encoding of new task labels \n " ,
651
- " one_hot_labels = self.to_onehot(labels) # Size = [128, n_classes] will be sliced as [:, :self.num_classes-10]\n " ,
652
- " new_classes = (self.order[range(self.num_classes-10, self.num_classes)]).astype(np.int32)\n " ,
653
- " one_hot_labels = torch.stack([one_hot_labels[:, i] for i in new_classes], axis=1)\n " ,
654
- " \n " ,
655
- " # Old net forward pass\n " ,
656
- " old_outputs = self.sigmoid(self.old_net(batch)) # Size = [128, 100]\n " ,
657
- " old_classes = (self.order[range(self.num_classes-10)]).astype(np.int32)\n " ,
658
- " old_outputs = torch.stack([old_outputs[:, i] for i in old_classes], axis =1)\n " ,
659
- " \n " ,
660
- " # Combine new and old class targets\n " ,
661
- " targets = torch.cat((old_outputs, one_hot_labels), 1)\n " ,
662
- " \n " ,
663
- " # New net forward pass\n " ,
664
- " outputs = self.net(batch) # Size = [128, 100] comparable with the define targets\n " ,
665
- " out_classes = (self.order[range(self.num_classes)]).astype(np.int32)\n " ,
666
- " outputs = torch.stack([outputs[:, i] for i in out_classes], axis=1)\n " ,
667
- " \n " ,
668
- " \n " ,
669
- " loss = self.criterion(outputs, targets) # BCE Loss with sigmoids over outputs (over targets must be done manually)\n " ,
670
- " \n " ,
671
- " # Get predictions\n " ,
672
- " _, preds = torch.max(outputs.data, 1)\n " ,
673
- " \n " ,
674
- " # Accuracy over NEW IMAGES, not over all images\n " ,
675
- " running_corrects = \\\n " ,
676
- " torch.sum(preds == labels.data).data.item() \n " ,
677
- " \n " ,
678
- " # Backward pass: computes gradients\n " ,
679
- " loss.backward()\n " ,
680
- " \n " ,
681
- " self.optimizer.step()\n " ,
682
- " \n " ,
683
- " return loss, running_corrects\n " ,
684
- " \n " ,
685
- " \n " ,
686
- " def do_epoch(self, current_epoch):\n " ,
687
- " \n " ,
688
- " self.net.train()\n " ,
689
- " \n " ,
690
- " running_train_loss = 0\n " ,
691
- " running_corrects = 0\n " ,
692
- " total = 0\n " ,
693
- " batch_idx = 0\n " ,
694
- " \n " ,
695
- " print(f\" Epoch: {current_epoch}, LR: {self.scheduler.get_last_lr()}\" )\n " ,
696
- " \n " ,
697
- " for images, labels in self.train_dataloader:\n " ,
698
- " \n " ,
699
- " if self.num_classes == CLASS_BATCH_SIZE:\n " ,
700
- " loss, corrects = self.do_first_batch(images, labels)\n " ,
701
- " else:\n " ,
702
- " loss, corrects = self.do_batch(images, labels)\n " ,
703
- " \n " ,
704
- " running_train_loss += loss.item()\n " ,
705
- " running_corrects += corrects\n " ,
706
- " total += labels.size(0)\n " ,
707
- " batch_idx += 1\n " ,
708
- " \n " ,
709
- " self.scheduler.step()\n " ,
710
- " \n " ,
711
- " # Calculate average scores\n " ,
712
- " train_loss = running_train_loss / batch_idx # Average over all batches\n " ,
713
- " train_accuracy = running_corrects / float(total) # Average over all samples\n " ,
714
- " \n " ,
715
- " print(f\" Train loss: {train_loss}, Train accuracy: {train_accuracy}\" )\n " ,
716
- " \n " ,
717
- " return (train_loss, train_accuracy)\n " ,
718
- " \n " ,
517
+ " # Arguments for Learning without Forgetting\n " ,
518
+ " BATCH_SIZE = 128\n " ,
519
+ " LR = 2"
520
+ ],
521
+ "execution_count" : null ,
522
+ "outputs" : []
523
+ },
524
+ {
525
+ "cell_type" : " code" ,
526
+ "metadata" : {
527
+ "id" : " MlThDLCvXJwS" ,
528
+ "colab_type" : " code" ,
529
+ "colab" : {}
530
+ },
531
+ "source" : [
532
+ " logs = [[] for _ in range(NUM_RUNS)]\n " ,
719
533
" \n " ,
720
- " def train(self, num_epochs):\n " ,
721
- " \"\"\" Train the network for a specified number of epochs, and save\n " ,
722
- " the best performing model on the validation set.\n " ,
723
- " \n " ,
724
- " Args:\n " ,
725
- " num_epochs (int): number of epochs for training the network.\n " ,
726
- " Returns:\n " ,
727
- " train_loss: loss computed on the last epoch\n " ,
728
- " train_accuracy: accuracy computed on the last epoch\n " ,
729
- " val_loss: average loss on the validation set of the last epoch\n " ,
730
- " val_accuracy: accuracy on the validation set of the last epoch\n " ,
731
- " \"\"\"\n " ,
732
- " \n " ,
733
- " # @todo: is the return behaviour intended? (scores of the last epoch)\n " ,
734
- " \n " ,
735
- " self.net = self.net.to(self.device)\n " ,
736
- " if self.old_net != None:\n " ,
737
- " self.old_net = self.old_net.to(self.device)\n " ,
738
- " self.old_net.train(False)\n " ,
739
- " \n " ,
740
- " cudnn.benchmark # Calling this optimizes runtime\n " ,
741
- " \n " ,
742
- " self.best_loss = float(\" inf\" )\n " ,
743
- " self.best_epoch = 0\n " ,
744
- " \n " ,
745
- " for epoch in range(num_epochs):\n " ,
746
- " # Run an epoch (start counting form 1)\n " ,
747
- " train_loss, train_accuracy = self.do_epoch(epoch+1)\n " ,
534
+ " # Iterate over runs\n " ,
535
+ " for run_i in range(NUM_RUNS):\n " ,
536
+ " net = resnet32()\n " ,
748
537
" \n " ,
749
- " # Validate after each epoch \n " ,
750
- " val_loss, val_accuracy = self.validate() \n " ,
751
- " \n " ,
752
- " # Best validation model\n " ,
753
- " if val_loss < self.best_loss:\n " ,
754
- " self.best_loss = val_loss\n " ,
755
- " self.best_net = deepcopy(self.net)\n " ,
756
- " self.best_epoch = epoch\n " ,
757
- " print(\" Best model updated\" )\n " ,
758
- " \n " ,
759
- " print(\"\" )\n " ,
760
- " \n " ,
761
- " return (train_loss, train_accuracy,\n " ,
762
- " val_loss, val_accuracy)\n " ,
763
- " \n " ,
764
- " \n " ,
765
- " def validate(self):\n " ,
766
- " \"\"\" Validate the model.\n " ,
538
+ " criterion = nn.BCEWithLogitsLoss()\n " ,
767
539
" \n " ,
768
- " Returns:\n " ,
769
- " val_loss: average loss function computed on the network outputs\n " ,
770
- " of the validation set (val_dataloader).\n " ,
771
- " val_accuracy: accuracy computed on the validation set.\n " ,
772
- " \"\"\"\n " ,
773
- " \n " ,
774
- " self.net.train(False)\n " ,
775
- " \n " ,
776
- " running_val_loss = 0\n " ,
777
- " running_corrects = 0\n " ,
778
- " total = 0\n " ,
779
- " batch_idx = 0\n " ,
780
- " \n " ,
781
- " \n " ,
782
- " for batch, labels in self.val_dataloader:\n " ,
783
- " batch = batch.to(self.device)\n " ,
784
- " labels = labels.to(self.device)\n " ,
785
- " total += labels.size(0)\n " ,
786
- " \n " ,
787
- " # One hot encoding of new task labels \n " ,
788
- " one_hot_labels = self.to_onehot(labels) # Size = [128, 100] will be sliced as [:, :self.num_classes-10]\n " ,
789
- " new_classes = (self.order[range(self.num_classes-10, self.num_classes)]).astype(np.int32)\n " ,
790
- " one_hot_labels = torch.stack([one_hot_labels[:, i] for i in new_classes], axis=1)\n " ,
791
- " \n " ,
792
- " if self.num_classes > 10:\n " ,
793
- " # Old net forward pass\n " ,
794
- " old_outputs = self.sigmoid(self.old_net(batch)) # Size = [128, 100]\n " ,
795
- " old_classes = (self.order[range(self.num_classes-10)]).astype(np.int32)\n " ,
796
- " old_outputs = torch.stack([old_outputs[:, i] for i in old_classes], axis =1)\n " ,
797
- " \n " ,
798
- " # Combine new and old class targets\n " ,
799
- " targets = torch.cat((old_outputs, one_hot_labels), 1)\n " ,
800
- " \n " ,
801
- " else:\n " ,
802
- " targets = one_hot_labels\n " ,
803
- " \n " ,
804
- " # New net forward pass\n " ,
805
- " outputs = self.net(batch) # Size = [128, 100] comparable with the define targets\n " ,
806
- " out_classes = (self.order[range(self.num_classes)]).astype(np.int32)\n " ,
807
- " outputs = torch.stack([outputs[:, i] for i in out_classes], axis=1)\n " ,
808
- " \n " ,
809
- " \n " ,
810
- " loss = self.criterion(outputs, targets) # BCE Loss with sigmoids over outputs (over targets must be done manually)\n " ,
811
- " \n " ,
812
- " # Get predictions\n " ,
813
- " _, preds = torch.max(outputs.data, 1)\n " ,
814
- " \n " ,
815
- " # Update the number of correctly classified validation samples\n " ,
816
- " running_corrects += torch.sum(preds == labels.data).data.item()\n " ,
817
- " running_val_loss += loss.item()\n " ,
818
- " \n " ,
819
- " batch_idx += 1\n " ,
820
- " \n " ,
821
- " # Calcuate scores\n " ,
822
- " val_loss = running_val_loss / batch_idx\n " ,
823
- " val_accuracy = running_corrects / float(total)\n " ,
824
- " \n " ,
825
- " print(f\" Validation loss: {val_loss}, Validation accuracy: {val_accuracy}\" )\n " ,
826
- " \n " ,
827
- " return (val_loss, val_accuracy)\n " ,
828
- " \n " ,
829
- " \n " ,
830
- " def test(self):\n " ,
831
- " \"\"\" Test the model.\n " ,
832
- " Returns:\n " ,
833
- " accuracy (float): accuracy of the model on the test set\n " ,
834
- " \"\"\"\n " ,
835
- " \n " ,
836
- " self.best_net.train(False) # Set Network to evaluation mode\n " ,
837
- " \n " ,
838
- " running_corrects = 0\n " ,
839
- " total = 0\n " ,
540
+ " for split_i in range(10):\n " ,
541
+ " print(f\" ## Split {split_i} of run {run_i} ##\" )\n " ,
840
542
" \n " ,
841
- " all_preds = torch.tensor([]) # to store all predictions\n " ,
842
- " all_preds = all_preds.type(torch.LongTensor)\n " ,
843
- " \n " ,
844
- " for images, labels in self.test_dataloader:\n " ,
845
- " images = images.to(self.device)\n " ,
846
- " labels = labels.to(self.device)\n " ,
847
- " total += labels.size(0)\n " ,
543
+ " # Redefine optimizer at each split (pass by reference issue)\n " ,
544
+ " parameters_to_optimize = net.parameters()\n " ,
545
+ " optimizer = optim.SGD(parameters_to_optimize, lr=LR,\n " ,
546
+ " momentum=MOMENTUM, weight_decay=WEIGHT_DECAY)\n " ,
547
+ " scheduler = optim.lr_scheduler.MultiStepLR(optimizer, \n " ,
548
+ " milestones=MILESTONES, gamma=GAMMA)\n " ,
549
+ " \n " ,
550
+ " num_classes = 10*(split_i+1)\n " ,
551
+ " \n " ,
552
+ " if num_classes == 10: # old network == None\n " ,
553
+ " lwf = LWF(DEVICE, net, None, criterion, optimizer, scheduler,\n " ,
554
+ " train_dataloaders[run_i][split_i],\n " ,
555
+ " val_dataloaders[run_i][split_i],\n " ,
556
+ " test_dataloaders[run_i][split_i],\n " ,
557
+ " num_classes)\n " ,
558
+ " else:\n " ,
559
+ " lwf = LWF(DEVICE, net, old_net, criterion, optimizer, scheduler,\n " ,
560
+ " train_dataloaders[run_i][split_i],\n " ,
561
+ " val_dataloaders[run_i][split_i],\n " ,
562
+ " test_dataloaders[run_i][split_i],\n " ,
563
+ " num_classes)\n " ,
848
564
" \n " ,
849
- " # Forward Pass\n " ,
850
- " outputs = self.best_net(images)\n " ,
565
+ " scores = lwf.train(NUM_EPOCHS) # train the model\n " ,
851
566
" \n " ,
852
- " # Get predictions\n " ,
853
- " _, preds = torch.max(outputs.data, 1)\n " ,
567
+ " logs[run_i].append({})\n " ,
854
568
" \n " ,
855
- " # Update Corrects\n " ,
856
- " running_corrects += torch.sum(preds == labels.data).data.item()\n " ,
569
+ " # score[i] = dictionary with key:epoch, value: score\n " ,
570
+ " logs[run_i][split_i]['train_loss'] = scores[0]\n " ,
571
+ " logs[run_i][split_i]['train_accuracy'] = scores[1]\n " ,
572
+ " logs[run_i][split_i]['val_loss'] = scores[2]\n " ,
573
+ " logs[run_i][split_i]['val_accuracy'] = scores[3]\n " ,
857
574
" \n " ,
858
- " # Append batch predictions\n " ,
859
- " all_preds = torch.cat(\n " ,
860
- " (all_preds.to(self.device), preds.to(self.device)), dim=0\n " ,
861
- " )\n " ,
575
+ " # Test the model on classes seen until now\n " ,
576
+ " test_accuracy, all_targets, all_preds = lwf.test()\n " ,
862
577
" \n " ,
863
- " # Calculate accuracy \n " ,
864
- " accuracy = running_corrects / float(total) \n " ,
578
+ " logs[run_i][split_i]['test_accuracy'] = test_accuracy \n " ,
579
+ " logs[run_i][split_i]['conf_mat'] = confusion_matrix(all_targets.to('cpu'), all_preds.to('cpu')) \n " ,
865
580
" \n " ,
866
- " print(f \" Test accuracy: {accuracy} \" )\n " ,
581
+ " old_net = deepcopy(lwf.net )\n " ,
867
582
" \n " ,
868
- " return (accuracy, all_preds )"
583
+ " lwf.increment_classes( )"
869
584
],
870
585
"execution_count" : null ,
871
586
"outputs" : []
872
587
},
588
+ {
589
+ "cell_type" : " markdown" ,
590
+ "metadata" : {
591
+ "id" : " 2xZbK6EGSaZN" ,
592
+ "colab_type" : " text"
593
+ },
594
+ "source" : [
595
+ " ### Plots"
596
+ ]
597
+ },
873
598
{
874
599
"cell_type" : " code" ,
875
600
"metadata" : {
876
- "id" : " MlThDLCvXJwS " ,
601
+ "id" : " bUfgSq1xSbrD " ,
877
602
"colab_type" : " code" ,
878
603
"colab" : {}
879
604
},
880
605
"source" : [
881
- " train_loss_history = []\n " ,
882
- " train_accuracy_history = []\n " ,
883
- " val_loss_history = []\n " ,
884
- " val_accuracy_history = []\n " ,
885
- " test_accuracy_history = []\n " ,
886
- " \n " ,
887
- " \n " ,
888
- " \n " ,
889
- " # Iterate over runs\n " ,
890
- " for train_dataloader, val_dataloader, test_dataloader in zip(train_dataloaders,\n " ,
891
- " val_dataloaders, test_dataloaders):\n " ,
892
- " \n " ,
893
- " \n " ,
894
- " train_loss_history.append({})\n " ,
895
- " train_accuracy_history.append({})\n " ,
896
- " val_loss_history.append({})\n " ,
897
- " val_accuracy_history.append({})\n " ,
898
- " test_accuracy_history.append({})\n " ,
899
- " \n " ,
900
- " net = resnet32() # Define the net\n " ,
901
- " \n " ,
902
- " criterion = nn.BCEWithLogitsLoss() # Define the loss\n " ,
903
- " \n " ,
904
- " \n " ,
905
- " i = 0\n " ,
906
- " for train_split, val_split, test_split in zip(train_dataloader,\n " ,
907
- " val_dataloader, test_dataloader):\n " ,
908
- " \n " ,
909
- " # Redefine optimizer at each split (pass by reference issue)\n " ,
910
- " parameters_to_optimize = net.parameters()\n " ,
911
- " optimizer = optim.SGD(parameters_to_optimize, lr=LR,\n " ,
912
- " momentum=MOMENTUM, weight_decay=WEIGHT_DECAY)\n " ,
913
- " scheduler = optim.lr_scheduler.MultiStepLR(optimizer, \n " ,
914
- " milestones=MILESTONES, gamma=GAMMA)\n " ,
915
- " \n " ,
916
- " current_split = \" Split %i\" %(i)\n " ,
917
- " print(current_split)\n " ,
918
- " \n " ,
919
- " num_classes = CLASS_BATCH_SIZE*(i+1)\n " ,
920
- " \n " ,
921
- " if num_classes == CLASS_BATCH_SIZE:\n " ,
922
- " # Old Network = None\n " ,
923
- " lwf = LWF(DEVICE, net, None, criterion, optimizer, scheduler,\n " ,
924
- " train_split, val_split, test_split, num_classes)\n " ,
925
- " else:\n " ,
926
- " lwf = LWF(DEVICE, net, old_net, criterion, optimizer, scheduler,\n " ,
927
- " train_split, val_split, test_split, num_classes)\n " ,
928
- " \n " ,
929
- " \n " ,
930
- " scores = lwf.train(NUM_EPOCHS) # train the model\n " ,
931
- " \n " ,
932
- " # score[i] = dictionary with key:epoch, value: score\n " ,
933
- " train_loss_history[-1][current_split] = scores[0]\n " ,
934
- " train_accuracy_history[-1][current_split] = scores[1]\n " ,
935
- " val_loss_history[-1][current_split] = scores[2]\n " ,
936
- " val_accuracy_history[-1][current_split] = scores[3]\n " ,
937
- " \n " ,
938
- " # Test the model on classes seen until now\n " ,
939
- " test_accuracy, all_preds = lwf.test()\n " ,
940
- " \n " ,
941
- " test_accuracy_history[-1][current_split] = test_accuracy\n " ,
606
+ " train_loss = [[logs[run_i][i]['train_loss'] for i in range(10)] for run_i in range(NUM_RUNS)]\n " ,
607
+ " train_accuracy = [[logs[run_i][i]['train_accuracy'] for i in range(10)] for run_i in range(NUM_RUNS)]\n " ,
608
+ " val_loss = [[logs[run_i][i]['val_loss'] for i in range(10)] for run_i in range(NUM_RUNS)]\n " ,
609
+ " val_accuracy = [[logs[run_i][i]['val_accuracy'] for i in range(10)] for run_i in range(NUM_RUNS)]\n " ,
610
+ " test_accuracy = [[logs[run_i][i]['test_accuracy'] for i in range(10)] for run_i in range(NUM_RUNS)]\n " ,
942
611
" \n " ,
943
- " # Uncomment if default resnet has 10 node at last FC layer\n " ,
944
- " old_net = deepcopy(lwf.net)\n " ,
945
- " lwf.increment_classes()\n " ,
612
+ " train_loss = np.array(train_loss)\n " ,
613
+ " train_accuracy = np.array(train_accuracy)\n " ,
614
+ " val_loss = np.array(val_loss)\n " ,
615
+ " val_accuracy = np.array(val_accuracy)\n " ,
616
+ " test_accuracy = np.array(test_accuracy)\n " ,
946
617
" \n " ,
947
- " i =i+1"
618
+ " train_loss_stats = np.array([train_loss.mean(0), train_loss.std(0)]).transpose()\n " ,
619
+ " train_accuracy_stats = np.array([train_accuracy.mean(0), train_accuracy.std(0)]).transpose()\n " ,
620
+ " val_loss_stats = np.array([val_loss.mean(0), val_loss.std(0)]).transpose()\n " ,
621
+ " val_accuracy_stats = np.array([val_accuracy.mean(0), val_accuracy.std(0)]).transpose()\n " ,
622
+ " test_accuracy_stats = np.array([test_accuracy.mean(0), test_accuracy.std(0)]).transpose()"
623
+ ],
624
+ "execution_count" : null ,
625
+ "outputs" : []
626
+ },
627
+ {
628
+ "cell_type" : " code" ,
629
+ "metadata" : {
630
+ "colab_type" : " code" ,
631
+ "id" : " 1w3_YPJCSeli" ,
632
+ "colab" : {}
633
+ },
634
+ "source" : [
635
+ " plot.train_val_scores(train_loss_stats, train_accuracy_stats, val_loss_stats, val_accuracy_stats)"
636
+ ],
637
+ "execution_count" : null ,
638
+ "outputs" : []
639
+ },
640
+ {
641
+ "cell_type" : " code" ,
642
+ "metadata" : {
643
+ "colab_type" : " code" ,
644
+ "id" : " ZSt6-FJbSelp" ,
645
+ "colab" : {}
646
+ },
647
+ "source" : [
648
+ " plot.test_scores(test_accuracy_stats)"
948
649
],
949
650
"execution_count" : null ,
950
651
"outputs" : []
1108
809
"execution_count" : null ,
1109
810
"outputs" : []
1110
811
},
1111
- {
1112
- "cell_type" : " code" ,
1113
- "metadata" : {
1114
- "id" : " by8c4Aaa-8ms" ,
1115
- "colab_type" : " code" ,
1116
- "colab" : {}
1117
- },
1118
- "source" : [
1119
- " print(logs_icarl)"
1120
- ],
1121
- "execution_count" : null ,
1122
- "outputs" : []
1123
- },
1124
- {
1125
- "cell_type" : " code" ,
1126
- "metadata" : {
1127
- "id" : " qOCjSFJy_ANm" ,
1128
- "colab_type" : " code" ,
1129
- "colab" : {}
1130
- },
1131
- "source" : [
1132
- " obj_save(logs_icarl, 'hybrid1_confmat')"
1133
- ],
1134
- "execution_count" : null ,
1135
- "outputs" : []
1136
- },
1137
812
{
1138
813
"cell_type" : " markdown" ,
1139
814
"metadata" : {
1200
875
"outputs" : []
1201
876
}
1202
877
]
1203
- }
878
+ }
0 commit comments