Skip to content
This repository was archived by the owner on Sep 3, 2021. It is now read-only.

Commit 5e8bee2

Browse files
committedJul 20, 2020
lwf class moved to module and cleanup for consistency
1 parent ceaa93d commit 5e8bee2

File tree

1 file changed

+149
-474
lines changed

1 file changed

+149
-474
lines changed
 

‎baselines.ipynb

+149-474
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@
111111
"from data.cifar100 import Cifar100\n",
112112
"from model.resnet_cifar import resnet32\n",
113113
"from model.manager import Manager\n",
114+
"from model.lwf import LWF\n",
114115
"from model.icarl import Exemplars\n",
115116
"from model.icarl import iCaRL\n",
116117
"from utils import plot"
@@ -301,7 +302,7 @@
301302
" criterion = nn.BCEWithLogitsLoss()\n",
302303
" \n",
303304
" for split_i in range(10):\n",
304-
" print(f\"# Split {split_i} of run {run_i}\")\n",
305+
" print(f\"## Split {split_i} of run {run_i} ##\")\n",
305306
"\n",
306307
" parameters_to_optimize = net.parameters()\n",
307308
" optimizer = optim.SGD(parameters_to_optimize, lr=LR, momentum=MOMENTUM, weight_decay=WEIGHT_DECAY)\n",
@@ -326,30 +327,14 @@
326327
" test_accuracy, all_targets, all_preds = manager.test()\n",
327328
"\n",
328329
" logs[run_i][split_i]['test_accuracy'] = test_accuracy\n",
330+
" logs[run_i][split_i]['conf_mat'] = confusion_matrix(all_targets.to('cpu'), all_preds.to('cpu'))\n",
329331
"\n",
330332
" # Add 10 nodes to last FC layer\n",
331333
" manager.increment_classes(n=10)"
332334
],
333335
"execution_count": null,
334336
"outputs": []
335337
},
336-
{
337-
"cell_type": "code",
338-
"metadata": {
339-
"colab_type": "code",
340-
"id": "JWdj0wvu996S",
341-
"colab": {}
342-
},
343-
"source": [
344-
"# Confusion matrix over last run test predictions\n",
345-
"targets = test_dataset.targets\n",
346-
"preds = all_preds.to('cpu').numpy()\n",
347-
"\n",
348-
"plot.heatmap_cm(targets, preds)"
349-
],
350-
"execution_count": null,
351-
"outputs": []
352-
},
353338
{
354339
"cell_type": "code",
355340
"metadata": {
@@ -415,32 +400,6 @@
415400
"## Learning Without Forgetting"
416401
]
417402
},
418-
{
419-
"cell_type": "markdown",
420-
"metadata": {
421-
"id": "ERL_PF-cm1N_",
422-
"colab_type": "text"
423-
},
424-
"source": [
425-
"### Arguments"
426-
]
427-
},
428-
{
429-
"cell_type": "code",
430-
"metadata": {
431-
"colab_type": "code",
432-
"id": "JHBfXPTXm16d",
433-
"colab": {}
434-
},
435-
"source": [
436-
"# Training settings for Learning Without Forgetting\n",
437-
"RANDOM_STATES = [658, 423, 422] \n",
438-
"BATCH_SIZE = 128\n",
439-
"LR = 2"
440-
],
441-
"execution_count": null,
442-
"outputs": []
443-
},
444403
{
445404
"cell_type": "markdown",
446405
"metadata": {
@@ -486,38 +445,36 @@
486445
"test_dataloaders = [[] for i in range(NUM_RUNS)]\n",
487446
"\n",
488447
"for run_i in range(NUM_RUNS):\n",
448+
" test_subsets = []\n",
449+
" random_state = RANDOM_STATES[run_i]\n",
489450
"\n",
490-
" test_subsets = []\n",
491-
" random_state = RANDOM_STATES[run_i]\n",
451+
" for split_i in range(10):\n",
452+
" # Download dataset only at first instantiation\n",
453+
" if run_i+split_i == 0:\n",
454+
" download = True\n",
455+
" else:\n",
456+
" download = False\n",
492457
"\n",
493-
" for split_i in range(CLASS_BATCH_SIZE):\n",
458+
" # Create CIFAR100 dataset\n",
459+
" train_dataset = Cifar100(DATA_DIR, train=True, download=download, random_state=random_state, transform=train_transform)\n",
460+
" test_dataset = Cifar100(DATA_DIR, train=False, download=False, random_state=random_state, transform=test_transform)\n",
494461
"\n",
495-
" # Download dataset only at first instantiation\n",
496-
" if(run_i+split_i == 0):\n",
497-
" download = True\n",
498-
" else:\n",
499-
" download = False\n",
462+
" # Subspace of CIFAR100 of 10 classes\n",
463+
" train_dataset.set_classes_batch(train_dataset.batch_splits[split_i])\n",
464+
" test_dataset.set_classes_batch([test_dataset.batch_splits[i] for i in range(0, split_i+1)])\n",
500465
"\n",
501-
" # Create CIFAR100 dataset\n",
502-
" train_dataset = Cifar100(DATA_DIR, train = True, download = download, random_state = random_state, transform=train_transform)\n",
503-
" test_dataset = Cifar100(DATA_DIR, train = False, download = False, random_state = random_state, transform=test_transform)\n",
504-
" \n",
505-
" # Subspace of CIFAR100 of 10 classes\n",
506-
" train_dataset.set_classes_batch(train_dataset.batch_splits[split_i]) \n",
507-
" test_dataset.set_classes_batch([test_dataset.batch_splits[i] for i in range(0, split_i+1)])\n",
466+
" # Define train and validation indices\n",
467+
" train_indices, val_indices = train_dataset.train_val_split(VAL_SIZE, random_state)\n",
508468
"\n",
509-
" # Define train and validation indices\n",
510-
" train_indices, val_indices = train_dataset.train_val_split(VAL_SIZE, random_state)\n",
511-
" \n",
512-
" train_dataloaders[run_i].append(DataLoader(Subset(train_dataset, train_indices), \n",
513-
" batch_size=BATCH_SIZE, shuffle=True, num_workers=4, drop_last=True))\n",
514-
" \n",
515-
" val_dataloaders[run_i].append(DataLoader(Subset(train_dataset, val_indices), \n",
516-
" batch_size=BATCH_SIZE, shuffle=True, num_workers=4, drop_last=True))\n",
517-
" \n",
518-
" # Dataset with all seen class\n",
519-
" test_dataloaders[run_i].append(DataLoader(test_dataset, \n",
520-
" batch_size=BATCH_SIZE, shuffle=True, num_workers=4)) "
469+
" train_dataloaders[run_i].append(DataLoader(Subset(train_dataset, train_indices),\n",
470+
" batch_size=BATCH_SIZE, shuffle=True, num_workers=4, drop_last=True))\n",
471+
"\n",
472+
" val_dataloaders[run_i].append(DataLoader(Subset(train_dataset, val_indices),\n",
473+
" batch_size=BATCH_SIZE, shuffle=True, num_workers=4, drop_last=True))\n",
474+
"\n",
475+
" # Dataset with all seen class\n",
476+
" test_dataloaders[run_i].append(DataLoader(test_dataset,\n",
477+
" batch_size=BATCH_SIZE, shuffle=True, num_workers=4))"
521478
],
522479
"execution_count": null,
523480
"outputs": []
@@ -534,417 +491,161 @@
534491
"dataiter = iter(test_dataloaders[0][5])\n",
535492
"images, labels = dataiter.next()\n",
536493
"\n",
537-
"plot.image_grid(images, one_channel=False)\n",
538-
"unique_labels = np.unique(labels, return_counts=True)\n",
539-
"unique_labels"
494+
"plot.image_grid(images, one_channel=False)"
540495
],
541496
"execution_count": null,
542497
"outputs": []
543498
},
499+
{
500+
"cell_type": "markdown",
501+
"metadata": {
502+
"id": "iYwMtMJuLyYe",
503+
"colab_type": "text"
504+
},
505+
"source": [
506+
"### Execution"
507+
]
508+
},
544509
{
545510
"cell_type": "code",
546511
"metadata": {
547-
"id": "cw6a_xAumXQW",
512+
"id": "JpGuC_hSL0jN",
548513
"colab_type": "code",
549514
"colab": {}
550515
},
551516
"source": [
552-
"from torch.nn import BCEWithLogitsLoss\n",
553-
"from copy import deepcopy\n",
554-
"\n",
555-
"'''BCE formulation:\n",
556-
" let x = logits, z = labels. The logistic loss is\n",
557-
"\n",
558-
" z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))\n",
559-
"'''\n",
560-
"\n",
561-
" \n",
562-
"CLASS_BATCH_SIZE = 10\n",
563-
"\n",
564-
"\n",
565-
"class LWF():\n",
566-
" def __init__(self, device, net, old_net, criterion, optimizer, scheduler,\n",
567-
" train_dataloader, val_dataloader, test_dataloader, num_classes=10):\n",
568-
" \n",
569-
" self.device = device\n",
570-
"\n",
571-
" self.net = net\n",
572-
" self.best_net = self.net\n",
573-
" self.old_net = old_net # None for first ten classes\n",
574-
"\n",
575-
" self.criterion = BCEWithLogitsLoss() # Classifier criterion \n",
576-
" self.optimizer = optimizer\n",
577-
" self.scheduler = scheduler\n",
578-
"\n",
579-
" self.train_dataloader = train_dataloader\n",
580-
" self.val_dataloader = val_dataloader\n",
581-
" self.test_dataloader = test_dataloader\n",
582-
"\n",
583-
" self.num_classes = num_classes # can be incremented ouitside methods in the main, or inside methods\n",
584-
" self.order = np.arange(100)\n",
585-
"\n",
586-
" self.sigmoid = nn.Sigmoid()\n",
587-
"\n",
588-
"\n",
589-
" def warm_up():\n",
590-
" pass\n",
591-
"\n",
592-
" def increment_classes(self, n=10):\n",
593-
" \"\"\"Add n classes in the final fully connected layer.\"\"\"\n",
594-
"\n",
595-
" in_features = self.net.fc.in_features # size of each input sample\n",
596-
" out_features = self.net.fc.out_features # size of each output sample\n",
597-
" weight = self.net.fc.weight.data\n",
598-
"\n",
599-
" self.net.fc = nn.Linear(in_features, out_features+n)\n",
600-
" self.net.fc.weight.data[:out_features] = weight\n",
601-
"\n",
602-
" def to_onehot(self, targets): \n",
603-
" '''\n",
604-
" Args:\n",
605-
" targets : dataloader.dataset.targets of the new task images\n",
606-
" '''\n",
607-
" one_hot_targets = torch.eye(self.num_classes)[targets]\n",
608-
"\n",
609-
" return one_hot_targets.to(self.device)\n",
610-
"\n",
611-
" def do_first_batch(self, batch, labels):\n",
612-
"\n",
613-
" batch = batch.to(self.device)\n",
614-
" labels = labels.to(self.device) # new classes labels\n",
615-
"\n",
616-
" # Zero-ing the gradients\n",
617-
" self.optimizer.zero_grad()\n",
618-
"\n",
619-
" # One hot encoding of new task labels \n",
620-
" one_hot_labels = self.to_onehot(labels) # Size = [128, 10]\n",
621-
"\n",
622-
" # New net forward pass\n",
623-
" outputs = self.net(batch) \n",
624-
" \n",
625-
" loss = self.criterion(outputs, one_hot_labels) # BCE Loss with sigmoids over outputs\n",
626-
"\n",
627-
" # Get predictions\n",
628-
" _, preds = torch.max(outputs.data, 1)\n",
629-
"\n",
630-
" # Accuracy over NEW IMAGES, not over all images\n",
631-
" running_corrects = \\\n",
632-
" torch.sum(preds == labels.data).data.item() # Può essere che debba usare targets e non labels\n",
633-
"\n",
634-
" # Backward pass: computes gradients\n",
635-
" loss.backward()\n",
636-
"\n",
637-
" self.optimizer.step()\n",
638-
"\n",
639-
" return loss, running_corrects\n",
640-
"\n",
641-
"\n",
642-
" def do_batch(self, batch, labels):\n",
643-
"\n",
644-
" batch = batch.to(self.device)\n",
645-
" labels = labels.to(self.device) # new classes labels\n",
646-
"\n",
647-
" # Zero-ing the gradients\n",
648-
" self.optimizer.zero_grad()\n",
649-
"\n",
650-
" # One hot encoding of new task labels \n",
651-
" one_hot_labels = self.to_onehot(labels) # Size = [128, n_classes] will be sliced as [:, :self.num_classes-10]\n",
652-
" new_classes = (self.order[range(self.num_classes-10, self.num_classes)]).astype(np.int32)\n",
653-
" one_hot_labels = torch.stack([one_hot_labels[:, i] for i in new_classes], axis=1)\n",
654-
"\n",
655-
" # Old net forward pass\n",
656-
" old_outputs = self.sigmoid(self.old_net(batch)) # Size = [128, 100]\n",
657-
" old_classes = (self.order[range(self.num_classes-10)]).astype(np.int32)\n",
658-
" old_outputs = torch.stack([old_outputs[:, i] for i in old_classes], axis =1)\n",
659-
" \n",
660-
" # Combine new and old class targets\n",
661-
" targets = torch.cat((old_outputs, one_hot_labels), 1)\n",
662-
"\n",
663-
" # New net forward pass\n",
664-
" outputs = self.net(batch) # Size = [128, 100] comparable with the define targets\n",
665-
" out_classes = (self.order[range(self.num_classes)]).astype(np.int32)\n",
666-
" outputs = torch.stack([outputs[:, i] for i in out_classes], axis=1)\n",
667-
" \n",
668-
" \n",
669-
" loss = self.criterion(outputs, targets) # BCE Loss with sigmoids over outputs (over targets must be done manually)\n",
670-
"\n",
671-
" # Get predictions\n",
672-
" _, preds = torch.max(outputs.data, 1)\n",
673-
"\n",
674-
" # Accuracy over NEW IMAGES, not over all images\n",
675-
" running_corrects = \\\n",
676-
" torch.sum(preds == labels.data).data.item() \n",
677-
"\n",
678-
" # Backward pass: computes gradients\n",
679-
" loss.backward()\n",
680-
"\n",
681-
" self.optimizer.step()\n",
682-
"\n",
683-
" return loss, running_corrects\n",
684-
"\n",
685-
"\n",
686-
" def do_epoch(self, current_epoch):\n",
687-
"\n",
688-
" self.net.train()\n",
689-
"\n",
690-
" running_train_loss = 0\n",
691-
" running_corrects = 0\n",
692-
" total = 0\n",
693-
" batch_idx = 0\n",
694-
"\n",
695-
" print(f\"Epoch: {current_epoch}, LR: {self.scheduler.get_last_lr()}\")\n",
696-
"\n",
697-
" for images, labels in self.train_dataloader:\n",
698-
"\n",
699-
" if self.num_classes == CLASS_BATCH_SIZE:\n",
700-
" loss, corrects = self.do_first_batch(images, labels)\n",
701-
" else:\n",
702-
" loss, corrects = self.do_batch(images, labels)\n",
703-
"\n",
704-
" running_train_loss += loss.item()\n",
705-
" running_corrects += corrects\n",
706-
" total += labels.size(0)\n",
707-
" batch_idx += 1\n",
708-
"\n",
709-
" self.scheduler.step()\n",
710-
"\n",
711-
" # Calculate average scores\n",
712-
" train_loss = running_train_loss / batch_idx # Average over all batches\n",
713-
" train_accuracy = running_corrects / float(total) # Average over all samples\n",
714-
"\n",
715-
" print(f\"Train loss: {train_loss}, Train accuracy: {train_accuracy}\")\n",
716-
"\n",
717-
" return (train_loss, train_accuracy)\n",
718-
"\n",
517+
"# Arguments for Learning without Forgetting\n",
518+
"BATCH_SIZE = 128\n",
519+
"LR = 2"
520+
],
521+
"execution_count": null,
522+
"outputs": []
523+
},
524+
{
525+
"cell_type": "code",
526+
"metadata": {
527+
"id": "MlThDLCvXJwS",
528+
"colab_type": "code",
529+
"colab": {}
530+
},
531+
"source": [
532+
"logs = [[] for _ in range(NUM_RUNS)]\n",
719533
"\n",
720-
" def train(self, num_epochs):\n",
721-
" \"\"\"Train the network for a specified number of epochs, and save\n",
722-
" the best performing model on the validation set.\n",
723-
" \n",
724-
" Args:\n",
725-
" num_epochs (int): number of epochs for training the network.\n",
726-
" Returns:\n",
727-
" train_loss: loss computed on the last epoch\n",
728-
" train_accuracy: accuracy computed on the last epoch\n",
729-
" val_loss: average loss on the validation set of the last epoch\n",
730-
" val_accuracy: accuracy on the validation set of the last epoch\n",
731-
" \"\"\"\n",
732-
"\n",
733-
" # @todo: is the return behaviour intended? (scores of the last epoch)\n",
734-
"\n",
735-
" self.net = self.net.to(self.device)\n",
736-
" if self.old_net != None:\n",
737-
" self.old_net = self.old_net.to(self.device)\n",
738-
" self.old_net.train(False)\n",
739-
"\n",
740-
" cudnn.benchmark # Calling this optimizes runtime\n",
741-
"\n",
742-
" self.best_loss = float(\"inf\")\n",
743-
" self.best_epoch = 0\n",
744-
"\n",
745-
" for epoch in range(num_epochs):\n",
746-
" # Run an epoch (start counting form 1)\n",
747-
" train_loss, train_accuracy = self.do_epoch(epoch+1)\n",
534+
"# Iterate over runs\n",
535+
"for run_i in range(NUM_RUNS):\n",
536+
" net = resnet32()\n",
748537
" \n",
749-
" # Validate after each epoch \n",
750-
" val_loss, val_accuracy = self.validate() \n",
751-
"\n",
752-
" # Best validation model\n",
753-
" if val_loss < self.best_loss:\n",
754-
" self.best_loss = val_loss\n",
755-
" self.best_net = deepcopy(self.net)\n",
756-
" self.best_epoch = epoch\n",
757-
" print(\"Best model updated\")\n",
758-
"\n",
759-
" print(\"\")\n",
760-
"\n",
761-
" return (train_loss, train_accuracy,\n",
762-
" val_loss, val_accuracy)\n",
763-
"\n",
764-
"\n",
765-
" def validate(self):\n",
766-
" \"\"\"Validate the model.\n",
538+
" criterion = nn.BCEWithLogitsLoss()\n",
767539
" \n",
768-
" Returns:\n",
769-
" val_loss: average loss function computed on the network outputs\n",
770-
" of the validation set (val_dataloader).\n",
771-
" val_accuracy: accuracy computed on the validation set.\n",
772-
" \"\"\"\n",
773-
"\n",
774-
" self.net.train(False)\n",
775-
"\n",
776-
" running_val_loss = 0\n",
777-
" running_corrects = 0\n",
778-
" total = 0\n",
779-
" batch_idx = 0\n",
780-
"\n",
781-
"\n",
782-
" for batch, labels in self.val_dataloader:\n",
783-
" batch = batch.to(self.device)\n",
784-
" labels = labels.to(self.device)\n",
785-
" total += labels.size(0)\n",
786-
"\n",
787-
" # One hot encoding of new task labels \n",
788-
" one_hot_labels = self.to_onehot(labels) # Size = [128, 100] will be sliced as [:, :self.num_classes-10]\n",
789-
" new_classes = (self.order[range(self.num_classes-10, self.num_classes)]).astype(np.int32)\n",
790-
" one_hot_labels = torch.stack([one_hot_labels[:, i] for i in new_classes], axis=1)\n",
791-
"\n",
792-
" if self.num_classes > 10:\n",
793-
" # Old net forward pass\n",
794-
" old_outputs = self.sigmoid(self.old_net(batch)) # Size = [128, 100]\n",
795-
" old_classes = (self.order[range(self.num_classes-10)]).astype(np.int32)\n",
796-
" old_outputs = torch.stack([old_outputs[:, i] for i in old_classes], axis =1)\n",
797-
"\n",
798-
" # Combine new and old class targets\n",
799-
" targets = torch.cat((old_outputs, one_hot_labels), 1)\n",
800-
"\n",
801-
" else:\n",
802-
" targets = one_hot_labels\n",
803-
"\n",
804-
" # New net forward pass\n",
805-
" outputs = self.net(batch) # Size = [128, 100] comparable with the define targets\n",
806-
" out_classes = (self.order[range(self.num_classes)]).astype(np.int32)\n",
807-
" outputs = torch.stack([outputs[:, i] for i in out_classes], axis=1)\n",
808-
"\n",
809-
" \n",
810-
" loss = self.criterion(outputs, targets) # BCE Loss with sigmoids over outputs (over targets must be done manually)\n",
811-
"\n",
812-
" # Get predictions\n",
813-
" _, preds = torch.max(outputs.data, 1)\n",
814-
"\n",
815-
" # Update the number of correctly classified validation samples\n",
816-
" running_corrects += torch.sum(preds == labels.data).data.item()\n",
817-
" running_val_loss += loss.item()\n",
818-
"\n",
819-
" batch_idx += 1\n",
820-
"\n",
821-
" # Calcuate scores\n",
822-
" val_loss = running_val_loss / batch_idx\n",
823-
" val_accuracy = running_corrects / float(total)\n",
824-
"\n",
825-
" print(f\"Validation loss: {val_loss}, Validation accuracy: {val_accuracy}\")\n",
826-
"\n",
827-
" return (val_loss, val_accuracy)\n",
828-
"\n",
829-
"\n",
830-
" def test(self):\n",
831-
" \"\"\"Test the model.\n",
832-
" Returns:\n",
833-
" accuracy (float): accuracy of the model on the test set\n",
834-
" \"\"\"\n",
835-
"\n",
836-
" self.best_net.train(False) # Set Network to evaluation mode\n",
837-
"\n",
838-
" running_corrects = 0\n",
839-
" total = 0\n",
540+
" for split_i in range(10):\n",
541+
" print(f\"## Split {split_i} of run {run_i} ##\")\n",
840542
"\n",
841-
" all_preds = torch.tensor([]) # to store all predictions\n",
842-
" all_preds = all_preds.type(torch.LongTensor)\n",
843-
" \n",
844-
" for images, labels in self.test_dataloader:\n",
845-
" images = images.to(self.device)\n",
846-
" labels = labels.to(self.device)\n",
847-
" total += labels.size(0)\n",
543+
" # Redefine optimizer at each split (pass by reference issue)\n",
544+
" parameters_to_optimize = net.parameters()\n",
545+
" optimizer = optim.SGD(parameters_to_optimize, lr=LR,\n",
546+
" momentum=MOMENTUM, weight_decay=WEIGHT_DECAY)\n",
547+
" scheduler = optim.lr_scheduler.MultiStepLR(optimizer, \n",
548+
" milestones=MILESTONES, gamma=GAMMA)\n",
549+
"\n",
550+
" num_classes = 10*(split_i+1)\n",
551+
"\n",
552+
" if num_classes == 10: # old network == None\n",
553+
" lwf = LWF(DEVICE, net, None, criterion, optimizer, scheduler,\n",
554+
" train_dataloaders[run_i][split_i],\n",
555+
" val_dataloaders[run_i][split_i],\n",
556+
" test_dataloaders[run_i][split_i],\n",
557+
" num_classes)\n",
558+
" else:\n",
559+
" lwf = LWF(DEVICE, net, old_net, criterion, optimizer, scheduler,\n",
560+
" train_dataloaders[run_i][split_i],\n",
561+
" val_dataloaders[run_i][split_i],\n",
562+
" test_dataloaders[run_i][split_i],\n",
563+
" num_classes)\n",
848564
"\n",
849-
" # Forward Pass\n",
850-
" outputs = self.best_net(images)\n",
565+
" scores = lwf.train(NUM_EPOCHS) # train the model\n",
851566
"\n",
852-
" # Get predictions\n",
853-
" _, preds = torch.max(outputs.data, 1)\n",
567+
" logs[run_i].append({})\n",
854568
"\n",
855-
" # Update Corrects\n",
856-
" running_corrects += torch.sum(preds == labels.data).data.item()\n",
569+
" # score[i] = dictionary with key:epoch, value: score\n",
570+
" logs[run_i][split_i]['train_loss'] = scores[0]\n",
571+
" logs[run_i][split_i]['train_accuracy'] = scores[1]\n",
572+
" logs[run_i][split_i]['val_loss'] = scores[2]\n",
573+
" logs[run_i][split_i]['val_accuracy'] = scores[3]\n",
857574
"\n",
858-
" # Append batch predictions\n",
859-
" all_preds = torch.cat(\n",
860-
" (all_preds.to(self.device), preds.to(self.device)), dim=0\n",
861-
" )\n",
575+
" # Test the model on classes seen until now\n",
576+
" test_accuracy, all_targets, all_preds = lwf.test()\n",
862577
"\n",
863-
" # Calculate accuracy\n",
864-
" accuracy = running_corrects / float(total) \n",
578+
" logs[run_i][split_i]['test_accuracy'] = test_accuracy\n",
579+
" logs[run_i][split_i]['conf_mat'] = confusion_matrix(all_targets.to('cpu'), all_preds.to('cpu'))\n",
865580
"\n",
866-
" print(f\"Test accuracy: {accuracy}\")\n",
581+
" old_net = deepcopy(lwf.net)\n",
867582
"\n",
868-
" return (accuracy, all_preds)"
583+
" lwf.increment_classes()"
869584
],
870585
"execution_count": null,
871586
"outputs": []
872587
},
588+
{
589+
"cell_type": "markdown",
590+
"metadata": {
591+
"id": "2xZbK6EGSaZN",
592+
"colab_type": "text"
593+
},
594+
"source": [
595+
"### Plots"
596+
]
597+
},
873598
{
874599
"cell_type": "code",
875600
"metadata": {
876-
"id": "MlThDLCvXJwS",
601+
"id": "bUfgSq1xSbrD",
877602
"colab_type": "code",
878603
"colab": {}
879604
},
880605
"source": [
881-
"train_loss_history = []\n",
882-
"train_accuracy_history = []\n",
883-
"val_loss_history = []\n",
884-
"val_accuracy_history = []\n",
885-
"test_accuracy_history = []\n",
886-
"\n",
887-
"\n",
888-
"\n",
889-
"# Iterate over runs\n",
890-
"for train_dataloader, val_dataloader, test_dataloader in zip(train_dataloaders,\n",
891-
" val_dataloaders, test_dataloaders):\n",
892-
" \n",
893-
" \n",
894-
" train_loss_history.append({})\n",
895-
" train_accuracy_history.append({})\n",
896-
" val_loss_history.append({})\n",
897-
" val_accuracy_history.append({})\n",
898-
" test_accuracy_history.append({})\n",
899-
"\n",
900-
" net = resnet32() # Define the net\n",
901-
" \n",
902-
" criterion = nn.BCEWithLogitsLoss() # Define the loss\n",
903-
" \n",
904-
" \n",
905-
" i = 0\n",
906-
" for train_split, val_split, test_split in zip(train_dataloader,\n",
907-
" val_dataloader, test_dataloader):\n",
908-
" \n",
909-
" # Redefine optimizer at each split (pass by reference issue)\n",
910-
" parameters_to_optimize = net.parameters()\n",
911-
" optimizer = optim.SGD(parameters_to_optimize, lr=LR,\n",
912-
" momentum=MOMENTUM, weight_decay=WEIGHT_DECAY)\n",
913-
" scheduler = optim.lr_scheduler.MultiStepLR(optimizer, \n",
914-
" milestones=MILESTONES, gamma=GAMMA)\n",
915-
" \n",
916-
" current_split = \"Split %i\"%(i)\n",
917-
" print(current_split)\n",
918-
"\n",
919-
" num_classes = CLASS_BATCH_SIZE*(i+1)\n",
920-
"\n",
921-
" if num_classes == CLASS_BATCH_SIZE:\n",
922-
" # Old Network = None\n",
923-
" lwf = LWF(DEVICE, net, None, criterion, optimizer, scheduler,\n",
924-
" train_split, val_split, test_split, num_classes)\n",
925-
" else:\n",
926-
" lwf = LWF(DEVICE, net, old_net, criterion, optimizer, scheduler,\n",
927-
" train_split, val_split, test_split, num_classes)\n",
928-
" \n",
929-
"\n",
930-
" scores = lwf.train(NUM_EPOCHS) # train the model\n",
931-
"\n",
932-
" # score[i] = dictionary with key:epoch, value: score\n",
933-
" train_loss_history[-1][current_split] = scores[0]\n",
934-
" train_accuracy_history[-1][current_split] = scores[1]\n",
935-
" val_loss_history[-1][current_split] = scores[2]\n",
936-
" val_accuracy_history[-1][current_split] = scores[3]\n",
937-
"\n",
938-
" # Test the model on classes seen until now\n",
939-
" test_accuracy, all_preds = lwf.test()\n",
940-
"\n",
941-
" test_accuracy_history[-1][current_split] = test_accuracy\n",
606+
"train_loss = [[logs[run_i][i]['train_loss'] for i in range(10)] for run_i in range(NUM_RUNS)]\n",
607+
"train_accuracy = [[logs[run_i][i]['train_accuracy'] for i in range(10)] for run_i in range(NUM_RUNS)]\n",
608+
"val_loss = [[logs[run_i][i]['val_loss'] for i in range(10)] for run_i in range(NUM_RUNS)]\n",
609+
"val_accuracy = [[logs[run_i][i]['val_accuracy'] for i in range(10)] for run_i in range(NUM_RUNS)]\n",
610+
"test_accuracy = [[logs[run_i][i]['test_accuracy'] for i in range(10)] for run_i in range(NUM_RUNS)]\n",
942611
"\n",
943-
" # Uncomment if default resnet has 10 node at last FC layer\n",
944-
" old_net = deepcopy(lwf.net)\n",
945-
" lwf.increment_classes()\n",
612+
"train_loss = np.array(train_loss)\n",
613+
"train_accuracy = np.array(train_accuracy)\n",
614+
"val_loss = np.array(val_loss)\n",
615+
"val_accuracy = np.array(val_accuracy)\n",
616+
"test_accuracy = np.array(test_accuracy)\n",
946617
"\n",
947-
" i =i+1"
618+
"train_loss_stats = np.array([train_loss.mean(0), train_loss.std(0)]).transpose()\n",
619+
"train_accuracy_stats = np.array([train_accuracy.mean(0), train_accuracy.std(0)]).transpose()\n",
620+
"val_loss_stats = np.array([val_loss.mean(0), val_loss.std(0)]).transpose()\n",
621+
"val_accuracy_stats = np.array([val_accuracy.mean(0), val_accuracy.std(0)]).transpose()\n",
622+
"test_accuracy_stats = np.array([test_accuracy.mean(0), test_accuracy.std(0)]).transpose()"
623+
],
624+
"execution_count": null,
625+
"outputs": []
626+
},
627+
{
628+
"cell_type": "code",
629+
"metadata": {
630+
"colab_type": "code",
631+
"id": "1w3_YPJCSeli",
632+
"colab": {}
633+
},
634+
"source": [
635+
"plot.train_val_scores(train_loss_stats, train_accuracy_stats, val_loss_stats, val_accuracy_stats)"
636+
],
637+
"execution_count": null,
638+
"outputs": []
639+
},
640+
{
641+
"cell_type": "code",
642+
"metadata": {
643+
"colab_type": "code",
644+
"id": "ZSt6-FJbSelp",
645+
"colab": {}
646+
},
647+
"source": [
648+
"plot.test_scores(test_accuracy_stats)"
948649
],
949650
"execution_count": null,
950651
"outputs": []
@@ -1108,32 +809,6 @@
1108809
"execution_count": null,
1109810
"outputs": []
1110811
},
1111-
{
1112-
"cell_type": "code",
1113-
"metadata": {
1114-
"id": "by8c4Aaa-8ms",
1115-
"colab_type": "code",
1116-
"colab": {}
1117-
},
1118-
"source": [
1119-
"print(logs_icarl)"
1120-
],
1121-
"execution_count": null,
1122-
"outputs": []
1123-
},
1124-
{
1125-
"cell_type": "code",
1126-
"metadata": {
1127-
"id": "qOCjSFJy_ANm",
1128-
"colab_type": "code",
1129-
"colab": {}
1130-
},
1131-
"source": [
1132-
"obj_save(logs_icarl, 'hybrid1_confmat')"
1133-
],
1134-
"execution_count": null,
1135-
"outputs": []
1136-
},
1137812
{
1138813
"cell_type": "markdown",
1139814
"metadata": {
@@ -1200,4 +875,4 @@
1200875
"outputs": []
1201876
}
1202877
]
1203-
}
878+
}

0 commit comments

Comments
 (0)
This repository has been archived.