Skip to content

Commit

Permalink
Reduce the number of warning prints of heterogeneous GNN encoders whe…
Browse files Browse the repository at this point in the history
…n some nodes do not have in edges (#888)

*Issue #, if available:*

*Description of changes:*
To avoid excessive warning logs from RGCN, RGAT, and HGT encoders, this
PR ensures that identical warnings (warnings that share the same
message) are printed only once.


By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice.

Co-authored-by: Xiang Song <[email protected]>
  • Loading branch information
classicsong and Xiang Song authored Jun 20, 2024
1 parent 63d3051 commit 9d6610b
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 9 deletions.
22 changes: 19 additions & 3 deletions python/graphstorm/model/hgt_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,21 @@ def __init__(self,

# Dropout
self.drop = nn.Dropout(dropout)
self.warn_msg = set()

def warning_once(self, warn_msg):
""" Print same warning msg only once
Parameters
----------
warn_msg: str
Warning message
"""
if warn_msg in self.warn_msg:
# Skip printing warning
return
self.warn_msg.add(warn_msg)
logging.warning(warn_msg)

def forward(self, g, h):
"""Forward computation
Expand Down Expand Up @@ -255,9 +270,10 @@ def forward(self, g, h):
else:
trans_out = trans_out * alpha + self.a_linears[k](h[k]) * (1-alpha)
else: # Nodes not really in destination side.
logging.warning("Warning. Graph convolution returned empty " + \
f"dictionary for nodes in type: {str(k)}. Please check your data" + \
f" for no in-degree nodes in type: {str(k)}.")
warn_msg = "Warning. Graph convolution returned empty " \
f"dictionary for nodes in type: {str(k)}. Please check your data" \
f" for no in-degree nodes in type: {str(k)}."
self.warning_once(warn_msg)
# So add psudo self-loop for the destination nodes with its own feature.
dst_h = self.a_linears[k](h[k][:g.num_dst_nodes(k)])
trans_out = self.drop(dst_h)
Expand Down
22 changes: 19 additions & 3 deletions python/graphstorm/model/rgat_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,21 @@ def __init__(self,

# dropout
self.dropout = nn.Dropout(dropout)
self.warn_msg = set()

def warning_once(self, warn_msg):
""" Print same warning msg only once
Parameters
----------
warn_msg: str
Warning message
"""
if warn_msg in self.warn_msg:
# Skip printing warning
return
self.warn_msg.add(warn_msg)
logging.warning(warn_msg)

# pylint: disable=invalid-name
def forward(self, g, inputs):
Expand Down Expand Up @@ -196,9 +211,10 @@ def _apply(ntype, h):
for k, _ in inputs.items():
if g.number_of_dst_nodes(k) > 0:
if k not in hs:
logging.warning("Warning. Graph convolution returned empty " + \
f"dictionary for nodes in type: {str(k)}. Please check your data" + \
f" for no in-degree nodes in type: {str(k)}.")
warn_msg = "Warning. Graph convolution returned empty " \
f"dictionary for nodes in type: {str(k)}. Please check your data" \
f" for no in-degree nodes in type: {str(k)}."
self.warning_once(warn_msg)
hs[k] = th.zeros((g.number_of_dst_nodes(k),
self.out_feat),
device=inputs[k].device)
Expand Down
22 changes: 19 additions & 3 deletions python/graphstorm/model/rgcn_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,21 @@ def __init__(self,
num_ffn_layers_in_gnn, ffn_activation, dropout)

self.dropout = nn.Dropout(dropout)
self.warn_msg = set()

def warning_once(self, warn_msg):
""" Print same warning msg only once
Parameters
----------
warn_msg: str
Warning message
"""
if warn_msg in self.warn_msg:
# Skip printing warning
return
self.warn_msg.add(warn_msg)
logging.warning(warn_msg)

# pylint: disable=invalid-name
def forward(self, g, inputs):
Expand Down Expand Up @@ -245,9 +260,10 @@ def _apply(ntype, h):
for k, _ in inputs.items():
if g.number_of_dst_nodes(k) > 0:
if k not in hs:
logging.warning("Warning. Graph convolution returned empty " + \
f"dictionary for nodes in type: {str(k)}. Please check your data" + \
f" for no in-degree nodes in type: {str(k)}.")
warn_msg = "Warning. Graph convolution returned empty " \
f"dictionary for nodes in type: {str(k)}. Please check your data" \
f" for no in-degree nodes in type: {str(k)}."
self.warning_once(warn_msg)
hs[k] = th.zeros((g.number_of_dst_nodes(k),
self.out_feat),
device=inputs[k].device)
Expand Down

0 comments on commit 9d6610b

Please sign in to comment.