Skip to content

Commit

Permalink
Added url input feature
Browse files Browse the repository at this point in the history
  • Loading branch information
tamnvhust1 committed Apr 22, 2021
1 parent 281d58d commit 9f5c0cc
Show file tree
Hide file tree
Showing 7 changed files with 177 additions and 27 deletions.
77 changes: 60 additions & 17 deletions api/app.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
import sys
import os
import wget
import mimetypes
sys.path.insert(0, os.path.realpath(os.path.pardir))
from fastapi import FastAPI, File, UploadFile
from fastapi import FastAPI, File, UploadFile, Form
from fastapi.staticfiles import StaticFiles
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from celery_tasks.tasks import predict_image
from celery_tasks.tasks import predict_image, predict_video
from celery.result import AsyncResult
from models import Task, Prediction
import uuid
import logging
from pydantic.typing import List
from pydantic.typing import List, Optional
import numpy as np

UPLOAD_FOLDER = 'uploads'
Expand Down Expand Up @@ -41,30 +43,71 @@


@app.post('/api/process')
async def process(files: List[UploadFile] = File(...)):
async def process(files: Optional[List[UploadFile]] = File(None), url: Optional[str] = Form(None)):
tasks = []
try:
for file in files:
if files is None and url is None:
raise Exception('No input found')

if files is not None:
for file in files:
d = {}
try:
name = str(uuid.uuid4()).split('-')[0]
ext = file.filename.split('.')[-1]
file_name = f'{UPLOAD_FOLDER}/{name}.{ext}'
# start task prediction
# Check file type (image or video)
# TODO: other cases
mimestart = mimetypes.guess_type(file_name)[0]
if mimestart is not None:
mimestart = mimestart.split('/')[0]
if mimestart in ('video', 'image'):
with open(file_name, 'wb+') as f:
f.write(file.file.read())
f.close()

if mimestart == 'image':
task_id = predict_image.delay(os.path.join('api', file_name))
d['task_id'] = str(task_id)
d['status'] = 'PROCESSING'
d['url_result'] = f'/api/result/{task_id}'
elif mimestart == 'video':
task_id = predict_video.delay(os.path.join('api', file_name))
d['task_id'] = str(task_id)
d['status'] = 'PROCESSING'
d['url_result'] = f'/api/result/{task_id}'
except Exception as ex:
logging.info(ex)
d['task_id'] = str(task_id)
d['status'] = 'ERROR'
d['url_result'] = ''
tasks.append(d)
elif url is not None:
d = {}
try:
name = str(uuid.uuid4()).split('-')[0]
ext = file.filename.split('.')[-1]
file_name = f'{UPLOAD_FOLDER}/{name}.{ext}'
with open(file_name, 'wb+') as f:
f.write(file.file.read())
f.close()

# start task prediction
task_id = predict_image.delay(os.path.join('api', file_name))
d['task_id'] = str(task_id)
d['status'] = 'PROCESSING'
d['url_result'] = f'/api/result/{task_id}'
path = wget.download(url, out=f'{UPLOAD_FOLDER}')
print('path', path)
mimestart = mimetypes.guess_type(path)[0]
if mimestart is not None:
mimestart = mimestart.split('/')[0]
if mimestart == 'image':
task_id = predict_image.delay(os.path.join('api', path))
d['task_id'] = str(task_id)
d['status'] = 'PROCESSING'
d['url_result'] = f'/api/result/{task_id}'
elif mimestart == 'video':
task_id = predict_video.delay(os.path.join('api', path))
d['task_id'] = str(task_id)
d['status'] = 'PROCESSING'
d['url_result'] = f'/api/result/{task_id}'
except Exception as ex:
logging.info(ex)
d['task_id'] = str(task_id)
d['status'] = 'ERROR'
d['url_result'] = ''
tasks.append(d)

return JSONResponse(status_code=202, content=tasks)
except Exception as ex:
logging.info(ex)
Expand Down
12 changes: 12 additions & 0 deletions celery_tasks/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,15 @@ def predict_image(self, data):
self.retry(countdown=2)
except MaxRetriesExceededError as ex:
return {'status': 'FAIL', 'result': 'max retried achieved'}


@app.task(ignore_result=False, bind=True, base=PredictTask)
def predict_video(self, data):
try:
data_pred = self.model.predict_video(data)
return {'status': 'SUCCESS', 'result': data_pred}
except Exception as ex:
try:
self.retry(countdown=2)
except MaxRetriesExceededError as ex:
return {'status': 'FAIL', 'result': 'max retried achieved'}
57 changes: 56 additions & 1 deletion celery_tasks/yolo.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import os
import logging
from typing import final
import cv2
import torch
import matplotlib.pyplot as plt

Expand All @@ -7,6 +10,7 @@ class YoloModel:
def __init__(self):
self.model = torch.hub.load('ultralytics/yolov5', 'yolov5x', pretrained=True)
self.model.eval()
self.output_video_fps = 15

def predict(self, img):
try:
Expand All @@ -28,7 +32,58 @@ def predict(self, img):
preds['class'] = result.names[int(cls)]
data.append(preds)

return {'file_name': file_name, 'bbox': data}
return {'file_name': file_name, 'bbox': data, 'mimetype': 'image'}
except Exception as ex:
logging.error(str(ex))
return None

def predict_video(self, video_url):
try:
cap = None
writer = None

