diff --git a/engine.py b/engine.py index 4b6c25846..ac5ea6ff4 100644 --- a/engine.py +++ b/engine.py @@ -112,8 +112,15 @@ def evaluate(model, criterion, postprocessors, data_loader, base_ds, device, out res = {target['image_id'].item(): output for target, output in zip(targets, results)} if coco_evaluator is not None: coco_evaluator.update(res) + if panoptic_evaluator is not None: - res_pano = postprocessors["panoptic"](outputs, targets) + res_pano = postprocessors["panoptic"](outputs, target_sizes, orig_target_sizes) + for i, target in enumerate(targets): + image_id = target["image_id"].item() + file_name = f"{image_id:012d}.png" + res_pano[i]["image_id"] = image_id + res_pano[i]["file_name"] = file_name + panoptic_evaluator.update(res_pano) # gather the stats from all processes diff --git a/hubconf.py b/hubconf.py index ecda6cef9..d841d47c3 100644 --- a/hubconf.py +++ b/hubconf.py @@ -1,40 +1,46 @@ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved import torch -from models.detr import DETR + from models.backbone import Backbone, Joiner -from models.transformer import Transformer +from models.detr import DETR, PostProcess from models.position_encoding import PositionEmbeddingSine -dependencies = ['torch', 'torchvision'] +from models.segmentation import DETRsegm, PostProcessPanoptic +from models.transformer import Transformer +dependencies = ["torch", "torchvision"] -def _make_detr(backbone_name: str, dilation=False, num_classes=91): + +def _make_detr(backbone_name: str, dilation=False, num_classes=91, mask=False): hidden_dim = 256 - backbone = Backbone(backbone_name, train_backbone=True, - return_interm_layers=False, dilation=dilation) + backbone = Backbone(backbone_name, train_backbone=True, return_interm_layers=False, dilation=dilation) pos_enc = PositionEmbeddingSine(hidden_dim // 2, normalize=True) backbone_with_pos_enc = Joiner(backbone, pos_enc) backbone_with_pos_enc.num_channels = backbone.num_channels transformer = Transformer(d_model=hidden_dim, return_intermediate_dec=True) - return DETR(backbone_with_pos_enc, transformer, num_classes=num_classes, num_queries=100) + detr = DETR(backbone_with_pos_enc, transformer, num_classes=num_classes, num_queries=100) + if mask: + return DETRsegm(detr) + return detr -def detr_resnet50(pretrained=False, num_classes=91): +def detr_resnet50(pretrained=False, num_classes=91, return_postprocessor=False): """ DETR R50 with 6 encoder and 6 decoder layers. Achieves 42/62.4 AP/AP50 on COCO val5k. """ - model = _make_detr('resnet50', dilation=False, num_classes=num_classes) + model = _make_detr("resnet50", dilation=False, num_classes=num_classes) if pretrained: checkpoint = torch.hub.load_state_dict_from_url( - url='https://dl.fbaipublicfiles.com/detr/detr-r50-e632da11.pth', - map_location='cpu', - check_hash=True) - model.load_state_dict(checkpoint['model']) + url="https://dl.fbaipublicfiles.com/detr/detr-r50-e632da11.pth", map_location="cpu", check_hash=True + ) + model.load_state_dict(checkpoint["model"]) + if return_postprocessor: + return model, PostProcess() return model -def detr_resnet50_dc5(pretrained=False, num_classes=91): +def detr_resnet50_dc5(pretrained=False, num_classes=91, return_postprocessor=False): """ DETR-DC5 R50 with 6 encoder and 6 decoder layers. @@ -42,33 +48,35 @@ def detr_resnet50_dc5(pretrained=False, num_classes=91): output resolution. Achieves 43.3/63.1 AP/AP50 on COCO val5k. """ - model = _make_detr('resnet50', dilation=True, num_classes=num_classes) + model = _make_detr("resnet50", dilation=True, num_classes=num_classes) if pretrained: checkpoint = torch.hub.load_state_dict_from_url( - url='https://dl.fbaipublicfiles.com/detr/detr-r50-dc5-f0fb7ef5.pth', - map_location='cpu', - check_hash=True) - model.load_state_dict(checkpoint['model']) + url="https://dl.fbaipublicfiles.com/detr/detr-r50-dc5-f0fb7ef5.pth", map_location="cpu", check_hash=True + ) + model.load_state_dict(checkpoint["model"]) + if return_postprocessor: + return model, PostProcess() return model -def detr_resnet101(pretrained=False, num_classes=91): +def detr_resnet101(pretrained=False, num_classes=91, return_postprocessor=False): """ DETR-DC5 R101 with 6 encoder and 6 decoder layers. Achieves 43.5/63.8 AP/AP50 on COCO val5k. """ - model = _make_detr('resnet101', dilation=False, num_classes=num_classes) + model = _make_detr("resnet101", dilation=False, num_classes=num_classes) if pretrained: checkpoint = torch.hub.load_state_dict_from_url( - url='https://dl.fbaipublicfiles.com/detr/detr-r101-2c7b67e5.pth', - map_location='cpu', - check_hash=True) - model.load_state_dict(checkpoint['model']) + url="https://dl.fbaipublicfiles.com/detr/detr-r101-2c7b67e5.pth", map_location="cpu", check_hash=True + ) + model.load_state_dict(checkpoint["model"]) + if return_postprocessor: + return model, PostProcess() return model -def detr_resnet101_dc5(pretrained=False, num_classes=91): +def detr_resnet101_dc5(pretrained=False, num_classes=91, return_postprocessor=False): """ DETR-DC5 R101 with 6 encoder and 6 decoder layers. @@ -76,11 +84,85 @@ def detr_resnet101_dc5(pretrained=False, num_classes=91): output resolution. Achieves 44.9/64.7 AP/AP50 on COCO val5k. """ - model = _make_detr('resnet101', dilation=True, num_classes=num_classes) + model = _make_detr("resnet101", dilation=True, num_classes=num_classes) + if pretrained: + checkpoint = torch.hub.load_state_dict_from_url( + url="https://dl.fbaipublicfiles.com/detr/detr-r101-dc5-a2e86def.pth", map_location="cpu", check_hash=True + ) + model.load_state_dict(checkpoint["model"]) + if return_postprocessor: + return model, PostProcess() + return model + + +def detr_resnet50_panoptic( + pretrained=False, num_classes=250, threshold=0.85, return_postprocessor=False +): + """ + DETR R50 with 6 encoder and 6 decoder layers. + Achieves 43.4 PQ on COCO val5k. + + threshold is the minimum confidence required for keeping segments in the prediction + """ + model = _make_detr("resnet50", dilation=False, num_classes=num_classes, mask=True) + is_thing_map = {i: i <= 90 for i in range(250)} + if pretrained: + checkpoint = torch.hub.load_state_dict_from_url( + url="https://dl.fbaipublicfiles.com/detr/detr-r50-panoptic-00ce5173.pth", + map_location="cpu", + check_hash=True, + ) + model.load_state_dict(checkpoint["model"]) + if return_postprocessor: + return model, PostProcessPanoptic(is_thing_map, threshold=threshold) + return model + + +def detr_resnet50_dc5_panoptic( + pretrained=False, num_classes=91, threshold=0.85, return_postprocessor=False +): + """ + DETR-DC5 R50 with 6 encoder and 6 decoder layers. + + The last block of ResNet-50 has dilation to increase + output resolution. + Achieves 44.6 on COCO val5k. + + threshold is the minimum confidence required for keeping segments in the prediction + """ + model = _make_detr("resnet50", dilation=True, num_classes=num_classes, mask=True) + is_thing_map = {i: i <= 90 for i in range(250)} + if pretrained: + checkpoint = torch.hub.load_state_dict_from_url( + url="https://dl.fbaipublicfiles.com/detr/detr-r50-dc5-panoptic-da08f1b1.pth", + map_location="cpu", + check_hash=True, + ) + model.load_state_dict(checkpoint["model"]) + if return_postprocessor: + return model, PostProcessPanoptic(is_thing_map, threshold=threshold) + return model + + +def detr_resnet101_panoptic( + pretrained=False, num_classes=91, threshold=0.85, return_postprocessor=False +): + """ + DETR-DC5 R101 with 6 encoder and 6 decoder layers. + + Achieves 45.1 PQ on COCO val5k. + + threshold is the minimum confidence required for keeping segments in the prediction + """ + model = _make_detr("resnet101", dilation=False, num_classes=num_classes, mask=True) + is_thing_map = {i: i <= 90 for i in range(250)} if pretrained: checkpoint = torch.hub.load_state_dict_from_url( - url='https://dl.fbaipublicfiles.com/detr/detr-r101-dc5-a2e86def.pth', - map_location='cpu', - check_hash=True) - model.load_state_dict(checkpoint['model']) + url="https://dl.fbaipublicfiles.com/detr/detr-r101-panoptic-40021d53.pth", + map_location="cpu", + check_hash=True, + ) + model.load_state_dict(checkpoint["model"]) + if return_postprocessor: + return model, PostProcessPanoptic(is_thing_map, threshold=threshold) return model diff --git a/models/detr.py b/models/detr.py index 2471da646..1a43faed9 100644 --- a/models/detr.py +++ b/models/detr.py @@ -310,7 +310,7 @@ def build(args): aux_loss=args.aux_loss, ) if args.masks: - model = DETRsegm(model) + model = DETRsegm(model, freeze_detr=(args.frozen_weights is not None)) matcher = build_matcher(args) weight_dict = {'loss_ce': 1, 'loss_bbox': args.bbox_loss_coef} weight_dict['loss_giou'] = args.giou_loss_coef @@ -335,6 +335,6 @@ def build(args): postprocessors['segm'] = PostProcessSegm() if args.dataset_file == "coco_panoptic": is_thing_map = {i: i <= 90 for i in range(201)} - postprocessors["panoptic"] = PostProcessPanoptic(is_thing_map, True, threshold=0.85) + postprocessors["panoptic"] = PostProcessPanoptic(is_thing_map, threshold=0.85) return model, criterion, postprocessors diff --git a/models/segmentation.py b/models/segmentation.py index 589a58ea4..2c03c72d9 100644 --- a/models/segmentation.py +++ b/models/segmentation.py @@ -223,17 +223,15 @@ def forward(self, results, outputs, orig_target_sizes, max_target_sizes): assert len(orig_target_sizes) == len(max_target_sizes) max_h, max_w = max_target_sizes.max(0)[0].tolist() outputs_masks = outputs["pred_masks"].squeeze(2) - outputs_masks = F.interpolate( - outputs_masks, size=(max_h, max_w), mode="bilinear", align_corners=False - ) + outputs_masks = F.interpolate(outputs_masks, size=(max_h, max_w), mode="bilinear", align_corners=False) outputs_masks = (outputs_masks.sigmoid() > self.threshold).cpu() for i, (cur_mask, t, tt) in enumerate(zip(outputs_masks, max_target_sizes, orig_target_sizes)): img_h, img_w = t[0], t[1] results[i]["masks"] = cur_mask[:, :img_h, :img_w].unsqueeze(1) - results[i]["masks"] = F.interpolate(results[i]["masks"].float(), - size=tuple(tt.tolist()), - mode="nearest").byte() + results[i]["masks"] = F.interpolate( + results[i]["masks"].float(), size=tuple(tt.tolist()), mode="nearest" + ).byte() return results @@ -242,25 +240,41 @@ class PostProcessPanoptic(nn.Module): """This class converts the output of the model to the final panoptic result, in the format expected by the coco panoptic API """ - def __init__(self, is_thing_map, rescale_to_orig_size=False, threshold=0.85): + def __init__(self, is_thing_map, threshold=0.85): """ Parameters: is_thing_map: This is a whose keys are the class ids, and the values a boolean indicating whether the class is a thing (True) or a stuff (False) class - rescale_to_orig_size: If true, we use rescale the prediction to the size of the original image. - Otherwise, we keep the size after data augmentation threshold: confidence threshold: segments with confidence lower than this will be deleted """ super().__init__() - self.rescale_to_orig_size = rescale_to_orig_size self.threshold = threshold self.is_thing_map = is_thing_map - def forward(self, outputs, targets): + def forward(self, outputs, processed_sizes, target_sizes=None): + """ This function computes the panoptic prediction from the model's predictions. + Parameters: + outputs: This is a dict coming directly from the model. See the model doc for the content. + processed_sizes: This is a list of tuples (or torch tensors) of sizes of the images that were passed to the + model, ie the size after data augmentation but before batching. + target_sizes: This is a list of tuples (or torch tensors) corresponding to the requested final size + of each prediction. If left to None, it will default to the processed_sizes + """ + if target_sizes is None: + target_sizes = processed_sizes + assert len(processed_sizes) == len(target_sizes) out_logits, raw_masks, raw_boxes = outputs["pred_logits"], outputs["pred_masks"], outputs["pred_boxes"] - assert len(out_logits) == len(raw_masks) == len(targets) + assert len(out_logits) == len(raw_masks) == len(target_sizes) preds = [] - for cur_logits, cur_masks, cur_boxes, target in zip(out_logits, raw_masks, raw_boxes, targets): + + def to_tuple(tup): + if isinstance(tup, tuple): + return tup + return tuple(tup.cpu().tolist()) + + for cur_logits, cur_masks, cur_boxes, size, target_size in zip( + out_logits, raw_masks, raw_boxes, processed_sizes, target_sizes + ): # we filter empty queries and detection below threshold scores, labels = cur_logits.softmax(-1).max(-1) keep = labels.ne(outputs["pred_logits"].shape[-1] - 1) & (scores > self.threshold) @@ -268,7 +282,7 @@ def forward(self, outputs, targets): cur_scores = cur_scores[keep] cur_classes = cur_classes[keep] cur_masks = cur_masks[keep] - cur_masks = interpolate(cur_masks[None], tuple(target["size"].tolist()), mode="bilinear").squeeze(0) + cur_masks = interpolate(cur_masks[None], to_tuple(size), mode="bilinear").squeeze(0) cur_boxes = box_ops.box_cxcywh_to_xyxy(cur_boxes[keep]) h, w = cur_masks.shape[-2:] @@ -301,8 +315,7 @@ def get_ids_area(masks, scores, dedup=False): for eq_id in equiv: m_id.masked_fill_(m_id.eq(eq_id), equiv[0]) - field = "orig_size" if self.rescale_to_orig_size else "size" - final_h, final_w = target[field].cpu().unbind(0) + final_h, final_w = to_tuple(target_size) seg_img = Image.fromarray(id2rgb(m_id.view(h, w).cpu().numpy())) seg_img = seg_img.resize(size=(final_w, final_h), resample=Image.NEAREST) @@ -335,22 +348,14 @@ def get_ids_area(masks, scores, dedup=False): else: cur_classes = torch.ones(1, dtype=torch.long, device=cur_classes.device) - image_id = target["image_id"].item() - segments_info = [] for i, a in enumerate(area): cat = cur_classes[i].item() segments_info.append({"id": i, "isthing": self.is_thing_map[cat], "category_id": cat, "area": a}) del cur_classes - file_name = f"{image_id:012d}.png" with io.BytesIO() as out: seg_img.save(out, format="PNG") - predictions = { - "image_id": image_id, - "file_name": file_name, - "png_string": out.getvalue(), - "segments_info": segments_info, - } + predictions = {"png_string": out.getvalue(), "segments_info": segments_info} preds.append(predictions) return preds