From 74ea46a3fc2f4ac4ddabf7454f7aadddd80f5c1e Mon Sep 17 00:00:00 2001 From: xiaolinf Date: Tue, 15 Dec 2020 22:31:45 -0500 Subject: [PATCH] add MSN-PCN modules --- model/model_blocks.py | 84 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 84 insertions(+) diff --git a/model/model_blocks.py b/model/model_blocks.py index 16bf14a..be71191 100644 --- a/model/model_blocks.py +++ b/model/model_blocks.py @@ -103,3 +103,87 @@ def forward(self, x, latent): for i in range(self.opt.num_layers): x = self.activation(self.bn_list[i](self.conv_list[i](x))) return self.last_conv(x) + + +# Modules from MSN: https://github.com/Colin97/MSN-Point-Cloud-Completion +class PointNetfeat(nn.Module): + def __init__(self, num_points = 8192, global_feat = True): + super(PointNetfeat, self).__init__() + self.conv1 = torch.nn.Conv1d(3, 64, 1) + self.conv2 = torch.nn.Conv1d(64, 128, 1) + self.conv3 = torch.nn.Conv1d(128, 1024, 1) + + self.bn1 = torch.nn.BatchNorm1d(64) + self.bn2 = torch.nn.BatchNorm1d(128) + self.bn3 = torch.nn.BatchNorm1d(1024) + + self.num_points = num_points + self.global_feat = global_feat + def forward(self, x): + batchsize = x.size()[0] + x = F.relu(self.bn1(self.conv1(x))) + x = F.relu(self.bn2(self.conv2(x))) + x = self.bn3(self.conv3(x)) + x,_ = torch.max(x, 2) + x = x.view(-1, 1024) + return x + +class PointGenCon(nn.Module): + def __init__(self, bottleneck_size = 8192): + self.bottleneck_size = bottleneck_size + super(PointGenCon, self).__init__() + self.conv1 = torch.nn.Conv1d(self.bottleneck_size, self.bottleneck_size, 1) + self.conv2 = torch.nn.Conv1d(self.bottleneck_size, self.bottleneck_size//2, 1) + self.conv3 = torch.nn.Conv1d(self.bottleneck_size//2, self.bottleneck_size//4, 1) + self.conv4 = torch.nn.Conv1d(self.bottleneck_size//4, 3, 1) + + self.th = nn.Tanh() + self.bn1 = torch.nn.BatchNorm1d(self.bottleneck_size) + self.bn2 = torch.nn.BatchNorm1d(self.bottleneck_size//2) + self.bn3 = torch.nn.BatchNorm1d(self.bottleneck_size//4) + + def forward(self, x): + batchsize = x.size()[0] + x = F.relu(self.bn1(self.conv1(x))) + x = F.relu(self.bn2(self.conv2(x))) + x = F.relu(self.bn3(self.conv3(x))) + x = self.th(self.conv4(x)) + return x + +class PointNetRes(nn.Module): + def __init__(self): + super(PointNetRes, self).__init__() + self.conv1 = torch.nn.Conv1d(4, 64, 1) + self.conv2 = torch.nn.Conv1d(64, 128, 1) + self.conv3 = torch.nn.Conv1d(128, 1024, 1) + self.conv4 = torch.nn.Conv1d(1088, 512, 1) + self.conv5 = torch.nn.Conv1d(512, 256, 1) + self.conv6 = torch.nn.Conv1d(256, 128, 1) + self.conv7 = torch.nn.Conv1d(128, 3, 1) + + + self.bn1 = torch.nn.BatchNorm1d(64) + self.bn2 = torch.nn.BatchNorm1d(128) + self.bn3 = torch.nn.BatchNorm1d(1024) + self.bn4 = torch.nn.BatchNorm1d(512) + self.bn5 = torch.nn.BatchNorm1d(256) + self.bn6 = torch.nn.BatchNorm1d(128) + self.bn7 = torch.nn.BatchNorm1d(3) + self.th = nn.Tanh() + + def forward(self, x): + batchsize = x.size()[0] + npoints = x.size()[2] + x = F.relu(self.bn1(self.conv1(x))) + pointfeat = x + x = F.relu(self.bn2(self.conv2(x))) + x = self.bn3(self.conv3(x)) + x,_ = torch.max(x, 2) + x = x.view(-1, 1024) + x = x.view(-1, 1024, 1).repeat(1, 1, npoints) + x = torch.cat([x, pointfeat], 1) + x = F.relu(self.bn4(self.conv4(x))) + x = F.relu(self.bn5(self.conv5(x))) + x = F.relu(self.bn6(self.conv6(x))) + x = self.th(self.conv7(x)) + return x