Skip to content

Commit

Permalink
[Misc] Black auto fix. (dmlc#4641)
Browse files Browse the repository at this point in the history
* [Misc] Black auto fix.

* sort

Co-authored-by: Steve <[email protected]>
  • Loading branch information
frozenbugs and Steve authored Sep 26, 2022
1 parent 08c50eb commit a9f2acf
Show file tree
Hide file tree
Showing 47 changed files with 3,910 additions and 2,101 deletions.
50 changes: 28 additions & 22 deletions benchmarks/benchmarks/model_speed/bench_sage.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,26 @@
import time
import dgl
from dgl.nn.pytorch import SAGEConv

import torch
import torch.nn as nn
import torch.nn.functional as F

import dgl
from dgl.nn.pytorch import SAGEConv

from .. import utils


class GraphSAGE(nn.Module):
def __init__(self,
in_feats,
n_hidden,
n_classes,
n_layers,
activation,
dropout,
aggregator_type):
def __init__(
self,
in_feats,
n_hidden,
n_classes,
n_layers,
activation,
dropout,
aggregator_type,
):
super(GraphSAGE, self).__init__()
self.layers = nn.ModuleList()
self.dropout = nn.Dropout(dropout)
Expand All @@ -27,7 +32,9 @@ def __init__(self,
for i in range(n_layers - 1):
self.layers.append(SAGEConv(n_hidden, n_hidden, aggregator_type))
# output layer
self.layers.append(SAGEConv(n_hidden, n_classes, aggregator_type)) # activation None
self.layers.append(
SAGEConv(n_hidden, n_classes, aggregator_type)
) # activation None

def forward(self, graph, inputs):
h = self.dropout(inputs)
Expand All @@ -38,20 +45,21 @@ def forward(self, graph, inputs):
h = self.dropout(h)
return h

@utils.benchmark('time')
@utils.parametrize('data', ['cora', 'pubmed'])

@utils.benchmark("time")
@utils.parametrize("data", ["cora", "pubmed"])
def track_time(data):
data = utils.process_data(data)
device = utils.get_bench_device()
num_epochs = 200

g = data[0].to(device)

features = g.ndata['feat']
labels = g.ndata['label']
train_mask = g.ndata['train_mask']
val_mask = g.ndata['val_mask']
test_mask = g.ndata['test_mask']
features = g.ndata["feat"]
labels = g.ndata["label"]
train_mask = g.ndata["train_mask"]
val_mask = g.ndata["val_mask"]
test_mask = g.ndata["test_mask"]

in_feats = features.shape[1]
n_classes = data.num_classes
Expand All @@ -60,16 +68,14 @@ def track_time(data):
g = dgl.add_self_loop(g)

# create model
model = GraphSAGE(in_feats, 16, n_classes, 1, F.relu, 0.5, 'gcn')
model = GraphSAGE(in_feats, 16, n_classes, 1, F.relu, 0.5, "gcn")
loss_fcn = torch.nn.CrossEntropyLoss()

model = model.to(device)
model.train()

# optimizer
optimizer = torch.optim.Adam(model.parameters(),
lr=1e-2,
weight_decay=5e-4)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2, weight_decay=5e-4)

# dry run
for i in range(10):
Expand Down
Loading

0 comments on commit a9f2acf

Please sign in to comment.