-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathinference.py
48 lines (41 loc) · 2.63 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import torch
from models.DN_DAB_DETR.DABDETR import build_DABDETR
from util.slconfig import SLConfig
from util.visualizer import COCOVisualizer
from util import box_ops
model_config_path = "config.json" # change the path of the model config
model_checkpoint_path = "checkpoint_optimized_44.7ap.pth" # change the path of the model checkpoint
args = SLConfig.fromfile(model_config_path)
model, criterion, postprocessors = build_DABDETR(args)
checkpoint = torch.load(model_checkpoint_path, map_location='cpu')
model.load_state_dict(checkpoint['model'])
model.eval()
id2name = {1: 'person', 2: 'bicycle', 3: 'car', 4: 'motorcycle', 5: 'airplane', 6: 'bus', 7: 'train', 8: 'truck', 9: 'boat', 10: 'traffic light', 11: 'fire hydrant', 13: 'stop sign', 14: 'parking meter', 15: 'bench', 16: 'bird', 17: 'cat', 18: 'dog', 19: 'horse', 20: 'sheep', 21: 'cow', 22: 'elephant', 23: 'bear', 24: 'zebra', 25: 'giraffe', 27: 'backpack', 28: 'umbrella', 31: 'handbag', 32: 'tie', 33: 'suitcase', 34: 'frisbee', 35: 'skis', 36: 'snowboard', 37: 'sports ball', 38: 'kite', 39: 'baseball bat', 40: 'baseball glove', 41: 'skateboard', 42: 'surfboard', 43: 'tennis racket', 44: 'bottle', 46: 'wine glass', 47: 'cup', 48: 'fork', 49: 'knife', 50: 'spoon', 51: 'bowl', 52: 'banana', 53: 'apple', 54: 'sandwich', 55: 'orange', 56: 'broccoli', 57: 'carrot', 58: 'hot dog', 59: 'pizza', 60: 'donut', 61: 'cake', 62: 'chair', 63: 'couch', 64: 'potted plant', 65: 'bed', 67: 'dining table', 70: 'toilet', 72: 'tv', 73: 'laptop', 74: 'mouse', 75: 'remote', 76: 'keyboard', 77: 'cell phone', 78: 'microwave', 79: 'oven', 80: 'toaster', 81: 'sink', 82: 'refrigerator', 84: 'book', 85: 'clock', 86: 'vase', 87: 'scissors', 88: 'teddy bear', 89: 'hair drier', 90: 'toothbrush'}
vslzr = COCOVisualizer()
from PIL import Image
import datasets.transforms as T
image = Image.open("bus.jpg").convert("RGB")
# image
# transform images
transform = T.Compose([
T.RandomResize([800], max_size=1333),
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
image, _ = transform(image, None)
# predict images
output, _ = model(image[None],0)
output = postprocessors['bbox'](output, torch.Tensor([[1.0, 1.0]]))[0]
# visualize outputs
thershold = 0.3 # set a thershold
scores = output['scores']
labels = output['labels']
boxes = box_ops.box_xyxy_to_cxcywh(output['boxes'])
select_mask = scores > thershold
box_label = [id2name[int(item)] for item in labels[select_mask]]
pred_dict = {
'boxes': boxes[select_mask],
'size': torch.Tensor([image.shape[1], image.shape[2]]),
'box_label': box_label
}
vslzr.visualize(image, pred_dict, savedir=None)