Skip to content

Commit

Permalink
add MSN-PCN modules
Browse files Browse the repository at this point in the history
  • Loading branch information
FANG-Xiaolin committed Dec 16, 2020
1 parent cfaf61a commit 74ea46a
Showing 1 changed file with 84 additions and 0 deletions.
84 changes: 84 additions & 0 deletions model/model_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,3 +103,87 @@ def forward(self, x, latent):
for i in range(self.opt.num_layers):
x = self.activation(self.bn_list[i](self.conv_list[i](x)))
return self.last_conv(x)


# Modules from MSN: https://github.com/Colin97/MSN-Point-Cloud-Completion
class PointNetfeat(nn.Module):
def __init__(self, num_points = 8192, global_feat = True):
super(PointNetfeat, self).__init__()
self.conv1 = torch.nn.Conv1d(3, 64, 1)
self.conv2 = torch.nn.Conv1d(64, 128, 1)
self.conv3 = torch.nn.Conv1d(128, 1024, 1)

self.bn1 = torch.nn.BatchNorm1d(64)
self.bn2 = torch.nn.BatchNorm1d(128)
self.bn3 = torch.nn.BatchNorm1d(1024)

self.num_points = num_points
self.global_feat = global_feat
def forward(self, x):
batchsize = x.size()[0]
x = F.relu(self.bn1(self.conv1(x)))
x = F.relu(self.bn2(self.conv2(x)))
x = self.bn3(self.conv3(x))
x,_ = torch.max(x, 2)
x = x.view(-1, 1024)
return x

class PointGenCon(nn.Module):
def __init__(self, bottleneck_size = 8192):
self.bottleneck_size = bottleneck_size
super(PointGenCon, self).__init__()
self.conv1 = torch.nn.Conv1d(self.bottleneck_size, self.bottleneck_size, 1)
self.conv2 = torch.nn.Conv1d(self.bottleneck_size, self.bottleneck_size//2, 1)
self.conv3 = torch.nn.Conv1d(self.bottleneck_size//2, self.bottleneck_size//4, 1)
self.conv4 = torch.nn.Conv1d(self.bottleneck_size//4, 3, 1)

self.th = nn.Tanh()
self.bn1 = torch.nn.BatchNorm1d(self.bottleneck_size)
self.bn2 = torch.nn.BatchNorm1d(self.bottleneck_size//2)
self.bn3 = torch.nn.BatchNorm1d(self.bottleneck_size//4)

def forward(self, x):
batchsize = x.size()[0]
x = F.relu(self.bn1(self.conv1(x)))
x = F.relu(self.bn2(self.conv2(x)))
x = F.relu(self.bn3(self.conv3(x)))
x = self.th(self.conv4(x))
return x

class PointNetRes(nn.Module):
def __init__(self):
super(PointNetRes, self).__init__()
self.conv1 = torch.nn.Conv1d(4, 64, 1)
self.conv2 = torch.nn.Conv1d(64, 128, 1)
self.conv3 = torch.nn.Conv1d(128, 1024, 1)
self.conv4 = torch.nn.Conv1d(1088, 512, 1)
self.conv5 = torch.nn.Conv1d(512, 256, 1)
self.conv6 = torch.nn.Conv1d(256, 128, 1)
self.conv7 = torch.nn.Conv1d(128, 3, 1)


self.bn1 = torch.nn.BatchNorm1d(64)
self.bn2 = torch.nn.BatchNorm1d(128)
self.bn3 = torch.nn.BatchNorm1d(1024)
self.bn4 = torch.nn.BatchNorm1d(512)
self.bn5 = torch.nn.BatchNorm1d(256)
self.bn6 = torch.nn.BatchNorm1d(128)
self.bn7 = torch.nn.BatchNorm1d(3)
self.th = nn.Tanh()

def forward(self, x):
batchsize = x.size()[0]
npoints = x.size()[2]
x = F.relu(self.bn1(self.conv1(x)))
pointfeat = x
x = F.relu(self.bn2(self.conv2(x)))
x = self.bn3(self.conv3(x))
x,_ = torch.max(x, 2)
x = x.view(-1, 1024)
x = x.view(-1, 1024, 1).repeat(1, 1, npoints)
x = torch.cat([x, pointfeat], 1)
x = F.relu(self.bn4(self.conv4(x)))
x = F.relu(self.bn5(self.conv5(x)))
x = F.relu(self.bn6(self.conv6(x)))
x = self.th(self.conv7(x))
return x

0 comments on commit 74ea46a

Please sign in to comment.