Skip to content

Commit

Permalink
video and readme with imgs
Browse files Browse the repository at this point in the history
  • Loading branch information
VCasecnikovs committed May 30, 2020
1 parent 8a7cd85 commit cfc8781
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 11 deletions.
46 changes: 37 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Yet-Another-YOLOv4-Pytorch
![](github_imgs/from_net.png)

This is implementation of YOLOv4 object detection neural network on pytorch. I'll try to implement all features of original paper.

Expand All @@ -14,12 +15,26 @@ This is implementation of YOLOv4 object detection neural network on pytorch. I'l
- [ ] Self attention attack
- [ ] Notebook with guide

## What you can already do
You can use video_demo.py to take a look at the original weights realtime OD detection. (Have 9 fps on my GTX1060 laptop!!!)
![](/github_imgs/realtime.jpg)

You can train your own model with mosaic augmentation for training. Guides how to do this are written below. Borders of images on some datasets are even hard to find.
![](/github_imgs/mosaic.png)


You can make inference, guide bellow.


## Initialize NN

import model
#If you change n_classes from the pretrained, there will be caught one error, don't panic it is ok
m = model.YOLOv4(n_classes=1, weights_path="weights/yolov4.pth")

## Download weights
You can download weights using from this link: https://drive.google.com/open?id=12AaR4fvIQPZ468vhm0ZYZSLgWac2HBnq

## Initialize dataset

import dataset
Expand All @@ -43,22 +58,35 @@ dataset has collate_function
paths_b, xb, yb = d.collate_fn((y1, y2))
# yb has 6 columns

### Bboxes format
## Y's format
1. Num of img to which this anchor belongs
2. BBox class
3. x center
4. y center
5. width
6. height

### Forward with loss
(y_hat1, y_hat2, y_hat3), (loss_1, loss_2, loss_3) = m(xb, yb)

### Forward without loss
(y_hat1, y_hat2, y_hat3), _ = m(img_batch) #_ is (0, 0, 0)
## Forward with loss
y_hat, loss = m(xb, yb)

!!! y_hat is already resized anchors to image size bboxes

### Check if bboxes are correct
## Forward without loss
y_hat, _ = m(img_batch) #_ is (0, 0, 0)

## Check if bboxes are correct
import utils
from PIL import Image
path, img, bboxes = d[0]
img_with_bboxes = utils.get_img_with_bboxes(img, bboxes) #PIL image

img_with_bboxes = utils.get_img_with_bboxes(img, bboxes[:, 2:]) #Returns numpy array
Image.fromarray(img_with_bboxes)

## Get predicted bboxes
anchors, loss = m(xb.cuda(), yb.cuda())
confidence_threshold = 0.05
iou_threshold = 0.5
bboxes, labels = utils.get_bboxes_from_anchors(anchors, confidence_threshold, iou_threshold, coco_dict) #COCO dict is id->class dictionary (f.e. 0->person)
#For first img
arr = utils.get_img_with_bboxes(xb[0].cpu(), bboxes[0].cpu(), resize=False, labels=labels[0])
Image.fromarray(arr)

Binary file added github_imgs/from_net.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added github_imgs/mosaic.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added github_imgs/realtime.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
14 changes: 12 additions & 2 deletions video_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from torch.backends import cudnn
import torch
import utils
import time

coco_dict = {0: 'person',
1: 'bicycle',
Expand Down Expand Up @@ -97,11 +98,15 @@

m = m.cuda()

#To warm up JIT
m(torch.zeros((1, 3, 608, 608)).cuda())

cap = cv2.VideoCapture(0)

frames_n = 0
start_time = time.time()

while True:
print("Frame got")
ret, frame = cap.read()
if not ret:
break
Expand All @@ -123,14 +128,19 @@

bboxes, labels = utils.get_bboxes_from_anchors(anchors, 0.4, 0.5, coco_dict)
arr = utils.get_img_with_bboxes(x.cpu(), bboxes[0].cpu(), resize=False, labels=labels[0])

arr = cv2.cvtColor(arr, cv2.COLOR_RGB2BGR)

frames_n += 1

arr = cv2.putText(arr, "FPS: " + str(frames_n / (time.time() - start_time)), (100, 100), cv2.FONT_HERSHEY_DUPLEX, 0.75, (255, 255, 255))

cv2.imshow("test", arr)
if cv2.waitKey(1) & 0xFF == ord('q'):
break







0 comments on commit cfc8781

Please sign in to comment.