Skip to content

Commit

Permalink
take back DGL example changes.
Browse files Browse the repository at this point in the history
  • Loading branch information
mfbalin committed Sep 7, 2024
1 parent 9e445c6 commit d04dde1
Showing 1 changed file with 5 additions and 19 deletions.
24 changes: 5 additions & 19 deletions examples/graphbolt/rgcn/hetero_rgcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@
import torch.nn as nn
import torch.nn.functional as F
from dgl.nn import HeteroEmbedding
from evaluator import IGB_Evaluator
from ogb.lsc import MAG240MEvaluator
from ogb.nodeproppred import Evaluator
from tqdm import tqdm
Expand Down Expand Up @@ -142,10 +141,6 @@ def create_dataloader(
if name == "ogb-lsc-mag240m":
node_feature_keys["author"] = ["feat"]
node_feature_keys["institution"] = ["feat"]
if "igb-het" in name:
node_feature_keys["author"] = ["feat"]
node_feature_keys["institute"] = ["feat"]
node_feature_keys["fos"] = ["feat"]
datapipe = datapipe.fetch_feature(features, node_feature_keys)

# Create a DataLoader from the datapipe.
Expand All @@ -163,7 +158,7 @@ def extract_embed(node_embed, input_nodes):

def extract_node_features(name, block, data, node_embed, device):
"""Extract the node features from embedding layer or raw features."""
if name == "ogbn-mag" or "igb-het" in name:
if name == "ogbn-mag":
input_nodes = {
k: v.to(device) for k, v in block.srcdata[dgl.NID].items()
}
Expand Down Expand Up @@ -429,9 +424,7 @@ def evaluate(
model.eval()
category = "paper"
# An evaluator for the dataset.
if "igb-het" in name:
evaluator = IGB_Evaluator(name=name, num_tasks=1, eval_metric="acc")
elif name == "ogbn-mag":
if name == "ogbn-mag":
evaluator = Evaluator(name=name)
else:
evaluator = MAG240MEvaluator()
Expand Down Expand Up @@ -595,7 +588,7 @@ def main(args):
# `institution` are generated in advance and stored in the feature store.
# For `ogbn-mag`, we generate the features on the fly.
embed_layer = None
if args.dataset == "ogbn-mag" or "igb-het" in args.dataset:
if args.dataset == "ogbn-mag":
# Create the embedding layer and move it to the appropriate device.
embed_layer = rel_graph_embed(g, feat_size).to(device)
print(
Expand Down Expand Up @@ -670,15 +663,8 @@ def main(args):
"--dataset",
type=str,
default="ogbn-mag",
choices=[
"ogbn-mag",
"ogb-lsc-mag240m",
"igb-het-tiny",
"igb-het-small",
"igb-het-medium",
],
help="Dataset name. Possible values: ogbn-mag, ogb-lsc-mag240m, "
" igb-het-[tiny|small|medium].",
choices=["ogbn-mag", "ogb-lsc-mag240m"],
help="Dataset name. Possible values: ogbn-mag, ogb-lsc-mag240m",
)
parser.add_argument("--num_epochs", type=int, default=3)
parser.add_argument("--num_workers", type=int, default=0)
Expand Down

0 comments on commit d04dde1

Please sign in to comment.