cap = cv2.VideoCapture(video_url)
frame_cnt = 0
final_result = {}
out_filename = video_url.split('/')[-1]
avi_filename = os.path.splitext(out_filename)[0] + '.mp4'
out_filepath = os.path.join('api/static/results', avi_filename)
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
writer = cv2.VideoWriter(out_filepath, cv2.VideoWriter_fourcc(*'mp4v'), self.output_video_fps, (frame_width, frame_height))
filename = f'static/{avi_filename}'
with torch.no_grad():
while True:
ret, frame = cap.read()
if not ret:
break

if frame_cnt > 25:
break
frame_cnt += 1
result = self.model(frame)
result.render()
writer.write(result.imgs[0])
data = []

for i in range(len(result.xywhn[0])):
x, y, w, h, prob, cls = result.xywhn[0][i].numpy()
preds = {}
preds['x'] = str(x)
preds['y'] = str(y)
preds['w'] = str(w)
preds['h'] = str(h)
preds['prob'] = str(prob)
preds['class'] = result.names[int(cls)]
data.append(preds)
final_result[frame_cnt] = data

return {'file_name': filename, 'bbox': final_result, 'mimetype': 'video'}
except Exception as e:
logging.error(str(e))
return None
finally:
if writer is not None:
writer.release()
if cap is not None:
cap.release()
3 changes: 3 additions & 0 deletions docker-compose.yml
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
version: "3"
services:
rabbitmq:
container_name: rabbitmq
Expand Down Expand Up @@ -29,5 +30,7 @@ services:
dockerfile: webapp/Dockerfile
context: .
command: sh -c "cd /app && uvicorn app:app --host 0.0.0.0 --port 80 --reload"
volumes:
- ./webapp:/app
ports:
- "80:80"
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@ python-multipart
pyyaml
seaborn
uvicorn
redis
redis
wget
19 changes: 16 additions & 3 deletions webapp/templates/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,14 @@ <h2> Yolo v5 - Object Detection </h2>
<div class="container">
<div class="col-sm-12">
<div class="form-row">
<div class="form-group col-md-8">
<label for="file" id="label_file">&nbsp;</label>
<div class="form-group col-md-4">
<label for="file" id="label_file">From local files</label>
<input class="form-control" type="file" id="input_file" multiple />
</div>
<div class="form-group col-md-4">
<label for="input_url" id="label_url">or url</label>
<input class="form-control" type="text" id="input_url"/>
</div>
<div class="form-group col-md-4">
<label>&nbsp;</label>
<button type="button" class="btn btn-primary btn-block" id="btn-process">Upload and Process</button>
Expand All @@ -44,7 +48,16 @@ <h2> Yolo v5 - Object Detection </h2>
</div>
<div class="form-row" id="row_detail">
<div class="col-md-3"><textarea id="result_txt" style="height: 800px; font-size: 12px;"></textarea></div>
<div class="col-md-9"><a id="result_link"><img id="result_img" width="600px" /></a></div>
<div class="col-md-9">
<div>
<a id="result_image_link">
<img id="result_img" width="600px" />
</a>
<a id="result_video_link">
</a>
</div>
</a>
</div>
</div>

</div>
Expand Down
33 changes: 28 additions & 5 deletions webapp/templates/static/app.js
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,15 @@ jQuery(document).ready(function () {
$("#row_results").hide();
$('#btn-process').on('click', function () {
var form_data = new FormData();
files = $('#input_file').prop('files')
for (i = 0; i < files.length; i++)
form_data.append('files', $('#input_file').prop('files')[i]);
url = $('#input_url')[0].value
if (url.length > 0) {
form_data.append('url', url)
}
else {
files = $('#input_file').prop('files')
for (i = 0; i < files.length; i++)
form_data.append('files', $('#input_file').prop('files')[i]);
}

$.ajax({
url: URL + '/api/process',
Expand Down Expand Up @@ -75,8 +81,25 @@ jQuery(document).ready(function () {
if (data['status'] == 'SUCCESS') {
$('#row_detail').show()
$('#result_txt').val(JSON.stringify(res.result['bbox'], undefined, 4))
$('#result_img').attr('src', URL + '/' + res.result.file_name)
$('#result_link').attr('href', URL + '/' + res.result.file_name)
console.log(res.result['mimetype']);
if (res.result['mimetype'] == 'image') {
$('#result_img').attr('src', URL + '/' + res.result.file_name)
$('#result_image_link').attr('href', URL + '/' + res.result.file_name)
$('#result_image_link').show()
$('#result_video_link').hide()
} else if (res.result['mimetype'] == 'video') {
// $('#result_video').attr('src', URL + '/' + res.result.file_name)
// $('#result_video')[0].load();
// $('#result_video_link').attr('href', URL + '/' + res.result.file_name)
// $('#result_image_link').hide()

// $('#result_video').attr('src', URL + '/' + res.result.file_name)
// $('#result_video_link video')[0].load()
$('#result_video_link').attr('href', URL + '/' + res.result.file_name)
$('#result_video_link').text(URL + '/' + res.result.file_name)
$('#result_video_link').show()
$('#result_image_link').hide()
}
} else {
alert('Result not ready or already consumed!')
$('#row_detail').hide()
Expand Down

0 comments on commit 9f5c0cc

Please sign in to comment.