14
14
import numpy as np
15
15
from discord import app_commands
16
16
17
- from UltralyticsBot import REQ_LIM , REQ_ENDPOINT , CMDS , RESPONSE_KEYS , HUB_KEY , DEFAULT_INFER , BOT_ID , OWNER_ID , GH , MAX_REQ
18
- from UltralyticsBot .utils .checks import model_chk
17
+ from UltralyticsBot import REQ_LIM , REQ_ENDPOINT , CMDS , RESPONSE_KEYS , HUB_KEY , DEFAULT_INFER , BOT_ID , OWNER_ID , GH , MODELS
19
18
from UltralyticsBot .utils .logging import Loggr
19
+ from UltralyticsBot .cmds .client import MyClient
20
+ from UltralyticsBot .utils .checks import model_chk
20
21
from UltralyticsBot .utils .general import ReqImage , attach_file , files_age
21
22
from UltralyticsBot .utils .plotting import nxy2xy , xcycwh2xyxy , rel_line_size , draw_all_boxes
22
23
from UltralyticsBot .utils .msgs import IMG_ERR_MSG , API_ERR_MSG , NOT_OWNER , gen_line , ReqMessage , ResponseMsg , NEWLINE
23
- from UltralyticsBot .cmds .client import MyClient
24
24
25
25
TEMPFILE = 'detect_res.png' # fallback
26
26
LIMITS = {k :app_commands .Range [type (v ['min' ]), v ['min' ], v ['max' ]] for k ,v in REQ_LIM .items ()}
@@ -39,9 +39,12 @@ def inference_req(imgbytes:bytes, req2:str=REQ_ENDPOINT, **kwargs) -> requests.R
39
39
if any (kwargs ):
40
40
for k in kwargs :
41
41
_ = req_dict .update ({k :kwargs [k ]}) if k in req_dict else None
42
- req_dict ['key' ] = HUB_KEY
43
- req_dict ['image' ] = base64 .b64encode (imgbytes ).decode ()
44
- return requests .post (req2 , json = req_dict )
42
+ # req_dict['key'] = HUB_KEY
43
+ # req_dict['image'] = base64.b64encode(imgbytes).decode()
44
+ # req_dict['image'] = base64.b64encode(imgbytes).decode()
45
+ _ = [req_dict .pop (i ) for i in ["image" , "key" ]]
46
+ # return requests.post(req2, json=req_dict)
47
+ return requests .post (req2 , headers = {}, data = req_dict , files = {"image" :imgbytes })
45
48
# return req_dict # NOTE might need to change in future
46
49
47
50
def process_result (img :np .ndarray , predictions :list , plot :bool , class_pad :int ) -> tuple [np .ndarray , str ]:
@@ -52,10 +55,26 @@ def process_result(img:np.ndarray, predictions:list, plot:bool, class_pad:int) -
52
55
53
56
msg = ''
54
57
pred_boxes = np .zeros ((1 ,5 )) # x-center, y-center, width, height, class
58
+ # TODO add post processing based on model used Segment, Key-point, Pose, OBB
55
59
for p in predictions :
56
60
cls_name , conf , idx , * (x , y , w , h ) = get_values (p )
57
- x1 , y1 , x2 , y2 = tuple (nxy2xy (xcycwh2xyxy (np .array ((x , y , w , h ))), imH , imW )) # n-xcycwh -> x1y1x2y2
58
- pred_boxes = np .vstack ([pred_boxes , np .array ((x1 , y1 , x2 , y2 , idx ))])
61
+ x1 , y1 , x2 , y2 = tuple (
62
+ nxy2xy (
63
+ xcycwh2xyxy (
64
+ np .array ((x , y , w , h ))
65
+ ),
66
+ imH ,
67
+ imW
68
+ )
69
+ ) # n-xcycwh -> x1y1x2y2
70
+ pred_boxes = np .vstack (
71
+ [
72
+ pred_boxes ,
73
+ np .array (
74
+ (x1 , y1 , x2 , y2 , idx )
75
+ )
76
+ ]
77
+ )
59
78
msg += gen_line (cls_name , class_pad , conf , x1 , y1 , x2 , y2 )
60
79
61
80
if plot :
@@ -114,41 +133,59 @@ async def msg_predict(message:discord.Message):
114
133
await message .reply (text , file = file )
115
134
116
135
###-----Slash Commands-----###
136
+ @app_commands .choices (
137
+ model = [app_commands .Choice (name = m , value = str (m ).lower ()) for m in MODELS ]
138
+ )
117
139
@app_commands .describe (
118
- img_url = 'Valid HTTP/S link to a supported image type.' ,
119
- show = "Enable/disable showing annotated result image." ,
120
- conf = "Confidence threshold for class predictions." ,
121
- iou = "Intersection over union threshold for detections." ,
122
- size = "Inference image size (single dimension only)." ,
123
- model = "One of YOLOv(5|8)(n|s|m|l|x) models to use for inference."
140
+ img_url = 'Valid HTTP/S link to a supported image type.' ,
141
+ show = "Enable/disable showing annotated result image." ,
142
+ conf = "Confidence threshold for class predictions." ,
143
+ iou = "Intersection over union threshold for detections." ,
144
+ size = "Inference image size (single dimension only)." ,
145
+ model = "One of YOLOv(5|8)(n|s|m|l|x) models to use for inference."
124
146
)
125
147
async def im_predict (interaction :discord .Interaction ,
126
- img_url :str ,
127
- show :bool = True ,
128
- conf :LIMITS ['conf' ]= 0.35 ,
129
- iou :LIMITS ['iou' ]= 0.45 ,
130
- size :LIMITS ['size' ]= 640 ,
131
- model :str = 'yolov8n' ,
132
- ):
148
+ img_url :str ,
149
+ show :bool = True ,
150
+ conf :LIMITS ['conf' ]= 0.35 , # type: ignore
151
+ iou :LIMITS ['iou' ]= 0.45 , # type: ignore
152
+ size :LIMITS ['size' ]= 640 , # type: ignore
153
+ model :app_commands . Choice [ str ] = 'yolov8n' ,
154
+ ):
133
155
await interaction .response .defer (thinking = True ) # permits longer response time
134
156
135
- model = model_chk (model )
157
+ model = model_chk (model . value )
136
158
image = ReqImage (img_url )
137
159
# infer_im, infer_data, infer_ratio = image.inference_img(int(size))
138
160
if not image .image_error :
139
161
140
162
try :
141
163
# if not image.image_error:
142
164
infer_im , infer_data , infer_ratio = image .inference_img (int (size ))
143
- req = inference_req (infer_data , req2 = REQ_ENDPOINT , confidence = str (conf ), iou = str (iou ), size = str (size ), model = str (model ))
165
+ req = inference_req (
166
+ infer_data ,
167
+ req2 = REQ_ENDPOINT .replace ("yolov8n" , model .lower ()),
168
+ confidence = str (conf ),
169
+ iou = str (iou ),
170
+ size = str (size ),
171
+ model = str (model )
172
+ )
144
173
req .raise_for_status ()
145
174
if req .status_code != 200 : # Catch all other non-good return codes and make sure to reply
146
175
Loggr .debug (f"{ API_ERR_MSG .format (req .status_code , req .reason )} " )
147
176
await interaction .followup .send (API_ERR_MSG .format (req .status_code , req .reason ))
148
177
149
178
else :
150
179
Reply = ResponseMsg (req , show , True , infer_ratio )
151
- file , text = Reply .start_msg (partial (process_result , img = infer_im , plot = show , class_pad = Reply .cls_pad ), infer_ratio = infer_ratio )
180
+ file , text = Reply .start_msg (
181
+ partial (
182
+ process_result , # NOTE will need to update for all model tasks
183
+ img = infer_im ,
184
+ plot = show ,
185
+ class_pad = Reply .cls_pad
186
+ ),
187
+ infer_ratio = infer_ratio
188
+ )
152
189
file = attach_file (file ) if show else None
153
190
154
191
except requests .HTTPError as e :
0 commit comments