Skip to content

Commit 12f3669

Browse files
committedFeb 11, 2024
add Choices for model, import all valid MODELS for Choices, refactor for new (temp) API endpoint, and line formatting
1 parent 8564b0a commit 12f3669

File tree

1 file changed

+61
-24
lines changed

1 file changed

+61
-24
lines changed
 

‎src/UltralyticsBot/cmds/actions.py

+61-24
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,13 @@
1414
import numpy as np
1515
from discord import app_commands
1616

17-
from UltralyticsBot import REQ_LIM, REQ_ENDPOINT, CMDS, RESPONSE_KEYS, HUB_KEY, DEFAULT_INFER, BOT_ID, OWNER_ID, GH, MAX_REQ
18-
from UltralyticsBot.utils.checks import model_chk
17+
from UltralyticsBot import REQ_LIM, REQ_ENDPOINT, CMDS, RESPONSE_KEYS, HUB_KEY, DEFAULT_INFER, BOT_ID, OWNER_ID, GH, MODELS
1918
from UltralyticsBot.utils.logging import Loggr
19+
from UltralyticsBot.cmds.client import MyClient
20+
from UltralyticsBot.utils.checks import model_chk
2021
from UltralyticsBot.utils.general import ReqImage, attach_file, files_age
2122
from UltralyticsBot.utils.plotting import nxy2xy, xcycwh2xyxy, rel_line_size, draw_all_boxes
2223
from UltralyticsBot.utils.msgs import IMG_ERR_MSG, API_ERR_MSG, NOT_OWNER, gen_line, ReqMessage, ResponseMsg, NEWLINE
23-
from UltralyticsBot.cmds.client import MyClient
2424

2525
TEMPFILE = 'detect_res.png' # fallback
2626
LIMITS = {k:app_commands.Range[type(v['min']), v['min'], v['max']] for k,v in REQ_LIM.items()}
@@ -39,9 +39,12 @@ def inference_req(imgbytes:bytes, req2:str=REQ_ENDPOINT, **kwargs) -> requests.R
3939
if any(kwargs):
4040
for k in kwargs:
4141
_ = req_dict.update({k:kwargs[k]}) if k in req_dict else None
42-
req_dict['key'] = HUB_KEY
43-
req_dict['image'] = base64.b64encode(imgbytes).decode()
44-
return requests.post(req2, json=req_dict)
42+
# req_dict['key'] = HUB_KEY
43+
# req_dict['image'] = base64.b64encode(imgbytes).decode()
44+
# req_dict['image'] = base64.b64encode(imgbytes).decode()
45+
_ = [req_dict.pop(i) for i in ["image", "key"]]
46+
# return requests.post(req2, json=req_dict)
47+
return requests.post(req2, headers={}, data=req_dict, files={"image":imgbytes})
4548
# return req_dict # NOTE might need to change in future
4649

4750
def process_result(img:np.ndarray, predictions:list, plot:bool, class_pad:int) -> tuple[np.ndarray, str]:
@@ -52,10 +55,26 @@ def process_result(img:np.ndarray, predictions:list, plot:bool, class_pad:int) -
5255

5356
msg = ''
5457
pred_boxes = np.zeros((1,5)) # x-center, y-center, width, height, class
58+
# TODO add post processing based on model used Segment, Key-point, Pose, OBB
5559
for p in predictions:
5660
cls_name, conf, idx, *(x, y, w, h) = get_values(p)
57-
x1, y1, x2, y2 = tuple(nxy2xy(xcycwh2xyxy(np.array((x, y, w, h))), imH, imW)) # n-xcycwh -> x1y1x2y2
58-
pred_boxes = np.vstack([pred_boxes, np.array((x1, y1, x2, y2, idx))])
61+
x1, y1, x2, y2 = tuple(
62+
nxy2xy(
63+
xcycwh2xyxy(
64+
np.array((x, y, w, h))
65+
),
66+
imH,
67+
imW
68+
)
69+
) # n-xcycwh -> x1y1x2y2
70+
pred_boxes = np.vstack(
71+
[
72+
pred_boxes,
73+
np.array(
74+
(x1, y1, x2, y2, idx)
75+
)
76+
]
77+
)
5978
msg += gen_line(cls_name, class_pad, conf, x1, y1, x2, y2)
6079

