-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 0e9b715
Showing
11 changed files
with
467 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
/runs | ||
/runs/* | ||
*.pyc | ||
__pycache__/ | ||
*.py[cod] | ||
*$py.class | ||
.DS_store | ||
venv | ||
*.so |
Binary file not shown.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .vgg import VGG |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
import torch | ||
import torch.nn as nn | ||
from functools import partial | ||
|
||
class ConvBnAct(nn.Module): | ||
def __init__( | ||
self, in_feat:int, out_feat:int, kernel_size:int, stride:int, padding:int, | ||
bn:bool=True, | ||
act:bool=True | ||
) -> None: | ||
super().__init__() | ||
self.conv = nn.Conv2d( | ||
in_feat, out_feat, kernel_size=kernel_size, stride=stride, padding=padding, bias=not bn | ||
) | ||
if bn: | ||
self.bn = nn.BatchNorm2d(out_feat) | ||
else: | ||
self.bn = None | ||
if act: | ||
self.act = nn.ReLU() | ||
else: | ||
self.act = None | ||
def forward(self, x:torch.Tensor) -> torch.Tensor: | ||
x = self.conv(x) | ||
if self.bn is not None: | ||
x = self.bn(x) | ||
if self.act is not None: | ||
x = self.act(x) | ||
return x | ||
|
||
class downsample(nn.Module): | ||
def __init__(self, feat): | ||
super().__init__() | ||
self.body = nn.Sequential( | ||
nn.LazyConv2d(feat, kernel_size=1, stride=2, padding=0), | ||
nn.BatchNorm2d(feat) | ||
) | ||
def forward(self, x:torch.Tensor) -> torch.Tensor: | ||
return self.body(x) | ||
|
||
class basicResBlock(nn.Module): | ||
def __init__(self, in_feat, out_feat, half:bool): | ||
super().__init__() | ||
self.conv_1 = ConvBnAct(in_feat, out_feat, 3, 1, 1) | ||
if not half: | ||
self.conv_2 = ConvBnAct(out_feat, out_feat, 3, 1, 1, bn=True, act=False) | ||
self.downsample = None | ||
else: | ||
# if half | ||
self.conv_2 = ConvBnAct(out_feat, out_feat, 3, 2, 1, bn=True, act=False) | ||
self.downsample = downsample(out_feat) | ||
|
||
self.final_act = nn.ReLU() | ||
|
||
def forward(self, x:torch.Tensor) -> torch.Tensor: | ||
residual = x | ||
x = self.conv_1(x) | ||
x = self.conv_2(x) | ||
if self.downsample is not None: | ||
residual = self.downsample(residual) | ||
x += residual | ||
return self.final_act(x) | ||
|
||
# if resnet depth >= 50 | ||
class bottleneck(nn.Module): | ||
pass | ||
|
||
|
||
class RESNET18(nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.stem = ConvBnAct(3, 64, 7, 2, 3) | ||
self.stage1 = nn.Sequential( | ||
nn.MaxPool2d(kernel_size=3, stride=2, padding=1), | ||
basicResBlock(64, 64, False), | ||
basicResBlock(64, 64, False) | ||
) | ||
self.stage2 = nn.Sequential( | ||
basicResBlock(64, 128, True), | ||
basicResBlock(128, 128, False) | ||
) | ||
self.stage3 = nn.Sequential( | ||
basicResBlock(128, 256, True), | ||
basicResBlock(256, 256, False), | ||
) | ||
self.stage4 = nn.Sequential( | ||
basicResBlock(256, 512, True), | ||
basicResBlock(512, 512, False), | ||
) | ||
self.head = nn.Sequential( | ||
nn.AdaptiveAvgPool2d((1, 1)), | ||
nn.Flatten(), | ||
nn.Linear(512, 1000) | ||
) | ||
|
||
def forward(self, x:torch.Tensor) -> torch.Tensor: | ||
x = self.stem(x) | ||
x = self.stage1(x) | ||
x = self.stage2(x) | ||
x = self.stage3(x) | ||
x = self.stage4(x) | ||
x = self.head(x) | ||
return x | ||
|
||
|
||
if __name__ == "__main__": | ||
# random_input = torch.randn(1, 3, 224, 224) | ||
# vanilla_conv = ConvBnAct(3, 6, 3, 1, 1, False, True) | ||
# random_output = vanilla_conv(random_input) | ||
# print(random_output.shape) | ||
|
||
# random_input = torch.randn(1, 64, 112, 112) | ||
# first_resblock = basicResBlock(64, False) | ||
# random_output = first_resblock(random_input) | ||
# print(random_output.shape) | ||
|
||
# second_resblock = basicResBlock(64, True) | ||
# random_output = second_resblock(random_output) | ||
# print(random_output.shape) | ||
random_input = torch.randn(1, 3, 224, 224) | ||
|
||
# test resnet18 | ||
resnet_18 = RESNET18() | ||
resnet_18_output = resnet_18(random_input) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
import torch | ||
import torch.nn as nn | ||
|
||
# VGG 블록 정의 | ||
class VGGBLock(nn.Module): | ||
def __init__(self, num_convs, out_channels): | ||
super().__init__() | ||
layers = [] | ||
for _ in range(num_convs): | ||
layers.append( | ||
nn.LazyConv2d(out_channels, kernel_size=3, stride=1, padding=1) | ||
) | ||
layers.append( | ||
nn.LazyBatchNorm2d() | ||
) | ||
layers.append( | ||
nn.ReLU() | ||
) | ||
layers.append( | ||
# kernel=2, stride=2 이므로 H, W 값이 각각 0.5*H, 0.5*W로 축소 | ||
nn.MaxPool2d(kernel_size=2, stride=2) | ||
) | ||
self.layers = nn.Sequential(*layers) | ||
|
||
def forward(self, x): | ||
return self.layers(x) | ||
|
||
# VGG 네트워크 정의 | ||
class VGG(nn.Module): | ||
# 논문에 제시된 기본 설계 | ||
default_config = ((1, 64), (1, 128), (2, 256), (2, 512), (2, 512)) | ||
|
||
def __init__(self, cfg=None, num_classes=1000): | ||
super().__init__() | ||
# config파일이 따로 주어지지 않는다면 기본값을 사용 | ||
cfg = self.default_config if cfg is None else cfg | ||
conv_blks = [] | ||
|
||
# config의 내용에 따라 네트워크를 구성 | ||
for (num_convs, out_channels) in cfg: | ||
conv_blks.append(VGGBLock(num_convs, out_channels)) | ||
|
||
# Iterable unpacking | ||
self.backbone = nn.Sequential( | ||
*conv_blks | ||
) | ||
|
||
# 분류를 위한 헤드 부분 | ||
self.head = nn.Sequential( | ||
nn.Flatten(), | ||
nn.LazyLinear(4096), nn.ReLU(), nn.Dropout(0.1), | ||
nn.LazyLinear(4096), nn.ReLU(), nn.Dropout(0.1), | ||
nn.LazyLinear(num_classes) | ||
) | ||
|
||
def forward(self, x): | ||
feature = self.backbone(x) | ||
preds = self.head(feature) | ||
return preds | ||
|
||
if __name__ == "__main__": | ||
net = VGG() | ||
random_input = torch.randn(1, 3, 224, 224) | ||
b = net(random_input) | ||
print(b.shape) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
import os | ||
import argparse | ||
import shutil | ||
import torch | ||
import torch.nn as nn | ||
from torch.optim import SGD, Adam | ||
from tqdm import tqdm | ||
from models import VGG | ||
from utils import get_dataloader, get_current_datetime, AverageMeter, save_dict_as_json | ||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--model", type=str, default="vgg") | ||
parser.add_argument("--imgsz", type=int, default=224) | ||
parser.add_argument("--batch-size", type=int, default=32) | ||
parser.add_argument("--num-workers", type=int, default=8) | ||
parser.add_argument("--optimizer", type=str, default='sgd') | ||
parser.add_argument("--lr", type=float, default=1e-3) | ||
parser.add_argument("--epochs", type=int, default=300) | ||
parser.add_argument("--save_interval", type=int, default=50) | ||
|
||
def prepare(opt): | ||
SAVE_PATH = os.path.join("./runs", opt.model.upper()+"_"+get_current_datetime()) | ||
if os.path.exists(SAVE_PATH): | ||
shutil.rmtree(SAVE_PATH) | ||
os.makedirs(SAVE_PATH) | ||
else: | ||
# os.path.exists(SAVE_PATH) == False | ||
os.makedirs(SAVE_PATH) | ||
TRAIN_CONFIG = os.path.join(SAVE_PATH, "config.json") | ||
save_dict_as_json(TRAIN_CONFIG, vars(opt)) | ||
|
||
PTH_DIR = os.path.join(SAVE_PATH, "weights") | ||
os.makedirs(PTH_DIR) | ||
|
||
return SAVE_PATH, PTH_DIR | ||
|
||
# optimizer 관련 코드 | ||
def get_optimizer(name:str): | ||
if name.lower() == 'sgd': | ||
return SGD | ||
elif name.lower() == 'adam': | ||
return Adam | ||
|
||
if __name__ == "__main__": | ||
if not os.path.exists("./runs"): | ||
os.makedirs("./runs") | ||
|
||
opt = parser.parse_args() | ||
# 모델 가중치 저장 디렉토리 | ||
SAVE_PATH, PTH_DIR = prepare(opt) | ||
|
||
# 신경망 | ||
net = eval(opt.model.upper())().cuda() | ||
|
||
# 손실함수 | ||
criterion = nn.CrossEntropyLoss() | ||
|
||
# 최적화 | ||
optimizer_type = get_optimizer(opt.optimizer) | ||
optimizer = optimizer_type(net.parameters(), lr=opt.lr) | ||
|
||
# 데이터 파이프라인 | ||
train_loader, val_loader = get_dataloader(opt.batch_size, opt.imgsz, opt.num_workers) | ||
|
||
for e in range(1, opt.epochs+1): | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from .datasets import get_dataloader | ||
from .misc import get_current_datetime, AverageMeter, save_dict_as_json |
Oops, something went wrong.