diff --git a/model/rcnn_discriminator.py b/model/rcnn_discriminator.py index 5a2738a..41a4549 100644 --- a/model/rcnn_discriminator.py +++ b/model/rcnn_discriminator.py @@ -1,11 +1,36 @@ import torch import torch.nn as nn import torch.nn.functional as F -from .roi_layers import ROIAlign, ROIPool +#from .roi_layers import ROIAlign, ROIPool +from torchvision.ops import roi_align +from torchvision.ops import roi_pool from utils.util import * from utils.bilinear import * +class ROIAlign(nn.Module): + def __init__(self, output_size, spatial_scale, sampling_ratio): + super(ROIAlign, self).__init__() + self.output_size = output_size + self.spatial_scale = spatial_scale + self.sampling_ratio = sampling_ratio + + def forward(self, input, rois): + return roi_align( + input, rois, self.output_size, self.spatial_scale, self.sampling_ratio + ) + + +class ROIPool(nn.Module): + def __init__(self, output_size, spatial_scale): + super(ROIPool, self).__init__() + self.output_size = output_size + self.spatial_scale = spatial_scale + + def forward(self, input, rois): + return roi_pool(input, rois, self.output_size, self.spatial_scale) + + def conv2d(in_feat, out_feat, kernel_size=3, stride=1, pad=1, spectral_norm=True): conv = nn.Conv2d(in_feat, out_feat, kernel_size, stride, pad) if spectral_norm: @@ -58,8 +83,8 @@ def forward(self, x, y=None, bbox=None): # obj path # seperate different path s_idx = ((bbox[:, 3] - bbox[:, 1]) < 64) * ((bbox[:, 4] - bbox[:, 2]) < 64) - bbox_l, bbox_s = bbox[1-s_idx], bbox[s_idx] - y_l, y_s = y[1-s_idx], y[s_idx] + bbox_l, bbox_s = bbox[~s_idx], bbox[s_idx] + y_l, y_s = y[~s_idx], y[s_idx] obj_feat_s = self.block_obj3(x1) obj_feat_s = self.block_obj4(obj_feat_s)