From 9d6610b633c0ad9ed5cab96c01a2dba179b6fc42 Mon Sep 17 00:00:00 2001 From: "xiang song(charlie.song)" Date: Thu, 20 Jun 2024 11:44:12 -0700 Subject: [PATCH] Reduce the number of warning prints of heterogeneous GNN encoders when some nodes do not have in edges (#888) *Issue #, if available:* *Description of changes:* To avoid excessive warning logs from RGCN, RGAT, and HGT encoders, this PR ensures that identical warnings (warnings that share the same message) are printed only once. By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice. Co-authored-by: Xiang Song --- python/graphstorm/model/hgt_encoder.py | 22 +++++++++++++++++++--- python/graphstorm/model/rgat_encoder.py | 22 +++++++++++++++++++--- python/graphstorm/model/rgcn_encoder.py | 22 +++++++++++++++++++--- 3 files changed, 57 insertions(+), 9 deletions(-) diff --git a/python/graphstorm/model/hgt_encoder.py b/python/graphstorm/model/hgt_encoder.py index b544f51561..5ae923f4d4 100644 --- a/python/graphstorm/model/hgt_encoder.py +++ b/python/graphstorm/model/hgt_encoder.py @@ -183,6 +183,21 @@ def __init__(self, # Dropout self.drop = nn.Dropout(dropout) + self.warn_msg = set() + + def warning_once(self, warn_msg): + """ Print same warning msg only once + + Parameters + ---------- + warn_msg: str + Warning message + """ + if warn_msg in self.warn_msg: + # Skip printing warning + return + self.warn_msg.add(warn_msg) + logging.warning(warn_msg) def forward(self, g, h): """Forward computation @@ -255,9 +270,10 @@ def forward(self, g, h): else: trans_out = trans_out * alpha + self.a_linears[k](h[k]) * (1-alpha) else: # Nodes not really in destination side. - logging.warning("Warning. Graph convolution returned empty " + \ - f"dictionary for nodes in type: {str(k)}. Please check your data" + \ - f" for no in-degree nodes in type: {str(k)}.") + warn_msg = "Warning. Graph convolution returned empty " \ + f"dictionary for nodes in type: {str(k)}. Please check your data" \ + f" for no in-degree nodes in type: {str(k)}." + self.warning_once(warn_msg) # So add psudo self-loop for the destination nodes with its own feature. dst_h = self.a_linears[k](h[k][:g.num_dst_nodes(k)]) trans_out = self.drop(dst_h) diff --git a/python/graphstorm/model/rgat_encoder.py b/python/graphstorm/model/rgat_encoder.py index bc05ad1c14..1b53357773 100644 --- a/python/graphstorm/model/rgat_encoder.py +++ b/python/graphstorm/model/rgat_encoder.py @@ -150,6 +150,21 @@ def __init__(self, # dropout self.dropout = nn.Dropout(dropout) + self.warn_msg = set() + + def warning_once(self, warn_msg): + """ Print same warning msg only once + + Parameters + ---------- + warn_msg: str + Warning message + """ + if warn_msg in self.warn_msg: + # Skip printing warning + return + self.warn_msg.add(warn_msg) + logging.warning(warn_msg) # pylint: disable=invalid-name def forward(self, g, inputs): @@ -196,9 +211,10 @@ def _apply(ntype, h): for k, _ in inputs.items(): if g.number_of_dst_nodes(k) > 0: if k not in hs: - logging.warning("Warning. Graph convolution returned empty " + \ - f"dictionary for nodes in type: {str(k)}. Please check your data" + \ - f" for no in-degree nodes in type: {str(k)}.") + warn_msg = "Warning. Graph convolution returned empty " \ + f"dictionary for nodes in type: {str(k)}. Please check your data" \ + f" for no in-degree nodes in type: {str(k)}." + self.warning_once(warn_msg) hs[k] = th.zeros((g.number_of_dst_nodes(k), self.out_feat), device=inputs[k].device) diff --git a/python/graphstorm/model/rgcn_encoder.py b/python/graphstorm/model/rgcn_encoder.py index 040658670b..ca43bee0fd 100644 --- a/python/graphstorm/model/rgcn_encoder.py +++ b/python/graphstorm/model/rgcn_encoder.py @@ -179,6 +179,21 @@ def __init__(self, num_ffn_layers_in_gnn, ffn_activation, dropout) self.dropout = nn.Dropout(dropout) + self.warn_msg = set() + + def warning_once(self, warn_msg): + """ Print same warning msg only once + + Parameters + ---------- + warn_msg: str + Warning message + """ + if warn_msg in self.warn_msg: + # Skip printing warning + return + self.warn_msg.add(warn_msg) + logging.warning(warn_msg) # pylint: disable=invalid-name def forward(self, g, inputs): @@ -245,9 +260,10 @@ def _apply(ntype, h): for k, _ in inputs.items(): if g.number_of_dst_nodes(k) > 0: if k not in hs: - logging.warning("Warning. Graph convolution returned empty " + \ - f"dictionary for nodes in type: {str(k)}. Please check your data" + \ - f" for no in-degree nodes in type: {str(k)}.") + warn_msg = "Warning. Graph convolution returned empty " \ + f"dictionary for nodes in type: {str(k)}. Please check your data" \ + f" for no in-degree nodes in type: {str(k)}." + self.warning_once(warn_msg) hs[k] = th.zeros((g.number_of_dst_nodes(k), self.out_feat), device=inputs[k].device)