Skip to content

Commit

Permalink
fixed training bug for learnable SGT
Browse files Browse the repository at this point in the history
  • Loading branch information
kyang-06 committed Dec 19, 2023
1 parent 7e0c315 commit f151b11
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 5 deletions.
3 changes: 2 additions & 1 deletion src/base_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ def get_lifting_model(opt):
model = AutoDynamicGridLiftingNetwork(hidden_size=opt.hidsize,
num_block=opt.num_block,
grid_shape=opt.grid_shape,
padding_mode=opt.padding_mode)
padding_mode=opt.padding_mode,
autosgt_prior=opt.autosgt_prior)
else:
raise Exception('Unexpected argument, %s' % opt.lifting_model)
model = model.cuda()
Expand Down
67 changes: 63 additions & 4 deletions src/network/dgridconv_autogrids.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ def __init__(self,
output_dim = 3,
temperature=30,
grid_shape=(5,5),
padding_mode=('c','z')):
padding_mode=('c','z'),
autosgt_prior='standard'):
super(AutoDynamicGridLiftingNetwork, self).__init__()

self.linear_size = hidden_size
Expand Down Expand Up @@ -56,7 +57,7 @@ def __init__(self,
self.relu = ReLU(inplace=True)

self.grid_shape = list(grid_shape)
self.sgt_layer = AutoSGT(num_jts=num_jts, grid_shape=grid_shape)
self.sgt_layer = AutoSGT(num_jts=num_jts, grid_shape=grid_shape, autosgt_prior=autosgt_prior)

def net_update_temperature(self, temperature):
for m in self.modules():
Expand Down Expand Up @@ -240,13 +241,13 @@ def forward(self, x):


class AutoSGT(nn.Module):
def __init__(self, num_jts, grid_shape):
def __init__(self, num_jts, grid_shape, autosgt_prior):
super(AutoSGT, self).__init__()
self.grid_shape = grid_shape
self.J = num_jts
self.HW = np.prod(grid_shape)

self.register_parameter('sgt_trans_mat', torch.nn.Parameter(torch.rand(1, np.prod(grid_shape), num_jts)))
self.register_parameter('sgt_trans_mat', torch.nn.Parameter(self.init_sgt_prior(autosgt_prior)))

def forward(self, use_gumbel_noise, gumbel_temp, is_training=False):
sgt_trans_mat = self.sgt_trans_mat
Expand All @@ -262,3 +263,61 @@ def forward(self, use_gumbel_noise, gumbel_temp, is_training=False):
sgt_trans_mat_hard = F.one_hot(torch.argmax(sgt_trans_mat, -1)).float()

return sgt_trans_mat_hard

def init_sgt_prior(self, prior_type):
assert self.J == 17 and self.HW == 25
if prior_type == 'standard':
prior_sgt_mat = torch.zeros(self.grid_shape + [self.J])
# row 0
prior_sgt_mat[0, :, 7] = 1
# row 1
prior_sgt_mat[1, [0, -1], 0] = 1
prior_sgt_mat[1, [1, 2, 3], 8] = 1
# row 2
prior_sgt_mat[2, 0, 4] = 1
prior_sgt_mat[2, 1, 11] = 1
prior_sgt_mat[2, 2, 9] = 1
prior_sgt_mat[2, 3, 14] = 1
prior_sgt_mat[2, 4, 1] = 1
# row 3
prior_sgt_mat[3, 0, 5] = 1
prior_sgt_mat[3, 1, 12] = 1
prior_sgt_mat[3, 2, 9] = 1
prior_sgt_mat[3, 3, 15] = 1
prior_sgt_mat[3, 4, 2] = 1
# row 4
prior_sgt_mat[4, 0, 6] = 1
prior_sgt_mat[4, 1, 13] = 1
prior_sgt_mat[4, 2, 10] = 1
prior_sgt_mat[4, 3, 16] = 1
prior_sgt_mat[4, 4, 3] = 1
prior_sgt_mat = prior_sgt_mat.reshape(1, self.HW, self.J)
elif prior_type == 'learnt_type1':
prior_sgt_mat = torch.LongTensor([[7,4,7,1,0,
0,8,8,8,0,
4,11,9,14,1,
5,12,9,15,2,
6,13,10,16,3]])
prior_sgt_mat = F.one_hot(prior_sgt_mat, num_classes=self.J).float() # 1*self.HW*self.J
elif prior_type == 'learnt_type2':
prior_sgt_mat = torch.LongTensor([[0,15,7,1,0,
1,14,8,7,0,
4,0,9,13,1,
2,6,11,10,2,
5,12,14,16,3]])
prior_sgt_mat = F.one_hot(prior_sgt_mat, num_classes=self.J).float() # 1*self.HW*self.J
elif prior_type == 'learnt_type3':
prior_sgt_mat = torch.LongTensor([[9,7,7,10,7,
13,8,10,15,16,
9,12,7,14,1,
4,5,7,3,11,
7,6,9,2,14]])
prior_sgt_mat = F.one_hot(prior_sgt_mat, num_classes=self.J).float() # 1*self.HW*self.J
elif prior_type == 'random_prob':
prior_sgt_mat = torch.rand([self.HW, self.J])
prior_sgt_mat = F.softmax(prior_sgt_mat, dim=-1).unsqueeze(0)
else:
raise Exception()

return prior_sgt_mat

2 changes: 2 additions & 0 deletions src/tool/argument.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def _initial(self):
self.parser.add_argument('--num_block', type=int, default=2, help='number of residual blocks')
self.parser.add_argument('--padding_mode', type=str, nargs='+', default=['c','r'])
self.parser.add_argument('--grid_shape', type=int, nargs='+', default=[5, 5])
self.parser.add_argument('--autosgt_prior', type=str, default='standard')


def _print(self):
Expand All @@ -55,6 +56,7 @@ def parse(self):
ckpt = os.path.join(self.opt.ckpt, self.opt.exp)
if not os.path.isdir(ckpt):
os.makedirs(ckpt)
self.opt.ckpt = ckpt
self.opt.prepare_grid = self.opt.lifting_model in ['gridconv', 'dgridconv']
self._print()

Expand Down

0 comments on commit f151b11

Please sign in to comment.