6180
if plot:
@@ -114,41 +133,59 @@ async def msg_predict(message:discord.Message):
114133
await message.reply(text, file=file)
115134

116135
###-----Slash Commands-----###
136+
@app_commands.choices(
137+
model=[app_commands.Choice(name=m, value=str(m).lower()) for m in MODELS]
138+
)
117139
@app_commands.describe(
118-
img_url='Valid HTTP/S link to a supported image type.',
119-
show="Enable/disable showing annotated result image.",
120-
conf="Confidence threshold for class predictions.",
121-
iou="Intersection over union threshold for detections.",
122-
size="Inference image size (single dimension only).",
123-
model="One of YOLOv(5|8)(n|s|m|l|x) models to use for inference."
140+
img_url='Valid HTTP/S link to a supported image type.',
141+
show="Enable/disable showing annotated result image.",
142+
conf="Confidence threshold for class predictions.",
143+
iou="Intersection over union threshold for detections.",
144+
size="Inference image size (single dimension only).",
145+
model="One of YOLOv(5|8)(n|s|m|l|x) models to use for inference."
124146
)
125147
async def im_predict(interaction:discord.Interaction,
126-
img_url:str,
127-
show:bool=True,
128-
conf:LIMITS['conf']=0.35,
129-
iou:LIMITS['iou']=0.45,
130-
size:LIMITS['size']=640,
131-
model:str='yolov8n',
132-
):
148+
img_url:str,
149+
show:bool=True,
150+
conf:LIMITS['conf']=0.35, # type: ignore
151+
iou:LIMITS['iou']=0.45, # type: ignore
152+
size:LIMITS['size']=640, # type: ignore
153+
model:app_commands.Choice[str]='yolov8n',
154+
):
133155
await interaction.response.defer(thinking=True) # permits longer response time
134156

135-
model = model_chk(model)
157+
model = model_chk(model.value)
136158
image = ReqImage(img_url)
137159
# infer_im, infer_data, infer_ratio = image.inference_img(int(size))
138160
if not image.image_error:
139161

140162
try:
141163
# if not image.image_error:
142164
infer_im, infer_data, infer_ratio = image.inference_img(int(size))
143-
req = inference_req(infer_data, req2=REQ_ENDPOINT, confidence=str(conf), iou=str(iou), size=str(size), model=str(model))
165+
req = inference_req(
166+
infer_data,
167+
req2=REQ_ENDPOINT.replace("yolov8n", model.lower()),
168+
confidence=str(conf),
169+
iou=str(iou),
170+
size=str(size),
171+
model=str(model)
172+
)
144173
req.raise_for_status()
145174
if req.status_code != 200: # Catch all other non-good return codes and make sure to reply
146175
Loggr.debug(f"{API_ERR_MSG.format(req.status_code, req.reason)}")
147176
await interaction.followup.send(API_ERR_MSG.format(req.status_code, req.reason))
148177

149178
else:
150179
Reply = ResponseMsg(req, show, True, infer_ratio)
151-
file, text = Reply.start_msg(partial(process_result, img=infer_im, plot=show, class_pad=Reply.cls_pad), infer_ratio=infer_ratio)
180+
file, text = Reply.start_msg(
181+
partial(
182+
process_result, # NOTE will need to update for all model tasks
183+
img=infer_im,
184+
plot=show,
185+
class_pad=Reply.cls_pad
186+
),
187+
infer_ratio=infer_ratio
188+
)
152189
file = attach_file(file) if show else None
153190

154191
except requests.HTTPError as e:

0 commit comments

Comments
 (0)
Please sign in to comment.