Skip to content

Commit

Permalink
Update to GBI, datanode id
Browse files Browse the repository at this point in the history
  • Loading branch information
Andrzej Uszok committed Nov 1, 2023
1 parent 6c310c5 commit e43524e
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 19 deletions.
4 changes: 4 additions & 0 deletions domiknows/graph/dataNode.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from collections import OrderedDict, namedtuple
from time import perf_counter, perf_counter_ns
import re
from itertools import count

from .dataNodeConfig import dnConfig

Expand Down Expand Up @@ -87,6 +88,8 @@ class DataNode:
- gurobiModel (NoneType): Placeholder for Gurobi model.
- myLoggerTime (Logger): Logger for time measurement.
"""
_ids = count(1)

def __init__(self, myBuilder = None, instanceID = None, instanceValue = None, ontologyNode = None, graph = None, relationLinks = {}, attributes = {}):
"""Initialize a DataNode instance.
Expand All @@ -112,6 +115,7 @@ def __init__(self, myBuilder = None, instanceID = None, instanceValue = None, on
gurobiModel (NoneType): Placeholder for Gurobi model.
myLoggerTime (Logger): Logger for time measurement.
"""
self.id = next(self._ids)
self.myBuilder = myBuilder # DatanodeBuilder used to construct this datanode
self.instanceID = instanceID # The data instance id (e.g. paragraph number, sentence number, phrase number, image number, etc.)
self.instanceValue = instanceValue # Optional value of the instance (e.g. paragraph text, sentence text, phrase text, image bitmap, etc.)
Expand Down
39 changes: 20 additions & 19 deletions domiknows/program/model/gbi.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,12 +167,14 @@ def forward(self, datanode, build=None, print_grads=False):
probs.append(F.log_softmax(var_val, dim=-1).flatten())

log_probs = torch.cat(probs, dim=0).mean()
print('probs mean:')
print(log_probs)
if print_grads:
print('probs mean:')
print(log_probs)

argmax_vals = [torch.argmax(prob) for prob in probs]
print('argmax predictions:')
print(argmax_vals)
if print_grads:
argmax_vals = [torch.argmax(prob) for prob in probs]
print('argmax predictions:')
print(argmax_vals)

# -- Constraint loss: NLL * binary satisfaction + regularization loss
# reg loss is calculated based on L2 distance of weights between optimized model and original weights
Expand All @@ -181,20 +183,18 @@ def forward(self, datanode, build=None, print_grads=False):

if c_loss != c_loss:
continue

print("iter={}, c_loss={:.2f}, c_loss.grad_fn={}, num_constraints_l={}, satisfied={}".format(c_iter, c_loss.item(), c_loss.grad_fn.__class__.__name__, num_constraints_l, num_satisfied_l))
print("reg_loss={:.2f}, reg_loss.grad_fn={}, log_probs={:.2f}, log_probs.grad_fn={}\n".format(reg_loss.item(), reg_loss.grad_fn.__class__.__name__, log_probs.item(), log_probs.grad_fn.__class__.__name__))
if print_grads:
print("iter={}, c_loss={:.2f}, c_loss.grad_fn={}, num_constraints_l={}, satisfied={}".format(c_iter, c_loss.item(), c_loss.grad_fn.__class__.__name__, num_constraints_l, num_satisfied_l))
print("reg_loss={:.2f}, reg_loss.grad_fn={}, log_probs={:.2f}, log_probs.grad_fn={}\n".format(reg_loss.item(), reg_loss.grad_fn.__class__.__name__, log_probs.item(), log_probs.grad_fn.__class__.__name__))

# --- Check if constraints are satisfied
if num_satisfied_l == num_constraints_l:
# --- End early if constraints are satisfied
if model_has_GBI_inference:
self.server_model.inferTypes.append('GBI')
return c_loss, datanode, datanode.myBuilder
elif no_of_not_satisfied > 1000: # temporary change to see behavior c_iter loop finishes
if model_has_GBI_inference:
self.server_model.inferTypes.append('GBI')
return c_loss, datanode, datanode.myBuilder # ? float("nan")

print(f'Finishing GBI - Constraints are satisfied after {c_iter} iteration')
return c_loss, node_l, node_l.myBuilder

# --- Backward pass on self.server_model
if c_loss.requires_grad:
Expand All @@ -215,19 +215,17 @@ def forward(self, datanode, build=None, print_grads=False):
c_opt.step()

if print_grads:
# Print the params of the model parameters which have grad
# Print the params of the model parameters which have grad
print("Params after model step which have grad")
for name, param in self.server_model.named_parameters():
if param.grad is not None and torch.sum(torch.abs(param.grad)) > 0:
print(name, 'param sum ', torch.sum(torch.abs(param)).item())

node_l_builder = None
if node_l is not None:
node_l_builder = node_l.myBuilder

if model_has_GBI_inference:
self.server_model.inferTypes.append('GBI')
return c_loss, node_l, node_l_builder # ? float("nan")

print(f'Finishing GBI - Constraints not are satisfied after {self.gbi_iters} iteration')
return c_loss, node_l, node_l.myBuilder

def calculateGBISelection(self, datanode, conceptsRelations):
c_loss, updatedDatanode, updatedBuilder = self.forward(datanode)
Expand All @@ -249,12 +247,14 @@ def calculateGBISelection(self, datanode, conceptsRelations):

for i, (dn, originalDn) in enumerate(zip(dns, originalDns)):
v = dn.getAttribute(currentConcept) # Get learned probabilities
#print(f'Net(depth={i}inGBI); pred: {torch.argmax(v, dim=-1)}')
if v is None:
continue

# Calculate GBI results
vGBI = torch.zeros(v.size(), dtype=torch.float, device=updatedDatanode.current_device)
vArgmaxIndex = torch.argmax(v).item()
#print(f'vArgmaxIndex: {vArgmaxIndex}')
vGBI[vArgmaxIndex] = 1

# Add GBI inference result to the original datanode
Expand All @@ -272,3 +272,4 @@ def calculateGBISelection(self, datanode, conceptsRelations):
datanode.attributes["variableSet"][keyGBIInVariableSet]=gbiForConcept
updatedDatanode.attributes["variableSet"][keyGBIInVariableSet]=gbiForConcept

return

0 comments on commit e43524e

Please sign in to comment.