From 166097892797b7850a494d1c7b9f736f86ef5d99 Mon Sep 17 00:00:00 2001 From: Hongyuan Zhang <66273343+Alias-z@users.noreply.github.com> Date: Fri, 7 Jun 2024 22:33:40 +0200 Subject: [PATCH] Update has_mask method for mmdet models (#1054) --- sahi/models/mmdet.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sahi/models/mmdet.py b/sahi/models/mmdet.py index c64ea87c..b4b363d3 100644 --- a/sahi/models/mmdet.py +++ b/sahi/models/mmdet.py @@ -190,7 +190,9 @@ def has_mask(self): """ Returns if model output contains segmentation mask """ - has_mask = self.model.model.with_mask + # has_mask = self.model.model.with_mask + train_pipeline = self.model.cfg["train_dataloader"]["dataset"]["pipeline"] + has_mask = any(isinstance(item, dict) and any("mask" in key for key in item.keys()) for item in train_pipeline) return has_mask @property