From fbf9b4f8eb6222d2f01d9fa5e057885a492ac797 Mon Sep 17 00:00:00 2001 From: BigBigDJ <40659980+KelvinCPChiu@users.noreply.github.com> Date: Thu, 27 May 2021 15:09:47 +0800 Subject: [PATCH 1/2] Update rcnn_discriminator.py Replace the compilation of ROIAlign and ROIPooling with Torchvision function. --- model/rcnn_discriminator.py | 27 ++++++++++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) diff --git a/model/rcnn_discriminator.py b/model/rcnn_discriminator.py index 5a2738a..34b7fa0 100644 --- a/model/rcnn_discriminator.py +++ b/model/rcnn_discriminator.py @@ -1,11 +1,36 @@ import torch import torch.nn as nn import torch.nn.functional as F -from .roi_layers import ROIAlign, ROIPool +#from .roi_layers import ROIAlign, ROIPool +from torchvision.ops import roi_align +from torchvision.ops import roi_pool from utils.util import * from utils.bilinear import * +class ROIAlign(nn.Module): + def __init__(self, output_size, spatial_scale, sampling_ratio): + super(ROIAlign, self).__init__() + self.output_size = output_size + self.spatial_scale = spatial_scale + self.sampling_ratio = sampling_ratio + + def forward(self, input, rois): + return roi_align( + input, rois, self.output_size, self.spatial_scale, self.sampling_ratio + ) + + +class ROIPool(nn.Module): + def __init__(self, output_size, spatial_scale): + super(ROIPool, self).__init__() + self.output_size = output_size + self.spatial_scale = spatial_scale + + def forward(self, input, rois): + return roi_pool(input, rois, self.output_size, self.spatial_scale) + + def conv2d(in_feat, out_feat, kernel_size=3, stride=1, pad=1, spectral_norm=True): conv = nn.Conv2d(in_feat, out_feat, kernel_size, stride, pad) if spectral_norm: From 0bb32d2b7e903410acb9ba099350fd613064012e Mon Sep 17 00:00:00 2001 From: BigBigDJ <40659980+KelvinCPChiu@users.noreply.github.com> Date: Thu, 27 May 2021 15:14:21 +0800 Subject: [PATCH 2/2] Update rcnn_discriminator.py --- model/rcnn_discriminator.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/model/rcnn_discriminator.py b/model/rcnn_discriminator.py index 34b7fa0..41a4549 100644 --- a/model/rcnn_discriminator.py +++ b/model/rcnn_discriminator.py @@ -83,8 +83,8 @@ def forward(self, x, y=None, bbox=None): # obj path # seperate different path s_idx = ((bbox[:, 3] - bbox[:, 1]) < 64) * ((bbox[:, 4] - bbox[:, 2]) < 64) - bbox_l, bbox_s = bbox[1-s_idx], bbox[s_idx] - y_l, y_s = y[1-s_idx], y[s_idx] + bbox_l, bbox_s = bbox[~s_idx], bbox[s_idx] + y_l, y_s = y[~s_idx], y[s_idx] obj_feat_s = self.block_obj3(x1) obj_feat_s = self.block_obj4(obj_feat_s)