-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel.py
66 lines (58 loc) · 2.51 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
from label_studio_ml.model import LabelStudioMLBase
import requests, os
from ultralytics import YOLO
from PIL import Image
from io import BytesIO
import numpy as np
LS_URL = os.environ['LABEL_STUDIO_BASEURL']
LS_API_TOKEN = os.environ['LABEL_STUDIO_API_TOKEN']
class YOLOv8Model(LabelStudioMLBase):
def __init__(self, **kwargs):
# Call base class constructor
super(YOLOv8Model, self).__init__(**kwargs)
from_name, schema = list(self.parsed_label_config.items())[0]
self.from_name = from_name
self.to_name = schema['to_name'][0]
self.labels = ['car']
self.model = YOLO("best.pt")
def predict(self, tasks, **kwargs):
""" This is where inference happens: model returns
the list of predictions based on input list of tasks
"""
task = tasks[0]
predictions = []
score = 0
header = {
"Authorization": "Token " + LS_API_TOKEN}
image = Image.open(BytesIO(requests.get(
LS_URL + task['data']['image'], headers=header).content))
original_width, original_height = image.size
print(image.size)
results = self.model.predict(image)
for result in results:
for i, prediction in enumerate(result.boxes):
xyxy = prediction.xyxy[0].tolist()
predictions.append({
"id": str(i),
"from_name": self.from_name,
"to_name": self.to_name,
"type": "rectanglelabels",
"score": prediction.conf.item(),
"original_width": original_width,
"original_height": original_height,
"image_rotation": 0,
"value": {
"rotation": 0,
"x": xyxy[0] / original_width * 100,
"y": xyxy[1] / original_height * 100,
"width": (xyxy[2] - xyxy[0]) / original_width * 100,
"height": (xyxy[3] - xyxy[1]) / original_height * 100,
"rectanglelabels": [self.labels[int(prediction.cls.item())]]
}
})
score += prediction.conf.item()
return [{
"result": predictions,
"score": score / (i + 1),
"model_version": "v8n", # all predictions will be differentiated by model version
}]