diff --git a/lib/models/seg_hrnet_ocr.py b/lib/models/seg_hrnet_ocr.py index ed9df629..7869729b 100644 --- a/lib/models/seg_hrnet_ocr.py +++ b/lib/models/seg_hrnet_ocr.py @@ -56,8 +56,8 @@ def __init__(self, cls_num=0, scale=1): self.scale = scale def forward(self, feats, probs): - batch_size, c, h, w = probs.size(0), probs.size(1), probs.size(2), probs.size(3) - probs = probs.view(batch_size, c, -1) + batch_size, k, h, w = probs.size(0), probs.size(1), probs.size(2), probs.size(3) + probs = probs.view(batch_size, k, -1) feats = feats.view(batch_size, feats.size(1), -1) feats = feats.permute(0, 2, 1) # batch x hw x c probs = F.softmax(self.scale * probs, dim=2)# batch x k x hw