Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

onnx problem with edge_index is [2, 0] #9808

Open
LQY404 opened this issue Nov 27, 2024 · 0 comments
Open

onnx problem with edge_index is [2, 0] #9808

LQY404 opened this issue Nov 27, 2024 · 0 comments
Labels

Comments

@LQY404
Copy link

LQY404 commented Nov 27, 2024

🐛 Describe the bug

my model:

class MultiHeadAttentionPool(torch.nn.Module):
    def __init__(self, input_dim, num_heads):
        super(MultiHeadAttentionPool, self).__init__()
        self.num_heads = num_heads
        self.att_mlp = torch.nn.Linear(input_dim, num_heads)

    def forward(self, x, batch):
        alpha = self.att_mlp(x)
        alpha = F.leaky_relu(alpha)

        alpha = F.softmax(alpha, dim=0)
        out = 0
        for head in range(self.num_heads):
            out += alpha[:, head].unsqueeze(-1) * x

        res = torch.zeros((torch.unique(batch).shape[0], out.shape[-1]), dtype=out.dtype).to(out.device)
        return torch.scatter_add(res, 0, batch.unsqueeze(-1).expand(-1, out.shape[-1]), out)

        

class GraphTransformerClassifierAttentionPool(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, num_classes, heads=4, device='cpu', use_patch_loss=True):
        super(GraphTransformerClassifierAttentionPool, self).__init__()
        self.device = device

        self.conv1 = TransformerConv(input_dim, hidden_dim, heads=heads, concat=False, dropout=0.1)
        self.conv2 = TransformerConv(hidden_dim, hidden_dim, heads=heads, concat=False, dropout=0.1)

        self.att_pool = MultiHeadAttentionPool(hidden_dim, 8)
        

    def forward(self, x, edge_index, batch):
        x = x.to(self.device)
        batch = batch.to(dtype=torch.int64).to(self.device)
        edge_index = edge_index.to(self.device)
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        patch_emb = self.att_pool(x, batch)
        return patch_emb

onnx wrapper:

class ONNXWrapper_inner(nn.Module):

        def __init__(self, ori_model):
            super().__init__()
            self.model = ori_model
            self.model.eval()

        def forward(self, x, edge_index):
            # x: [M, 256]
            # edge_index: [2, k]
            
            patch_feature, _ = self.model(x, edge_index, torch.zeros((x.shape[0])))

            return patch_feature

onnx convert code:

from torch_geometric.nn import knn_graph

model_inner = ONNXWrapper_inner(GraphTransformerClassifierAttentionPool(256, 256, 20, heads=8, device='cuda'))
dummy_x = torch.randn((random.randint(1, 100), 256))
dummy_edge_index = knn_graph(F.normalize(dummy_x, p=2, dim=-1), k=10, loop=True).to(torch.int64).to('cuda')
torch.onnx.export(
            model_inner,
            (dummy_x, dummy_edge_index),
                'model_inner.onnx',
                export_params=True,
                opset_version=16,
                do_constant_folding=True,
                input_names=['x', 'edge_index'],
                output_names=['patch_feature'],
                dynamic_axes={
                    'x': {0: 'cell_num', 1: 'cell_dia'},
                    'edge_index': {1: 'edge_num'},
                    'patch_feature': {0: 'patch_num', 1: 'dia'},
                   
            },
            verbose=True,

        )

onnx runtime test code:

session1 = onnxruntime.InferenceSession('./tmp/model_inner.onnx')
input_name1 = session1.get_inputs()[0].name
input_name2 = session1.get_inputs()[1].name
print(input_name1, input_name2)
output_name = session1.get_outputs()[0].name
print(output_name)
for _ in range(10):
    x = torch.randn((random.randint(1, 2), 256))
    patch_list.append(x)
    x = F.normalize(x, p=2, dim=-1)
    edge_index = knn_graph(x, k=10, loop=False) 
    print(x.shape, edge_index.shape)

    outputs = session1.run([e.name for e in session1.get_outputs()], {input_name1: x.numpy(), input_name2: edge_index.numpy()})

In the loop of runtime test(10 loop), if the shape[0] of random generated x more than 1, the onnx can run without problem.
But when the shape[0] of x is 1(in this case, edge_index is [2, 0]), ONNXRuntimeError happend:

2024-11-27 10:59:23.162258055 [E:onnxruntime:, sequential_executor.cc:516 ExecuteKernel] Non-zero status code returned while running Expand node. Name:'/model/conv1/Expand' Status Message: invalid expand shape
return self._sess.run(output_names, input_feed, run_options)
onnxruntime.capi.onnxruntime_pybind11_state.InvalidArgument: [ONNXRuntimeError] : 2 : INVALID_ARGUMENT : Non-zero status code returned while running Expand node. Name:'/model/conv1/Expand' Status Message: invalid expand shape

Versions

the part packages version of my environment:

torch                         2.4.1
torch_cluster                 1.6.3
torch-geometric               2.6.1
torch_scatter                 2.1.2
torch_sparse                  0.6.18
onnx                          1.17.0
numpy                         1.26.4
@LQY404 LQY404 added the bug label Nov 27, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

1 participant