Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SparseGCN performance issue #8

Open
nissy-dev opened this issue Jun 29, 2020 · 0 comments
Open

SparseGCN performance issue #8

nissy-dev opened this issue Jun 29, 2020 · 0 comments
Labels
enhancement New feature or request help wanted Extra attention is needed

Comments

@nissy-dev
Copy link
Member

nissy-dev commented Jun 29, 2020

SparseGCN has a serious performance issue.
Training time/epoch of the Tox21 example is almost 30 times than PadGCN.

Result on my local PC (CPU)

Log about SparseGCN

$ python gcn_sparse_pattern_example.py

Iter 0/50 (241.6028 s) valid loss: 0.1522             valid roc_auc score: 0.6647
Iter 1/50 (208.9176 s) valid loss: 0.1649             valid roc_auc score: 0.6955
Iter 2/50 (225.3157 s) valid loss: 0.1516             valid roc_auc score: 0.7013

Log about PadGCN

$ python gcn_pad_pattern_example.py

Iter 0/50 (18.0109 s) valid loss: 0.1648             valid roc_auc score: 0.5425
Iter 1/50 (7.1491 s) valid loss: 0.1645             valid roc_auc score: 0.5315
Iter 2/50 (6.5680 s) valid loss: 0.1512             valid roc_auc score: 0.5597

The reason of this performance issue is related to jax-ml/jax#2242.
The SparseGCN uses jax.ops.index_add, but a large Python "for" loop leads to a serious performance issue when involving jax.ops.index_add.

According to issue comments, I have to rewrite training loop using lax.scan or lax.fori_loop in order to resolve this issue. I think if the training loop is rewritten using lax.scan or lax.fori_loop, it will improve the performance about not only SparseGCN but also PadGCN. Therefore, it is really important to resolve this issue.

However, lax.scan or lax.fori_loop were effected by functional programing style and it is difficult to treat them. So, it is difficult to rewrite training loop and I'm struggling this issue. I explain what is blocking my work.

1. lax.scan or lax.fori_loop don't accept a side effect

DeepChem's DiskDataset provides the iterbatches. We can use this method to write training loop like below.

for epoch in range(num_epochs):
    for batch in train_dataset.iterbatches(batch_size=batch_size):
        params, predict = forward(batch, params)

But, lax.scan or lax.fori_loop don't accept a side effect (like iterator/generator). So, I try to implement like below, but it didin't work. I made the issue related to this topic, please confirm jax-ml/jax#3567

train_iterator = train_dataset.iterbatches(batch_size=batch_size)

def run_epoch(init_params):
    def body_fun(idx, params):
        # this iterator doesn't work... batch value is always same in a loop
        batch = next(train_iterator)
        params, predict = forward(batch, params)
        return params

    return lax.fori_loop(0, train_num_batches, body_fun, init_params)

for epoch in range(num_epochs):
    params = run_epoch(params)

2. All values in the body_fun of lax.scan or lax.fori_loop don't accept changing the shape

All values in lax.scan or lax.fori_loop, like return value, arguments and so on, don't accept changing the shape. (See the documentation) This is a hard limitation of lax.scan or lax.fori_loop. (To be honest, there is also some additional limitation.... like jax-ml/jax#2962 )

One of the pain points is that it is difficult to treat accumulation operations (like adding a value to the list each loop). I explained some example!

# NG
val = []
for i in range(10):
   val.append(i)

# OK
val = np.zeros(10)
for i in range(10):
   val[i] = i

This point may be a problem if the number of metrics which we want to collect is increasing. Sometimes, we need a creative implementation.(See : jax-ml/jax#1708)

Another pain point is that the sparse pattern mini-batch is incompatible with this limitation.
In the case of the sparse pattern modeling, mini-batch data is changing a shape each batch like below.
This is the example of PyTorch Geometric. (x is a node feature)

>>> from torch_geometric.datasets import TUDataset
>>> from torch_geometric.data import DataLoader
>>> dataset = TUDataset(root='/tmp/ENZYMES', name='ENZYMES', use_node_attr=True)
>>> loader = DataLoader(dataset, batch_size=32, shuffle=True)
>>> for batch in loader:
...         batch
...
Batch(batch=[1137], edge_index=[2, 4368], x=[1137, 21], y=[32])
Batch(batch=[1144], edge_index=[2, 4408], x=[1144, 21], y=[32])
Batch(batch=[1191], edge_index=[2, 4444], x=[1191, 21], y=[32])
Batch(batch=[1087], edge_index=[2, 4288], x=[1087, 21], y=[32])
Batch(batch=[644], edge_index=[2, 2540], x=[644, 21], y=[24])

The sparse pattern modeling constructs one big graph each batch by unifying all graphs, so each mini-batch data has a different shape.

# batch_size = 3

graph 1 : node_feat (5, 100) edge_idx (2, 4) 
graph 2 : node_feat (9, 100) edge_idx (2, 6)
graph 3 : node_feat (5, 100) edge_idx (2, 7)

-> mini-batch graph : node(19, 100) edge_idx(2, 17)

This is a serious problem about implementing the sparse pattern model. Now, I'm thinking how to resolve this shape issue. The one solution is padding mini-batch graph data like below.

mini-batch graph : node(19, 100) edge_idx(2, 17) -> node(19, 100) edge_idx(2, 17)
mini-batch graph : node(15, 100) edge_idx(2, 13) -> node(19, 100) edge_idx(2, 17)
mini-batch graph : node(12, 100) edge_idx(2, 11) -> node(19, 100) edge_idx(2, 17)

However, we should care about padding values because the values have possibilities to affect the node aggregation algorithm of the sparse pattern.

3. It is difficult to debug the body_fun in lax.scan or lax.fori_loop

It is difficult to debug the body_fun like adding print function in lax.scan or lax.fori_loop. This point is also discussed in this issue jax-ml/jax#999, but the issue is still open...

train_iterator = train_dataset.iterbatches(batch_size=batch_size)

def run_epoch(init_params):
    def body_fun(idx, params):
        # this iterator doesn't work... batch value is always same in a loop
        batch = next(train_iterator)
        params, predict = forward(batch, params)
        #  any values were printed....
        print(predict)
        return params

    return lax.fori_loop(0, train_num_batches, body_fun, init_params)

for epoch in range(num_epochs):
    params = run_epoch(params)
@nissy-dev nissy-dev added enhancement New feature or request help wanted Extra attention is needed labels Jul 1, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request help wanted Extra attention is needed
Projects
None yet
Development

No branches or pull requests

1 participant