From e43524e3d6238794eda089a4a48f706917e57114 Mon Sep 17 00:00:00 2001 From: Andrzej Uszok Date: Tue, 31 Oct 2023 22:41:10 -0500 Subject: [PATCH] Update to GBI, datanode id --- domiknows/graph/dataNode.py | 4 ++++ domiknows/program/model/gbi.py | 39 +++++++++++++++++----------------- 2 files changed, 24 insertions(+), 19 deletions(-) diff --git a/domiknows/graph/dataNode.py b/domiknows/graph/dataNode.py index a7cf29b2..95792392 100644 --- a/domiknows/graph/dataNode.py +++ b/domiknows/graph/dataNode.py @@ -2,6 +2,7 @@ from collections import OrderedDict, namedtuple from time import perf_counter, perf_counter_ns import re +from itertools import count from .dataNodeConfig import dnConfig @@ -87,6 +88,8 @@ class DataNode: - gurobiModel (NoneType): Placeholder for Gurobi model. - myLoggerTime (Logger): Logger for time measurement. """ + _ids = count(1) + def __init__(self, myBuilder = None, instanceID = None, instanceValue = None, ontologyNode = None, graph = None, relationLinks = {}, attributes = {}): """Initialize a DataNode instance. @@ -112,6 +115,7 @@ def __init__(self, myBuilder = None, instanceID = None, instanceValue = None, on gurobiModel (NoneType): Placeholder for Gurobi model. myLoggerTime (Logger): Logger for time measurement. """ + self.id = next(self._ids) self.myBuilder = myBuilder # DatanodeBuilder used to construct this datanode self.instanceID = instanceID # The data instance id (e.g. paragraph number, sentence number, phrase number, image number, etc.) self.instanceValue = instanceValue # Optional value of the instance (e.g. paragraph text, sentence text, phrase text, image bitmap, etc.) diff --git a/domiknows/program/model/gbi.py b/domiknows/program/model/gbi.py index 460adcd3..9b8da721 100644 --- a/domiknows/program/model/gbi.py +++ b/domiknows/program/model/gbi.py @@ -167,12 +167,14 @@ def forward(self, datanode, build=None, print_grads=False): probs.append(F.log_softmax(var_val, dim=-1).flatten()) log_probs = torch.cat(probs, dim=0).mean() - print('probs mean:') - print(log_probs) + if print_grads: + print('probs mean:') + print(log_probs) - argmax_vals = [torch.argmax(prob) for prob in probs] - print('argmax predictions:') - print(argmax_vals) + if print_grads: + argmax_vals = [torch.argmax(prob) for prob in probs] + print('argmax predictions:') + print(argmax_vals) # -- Constraint loss: NLL * binary satisfaction + regularization loss # reg loss is calculated based on L2 distance of weights between optimized model and original weights @@ -181,20 +183,18 @@ def forward(self, datanode, build=None, print_grads=False): if c_loss != c_loss: continue - - print("iter={}, c_loss={:.2f}, c_loss.grad_fn={}, num_constraints_l={}, satisfied={}".format(c_iter, c_loss.item(), c_loss.grad_fn.__class__.__name__, num_constraints_l, num_satisfied_l)) - print("reg_loss={:.2f}, reg_loss.grad_fn={}, log_probs={:.2f}, log_probs.grad_fn={}\n".format(reg_loss.item(), reg_loss.grad_fn.__class__.__name__, log_probs.item(), log_probs.grad_fn.__class__.__name__)) + if print_grads: + print("iter={}, c_loss={:.2f}, c_loss.grad_fn={}, num_constraints_l={}, satisfied={}".format(c_iter, c_loss.item(), c_loss.grad_fn.__class__.__name__, num_constraints_l, num_satisfied_l)) + print("reg_loss={:.2f}, reg_loss.grad_fn={}, log_probs={:.2f}, log_probs.grad_fn={}\n".format(reg_loss.item(), reg_loss.grad_fn.__class__.__name__, log_probs.item(), log_probs.grad_fn.__class__.__name__)) # --- Check if constraints are satisfied if num_satisfied_l == num_constraints_l: # --- End early if constraints are satisfied if model_has_GBI_inference: self.server_model.inferTypes.append('GBI') - return c_loss, datanode, datanode.myBuilder - elif no_of_not_satisfied > 1000: # temporary change to see behavior c_iter loop finishes - if model_has_GBI_inference: - self.server_model.inferTypes.append('GBI') - return c_loss, datanode, datanode.myBuilder # ? float("nan") + + print(f'Finishing GBI - Constraints are satisfied after {c_iter} iteration') + return c_loss, node_l, node_l.myBuilder # --- Backward pass on self.server_model if c_loss.requires_grad: @@ -215,19 +215,17 @@ def forward(self, datanode, build=None, print_grads=False): c_opt.step() if print_grads: - # Print the params of the model parameters which have grad + # Print the params of the model parameters which have grad print("Params after model step which have grad") for name, param in self.server_model.named_parameters(): if param.grad is not None and torch.sum(torch.abs(param.grad)) > 0: print(name, 'param sum ', torch.sum(torch.abs(param)).item()) - - node_l_builder = None - if node_l is not None: - node_l_builder = node_l.myBuilder if model_has_GBI_inference: self.server_model.inferTypes.append('GBI') - return c_loss, node_l, node_l_builder # ? float("nan") + + print(f'Finishing GBI - Constraints not are satisfied after {self.gbi_iters} iteration') + return c_loss, node_l, node_l.myBuilder def calculateGBISelection(self, datanode, conceptsRelations): c_loss, updatedDatanode, updatedBuilder = self.forward(datanode) @@ -249,12 +247,14 @@ def calculateGBISelection(self, datanode, conceptsRelations): for i, (dn, originalDn) in enumerate(zip(dns, originalDns)): v = dn.getAttribute(currentConcept) # Get learned probabilities + #print(f'Net(depth={i}inGBI); pred: {torch.argmax(v, dim=-1)}') if v is None: continue # Calculate GBI results vGBI = torch.zeros(v.size(), dtype=torch.float, device=updatedDatanode.current_device) vArgmaxIndex = torch.argmax(v).item() + #print(f'vArgmaxIndex: {vArgmaxIndex}') vGBI[vArgmaxIndex] = 1 # Add GBI inference result to the original datanode @@ -272,3 +272,4 @@ def calculateGBISelection(self, datanode, conceptsRelations): datanode.attributes["variableSet"][keyGBIInVariableSet]=gbiForConcept updatedDatanode.attributes["variableSet"][keyGBIInVariableSet]=gbiForConcept + return \ No newline at end of file