Skip to content

Commit

Permalink
Merge pull request #386 from HLR/gbi_tests
Browse files Browse the repository at this point in the history
GBI minor algorithm changes
  • Loading branch information
auszok authored Sep 20, 2023
2 parents ae63fa2 + 38f3d14 commit 8d05d35
Showing 1 changed file with 27 additions and 38 deletions.
65 changes: 27 additions & 38 deletions domiknows/program/model/gbi.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,12 +108,12 @@ def find_last_layers_in_submodels(self, model, name=""):
def reset_last_layers_in_submodels(self, model, last_layers):
for name, last_layer in last_layers.items():
if isinstance(last_layer, (nn.Linear, nn.Conv2d)):
last_layer.bias.data += torch.randn_like(last_layer.bias.data) * 0.01
last_layer.weight.data += torch.randn_like(last_layer.weight.data) * 0.01
last_layer.bias.data += torch.randn_like(last_layer.bias.data) * 1e-4
last_layer.weight.data += torch.randn_like(last_layer.weight.data) * 1e-4

# ----

def forward(self, datanode, build=None):
def forward(self, datanode, build=None, print_grads=False):
# Get constraint satisfaction for the current DataNode
num_satisfied, num_constraints = self.get_constraints_satisfaction(datanode)
model_has_GBI_inference = False
Expand Down Expand Up @@ -144,7 +144,7 @@ def forward(self, datanode, build=None):
self.server_model.reset()

modelLParams = self.server_model.parameters()
c_opt = Adam(modelLParams, lr=1e-1, betas=[0.9, 0.999], eps=1e-07, amsgrad=False) #SGD(modelLParams, lr=1e-1)
c_opt = Adam(modelLParams, lr=1e-2, betas=[0.9, 0.999], eps=1e-07, amsgrad=False) #SGD(modelLParams, lr=1e-1)

# Remove "GBI" from the list of inference types if model has it
if hasattr(self.server_model, 'inferTypes'):
Expand All @@ -160,37 +160,24 @@ def forward(self, datanode, build=None):

num_satisfied_l, num_constraints_l = self.get_constraints_satisfaction(node_l)

if num_satisfied_l == num_constraints_l:
is_satisfied = 1
no_of_not_satisfied = 0
else:
is_satisfied = 0
no_of_not_satisfied += 1

is_satisfied = num_satisfied_l/num_constraints_l
# -- collect probs from datanode (in skeleton mode)
probs = {}
probs = []
for var_name, var_val in node_l.getAttribute('variableSet').items():
if var_name.endswith('>'):# and var_val.requires_grad:
probs[var_name] = torch.sum(F.log_softmax(var_val, dim=-1))
probs.append(F.log_softmax(var_val, dim=-1).flatten())

# print probs with the keys
print("probs:")
for key, value in probs.items():
print(key, value.item())

# get total log prob
log_probs = 0.0
for c_prob in probs.values():
eps = 1e-7
t = F.relu(c_prob)
tLog = torch.log(t + eps)
log_probs += torch.sum(tLog)
log_probs = torch.cat(probs, dim=0).mean()
print('probs mean:')
print(log_probs)

argmax_vals = [torch.argmax(prob) for prob in probs]
print('argmax predictions:')
print(argmax_vals)

# -- Constraint loss: NLL * binary satisfaction + regularization loss
# reg loss is calculated based on L2 distance of weights between optimized model and original weights
reg_loss = self.reg_loss(optimized_parameters, original_parameters)
c_loss = -1 * log_probs * is_satisfied + reg_loss
c_loss = log_probs * ((num_constraints_l - num_satisfied_l) / num_constraints_l) + reg_loss

if c_loss != c_loss:
continue
Expand All @@ -204,7 +191,7 @@ def forward(self, datanode, build=None):
if model_has_GBI_inference:
self.server_model.inferTypes.append('GBI')
return c_loss, datanode, datanode.myBuilder
elif no_of_not_satisfied > 10: # three consecutive iterations where constraints are not satisfied
elif no_of_not_satisfied > 1000: # temporary change to see behavior c_iter loop finishes
if model_has_GBI_inference:
self.server_model.inferTypes.append('GBI')
return c_loss, datanode, datanode.myBuilder # ? float("nan")
Expand All @@ -217,20 +204,22 @@ def forward(self, datanode, build=None):
# Compute gradients
c_loss.backward()

# Print the params of the model parameters which have grad
print("Params before model step which have grad")
for name, param in self.server_model.named_parameters():
if param.grad is not None and torch.sum(torch.abs(param.grad)) > 0:
print(name, 'param sum ', torch.sum(torch.abs(param)).item())
if print_grads:
# Print the params of the model parameters which have grad
print("Params before model step which have grad")
for name, param in self.server_model.named_parameters():
if param.grad is not None and torch.sum(torch.abs(param.grad)) > 0:
print(name, 'param sum ', torch.sum(torch.abs(param)).item())

# Update self.server_model params based on gradients
c_opt.step()

# Print the params of the model parameters which have grad
print("Params after model step which have grad")
for name, param in self.server_model.named_parameters():
if param.grad is not None and torch.sum(torch.abs(param.grad)) > 0:
print(name, 'param sum ', torch.sum(torch.abs(param)).item())
if print_grads:
# Print the params of the model parameters which have grad
print("Params after model step which have grad")
for name, param in self.server_model.named_parameters():
if param.grad is not None and torch.sum(torch.abs(param.grad)) > 0:
print(name, 'param sum ', torch.sum(torch.abs(param)).item())

node_l_builder = None
if node_l is not None:
Expand Down

0 comments on commit 8d05d35

Please sign in to comment.