Skip to content

Commit

Permalink
introduce resnet6 (#128)
Browse files Browse the repository at this point in the history
  • Loading branch information
tremblerz authored Oct 28, 2024
1 parent 28b8291 commit 3649350
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 17 deletions.
40 changes: 24 additions & 16 deletions src/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from torch import nn
import torch.nn.functional as F
import torch
from typing import List, Type, Optional, Tuple
from typing import List, Optional, Tuple


class BasicBlock(nn.Module):
Expand All @@ -23,7 +23,7 @@ class BasicBlock(nn.Module):

def __init__(self, in_planes: int, planes: int, stride: int = 1) -> None:

super(BasicBlock, self).__init__()
super(BasicBlock, self).__init__() # type: ignore
self.conv1 = nn.Conv2d(
in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False
)
Expand Down Expand Up @@ -65,7 +65,7 @@ class Bottleneck(nn.Module):
expansion = 4

def __init__(self, in_planes: int, planes: int, stride: int = 1):
super().__init__()
super().__init__() # type: ignore
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(
Expand Down Expand Up @@ -109,28 +109,29 @@ class ResNet(nn.Module):

def __init__(
self,
block: Type[nn.Module],
block: type[BasicBlock | Bottleneck],
num_blocks: List[int],
num_classes: int = 10,
num_channels: int = 3,
in_planes: int = 64,
) -> None:
super().__init__()
self.in_planes = 64
super().__init__() # type: ignore
self.in_planes = in_planes
self.conv1 = nn.Conv2d(
num_channels, 64, kernel_size=3, stride=1, padding=1, bias=False
num_channels, in_planes, kernel_size=3, stride=1, padding=1, bias=False
)
self.bn1 = nn.BatchNorm2d(64)
self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
self.linear = nn.Linear(512 * block.expansion, num_classes)
self.bn1 = nn.BatchNorm2d(in_planes)
self.layer1 = self._make_layer(block, in_planes, num_blocks[0], stride=1)
self.layer2 = self._make_layer(block, in_planes * 2, num_blocks[1], stride=2)
self.layer3 = self._make_layer(block, in_planes * 4, num_blocks[2], stride=2)
self.layer4 = self._make_layer(block, in_planes * 8, num_blocks[3], stride=2)
self.linear = nn.Linear(in_planes * 8 * block.expansion, num_classes)

def _make_layer(
self, block: Type[nn.Module], planes: int, num_blocks: int, stride: int
self, block: type[BasicBlock | Bottleneck], planes: int, num_blocks: int, stride: int
) -> nn.Sequential:
strides = [stride] + [1] * (num_blocks - 1)
layers = []
layers: List[BasicBlock|Bottleneck] = []
for stride in strides:
layers.append(block(self.in_planes, planes, stride))
self.in_planes = planes * block.expansion
Expand Down Expand Up @@ -168,7 +169,14 @@ def forward(

if out_feature:
return x, feature
return x
return x # type: ignore


def resnet6(num_channels: int = 3, num_classes: int = 10) -> ResNet:
"""
Constructs a ResNet-6 model.
"""
return ResNet(BasicBlock, [1, 1, 1, 0], num_classes, num_channels, 16)


def resnet10(num_channels: int = 3, num_classes: int = 10) -> ResNet:
Expand Down
6 changes: 5 additions & 1 deletion src/utils/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,11 @@ def get_model(
self.dset = dset
# TODO: add support for loading checkpointed models
model_name = model_name.lower()
if model_name == "resnet10":
if model_name == "resnet6":
if pretrained:
raise ValueError("Pretrained model not available for resnet6")
model = resnet.resnet6(**kwargs)
elif model_name == "resnet10":
if pretrained:
raise ValueError("Pretrained model not available for resnet10")
model = resnet.resnet10(**kwargs)
Expand Down

0 comments on commit 3649350

Please sign in to comment.