Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ONNX parsing for SkipSimplifiedLayerNormalization #3140

Merged
merged 18 commits into from
Aug 23, 2024

Conversation

turneram
Copy link
Contributor

No description provided.

@turneram turneram requested a review from causten as a code owner May 31, 2024 17:49
Copy link

codecov bot commented May 31, 2024

Codecov Report

Attention: Patch coverage is 93.93939% with 2 lines in your changes missing coverage. Please review.

Project coverage is 92.25%. Comparing base (03c43e5) to head (a3dae61).
Report is 154 commits behind head on develop.

Files with missing lines Patch % Lines
...onnx/parse_skip_simplified_layer_normalization.cpp 93.93% 2 Missing ⚠️
Additional details and impacted files
@@           Coverage Diff            @@
##           develop    #3140   +/-   ##
========================================
  Coverage    92.25%   92.25%           
========================================
  Files          500      501    +1     
  Lines        20054    20087   +33     
========================================
+ Hits         18500    18531   +31     
- Misses        1554     1556    +2     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.


🚨 Try these New Features:

@turneram turneram linked an issue Jun 5, 2024 that may be closed by this pull request
@turneram turneram requested review from umangyadav and CharlieL7 June 5, 2024 17:18
Copy link
Collaborator

@CharlieL7 CharlieL7 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks fine, I have the same questions as for #3129. Answering one should answer the other.

@causten causten requested a review from pfultz2 June 11, 2024 21:35
// bias (optional) : T
// 1D bias tensor with shape (hidden_size) - not used by ORT

if(args.size() != 3)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there not a 4th input to handle then? Based on the doc its specifying 3-4 inputs.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Onnxruntime didn't implement theirs to utilize the bias input, possibly because the only model that currently uses this op doesn't call it with a bias input. I could add it if we want to be inline with the specs, but since I was using theirs as a baseline for correctness I wrote ours the same way.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I get it. Nothing wrong with what you did. I think its better to cover our bases on this though as we've seen fun outcomes with some of these optional inputs

My only concern really is because microsoft wrote:
This version of the operator has been available since version 1 of the 'com.microsoft' operator set.

https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.SkipSimplifiedLayerNormalization

Copy link
Collaborator

@TedThemistokleous TedThemistokleous left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add the optional input parameters and attributes and outputs. Likely the outputs can just add be composite of existing ops we have like ReduceSum for the mean and bias sum outputs

  • bias input
  • mean (output)
  • inv_std_var (output)
  • input_skip_bias_sum (output)

@TedThemistokleous TedThemistokleous added the Onnx Operators Adding or modifying an Onnx Operator in the MIGraphX codebase label Jul 10, 2024
Comment on lines +33 to +48
std::vector<half> x{half{0.8},
half{-0.5},
half{0.0},
half{1.0},
half{0.5},
half{0.2},
half{0.3},
half{-0.6},
half{10.0},
half{-1.0},
half{0.0},
half{1.0},
half{1.2},
half{3.2},
half{-4.1},
half{5.3}};
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You shouldn't require casts

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removing the casts causes compilation errors. It looks like the way I have it is what we do in other tests using the half type.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does look like we use it like this, but just as an alternate solution you can try passing Float vector to migraphx::argument which has "half_type" as its shape.

@CharlieL7 CharlieL7 self-requested a review July 23, 2024 13:53

