diff --git a/python/graphstorm/wholegraph/utils.py b/python/graphstorm/wholegraph/utils.py index b67d4f8261..0d1d030998 100644 --- a/python/graphstorm/wholegraph/utils.py +++ b/python/graphstorm/wholegraph/utils.py @@ -22,7 +22,7 @@ def is_wholegraph_embedding(data): is required to use wholegraph framework. """ try: - import pylibwholegraph + import pylibwholegraph.torch return isinstance(data, pylibwholegraph.torch.WholeMemoryEmbedding) except ImportError: return False @@ -33,7 +33,7 @@ def is_wholegraph_embedding_module(data): is required to use wholegraph framework. """ try: - import pylibwholegraph + import pylibwholegraph.torch return isinstance(data, pylibwholegraph.torch.WholeMemoryEmbeddingModule) except: # pylint: disable=bare-except return False @@ -44,7 +44,7 @@ def is_wholegraph_optimizer(data): is required to use wholegraph framework. """ try: - import pylibwholegraph + import pylibwholegraph.torch return isinstance(data, pylibwholegraph.torch.WholeMemoryOptimizer) except: # pylint: disable=bare-except return False