Skip to content

Commit

Permalink
Added torchhub
Browse files Browse the repository at this point in the history
  • Loading branch information
VCasecnikovs committed May 31, 2020
1 parent 5869720 commit 5cd9fa5
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 2 deletions.
21 changes: 21 additions & 0 deletions hubconf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import torch
from model import YOLOv4

dependencies = ['torch']

def yolov4(pretrained=False, n_classes=80):
"""
YOLOv4 model
pretrained (bool): kwargs, load pretrained weights into the model
n_classes(int): amount of classes
"""
m = YOLOv4(n_classes=n_classes)
if pretrained:
try: #If we change input or output layers amount, we will have an option to use pretrained weights
m.load_state_dict(torch.hub.load_state_dict_from_url("https://github.com/VCasecnikovs/Yet-Another-YOLOv4-Pytorch/releases/download/V1.0/yolov4.pth"), strict=False)
except RuntimeError as e:
print(f'[Warning] Ignoring {e}')

return m


10 changes: 8 additions & 2 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,7 +456,7 @@ def forward(self, x, targets=None):


class YOLOv4(nn.Module):
def __init__(self, in_channels = 3, n_classes = 80, weights_path=None, img_dim=608, anchors=None):
def __init__(self, in_channels = 3, n_classes = 80, weights_path=None, pretrained=False, img_dim=608, anchors=None):
super().__init__()
if anchors is None:
anchors = [[[10, 13], [16, 30], [33, 23]],
Expand All @@ -478,9 +478,15 @@ def __init__(self, in_channels = 3, n_classes = 80, weights_path=None, img_dim=6

if weights_path:
try: #If we change input or output layers amount, we will have an option to use pretrained weights
ret = self.load_state_dict(torch.load(weights_path), strict=False)
self.load_state_dict(torch.load(weights_path), strict=False)
except RuntimeError as e:
print(f'[Warning] Ignoring {e}')
elif pretrained:
try: #If we change input or output layers amount, we will have an option to use pretrained weights
self.load_state_dict(torch.hub.load_state_dict_from_url("https://github.com/VCasecnikovs/Yet-Another-YOLOv4-Pytorch/releases/download/V1.0/yolov4.pth"), strict=False)
except RuntimeError as e:
print(f'[Warning] Ignoring {e}')



def forward(self, x, y=None):
Expand Down

0 comments on commit 5cd9fa5

Please sign in to comment.