if(x_rank < 2 or x_rank > 3 or x_rank != skip_rank or gamma_rank != 1)
{
MIGRAPHX_THROW("PARSE_SKIPSIMPLIFIEDLAYERNORMALIZATION: invalid input shape");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add some test cases for the invalid input cases that trigger this and the above throws

Copy link
Collaborator

@TedThemistokleous TedThemistokleous left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CI for clang, hiprtc,etc is all green, Not sure why its not reflected on this PR.

Add the two minor tests to check the throw cases apart from that looks good. Will kick off a merge from develop to see if that solves CI blockage with windows

@umangyadav
Copy link
Member

@turneram fix tidy

@CharlieL7 CharlieL7 removed their request for review August 21, 2024 18:43
@migraphx-bot
Copy link
Collaborator

Test Batch Rate new
a3dae6
Rate old
273d81
Diff Compare
torchvision-resnet50 64 3,246.16 3,251.64 -0.17%
torchvision-resnet50_fp16 64 6,984.28 6,991.78 -0.11%
torchvision-densenet121 32 2,434.09 2,434.61 -0.02%
torchvision-densenet121_fp16 32 4,074.48 4,073.09 0.03%
torchvision-inceptionv3 32 1,632.80 1,636.36 -0.22%
torchvision-inceptionv3_fp16 32 2,736.68 2,738.76 -0.08%
cadene-inceptionv4 16 774.68 776.27 -0.20%
cadene-resnext64x4 16 808.47 809.36 -0.11%
slim-mobilenet 64 7,437.76 7,455.99 -0.24%
slim-nasnetalarge 64 207.99 208.09 -0.05%
slim-resnet50v2 64 3,333.96 3,343.59 -0.29%
bert-mrpc-onnx 8 1,152.90 1,147.91 0.44%
bert-mrpc-tf 1 310.08 305.42 1.53%
pytorch-examples-wlang-gru 1 410.50 430.33 -4.61% 🔴
pytorch-examples-wlang-lstm 1 441.06 440.22 0.19%
torchvision-resnet50_1 1 804.23 765.01 5.13% 🔆
cadene-dpn92_1 1 426.81 434.49 -1.77%
cadene-resnext101_1 1 378.78 372.41 1.71%
onnx-taau-downsample 1 342.68 343.93 -0.36%
dlrm-criteoterabyte 1 35.01 35.07 -0.15%
dlrm-criteoterabyte_fp16 1 58.11 58.07 0.08%
agentmodel 1 7,799.89 8,142.66 -4.21% 🔴
unet_fp16 2 57.78 57.95 -0.29%
resnet50v1_fp16 1 1,003.70 913.83 9.83% 🔆
resnet50v1_int8 1 976.57 961.09 1.61%
bert_base_cased_fp16 64 1,147.32 1,148.86 -0.13%
bert_large_uncased_fp16 32 354.43 355.19 -0.21%
bert_large_fp16 1 209.92 211.63 -0.81%
distilgpt2_fp16 16 2,155.79 2,153.35 0.11%
yolov5s 1 534.49 502.62 6.34% 🔆
tinyllama 1 43.39 43.39 0.00%
vicuna-fastchat 1 170.91 174.27 -1.93%
whisper-tiny-encoder 1 408.45 408.53 -0.02%
whisper-tiny-decoder 1 430.59 431.61 -0.24%

This build is not recommended to merge 🔴

@migraphx-bot
Copy link
Collaborator


     ✅ bert-mrpc-onnx: PASSED: MIGraphX meets tolerance

     ✅ bert-mrpc-tf: PASSED: MIGraphX meets tolerance

     ✅ pytorch-examples-wlang-gru: PASSED: MIGraphX meets tolerance

     ✅ pytorch-examples-wlang-lstm: PASSED: MIGraphX meets tolerance

     ✅ torchvision-resnet50_1: PASSED: MIGraphX meets tolerance

     ✅ cadene-dpn92_1: PASSED: MIGraphX meets tolerance

     ✅ cadene-resnext101_1: PASSED: MIGraphX meets tolerance

     ✅ dlrm-criteoterabyte: PASSED: MIGraphX meets tolerance

     ✅ agentmodel: PASSED: MIGraphX meets tolerance

     ✅ unet: PASSED: MIGraphX meets tolerance

     ✅ resnet50v1: PASSED: MIGraphX meets tolerance

     ✅ bert_base_cased_fp16: PASSED: MIGraphX meets tolerance

🔴bert_large_uncased_fp16: FAILED: MIGraphX is not within tolerance - check verbose output


     ✅ bert_large: PASSED: MIGraphX meets tolerance

     ✅ yolov5s: PASSED: MIGraphX meets tolerance

     ✅ tinyllama: PASSED: MIGraphX meets tolerance

     ✅ vicuna-fastchat: PASSED: MIGraphX meets tolerance

     ✅ whisper-tiny-encoder: PASSED: MIGraphX meets tolerance

     ✅ whisper-tiny-decoder: PASSED: MIGraphX meets tolerance

     ✅ distilgpt2_fp16: PASSED: MIGraphX meets tolerance

@causten causten merged commit c9bbea1 into develop Aug 23, 2024
22 of 23 checks passed
@causten causten deleted the skip-simplified-layernorm branch August 23, 2024 13:42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Onnx Operators Adding or modifying an Onnx Operator in the MIGraphX codebase
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add onnx parser for SkipSimplifiedLayerNormalization
6 participants