@@ -591,6 +591,373 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
591
591
binder.op , resultType, lhs, rhs);
592
592
return success ();
593
593
});
594
+
595
+ patterns.onOp (
596
+ " MelWeightMatrix" , 17 ,
597
+ [](OpBinder binder, ConversionPatternRewriter &rewriter) {
598
+ llvm::SmallVector<Value> operands;
599
+ Torch::ValueTensorType resultType;
600
+ int64_t output_dtype_attr;
601
+ if (binder.tensorOperands (operands, 5 ) ||
602
+ binder.tensorResultType (resultType) || operands.size () != 5 ||
603
+ binder.s64IntegerAttr (output_dtype_attr, " output_datatype" , 1 )) {
604
+ return failure ();
605
+ }
606
+ // operands sequence :
607
+ // num_mel_bins, dft_length, sample_rate -> int32/64 tensors
608
+ // lower_edge_hertz, upper_edge_hertz -> f16/32/64
609
+
610
+ // Need to backtrack the values of num_mel_bins and dft_length//2+1 from
611
+ // result shape since the inputs are tensors and we cannot know their
612
+ // values at compile time. if the result type does not contain static
613
+ // shapes, then the implementation will be unsupported.
614
+ if (!resultType.areAllSizesKnown ())
615
+ return rewriter.notifyMatchFailure (
616
+ binder.op , " Unknown result sizes, not supported." );
617
+
618
+ ArrayRef<int64_t > resShape = resultType.getSizes ();
619
+ if (resShape.size () != 2 )
620
+ return rewriter.notifyMatchFailure (
621
+ binder.op ,
622
+ " Expected result rank to be 2, not supported for other ranks." );
623
+
624
+ std::optional<int64_t > torchDTypeInt =
625
+ onnxDtypeIntToTorchDtypeInt (output_dtype_attr);
626
+ if (!torchDTypeInt.has_value ()) {
627
+ return rewriter.notifyMatchFailure (
628
+ binder.op , " conversion to given output dtype unsupported" );
629
+ }
630
+
631
+ // Here Onwards all shapes will be computed using these sizes
632
+ int64_t numSpectrogramBinsInt = resShape[0 ];
633
+ int64_t numMelBinsInt = resShape[1 ];
634
+ Torch::ValueTensorType inputIntType = binder.toValidTensorType (
635
+ operands[0 ].getType ()); // Since operands[0 / 1 / 2] will have the
636
+ // same int type.
637
+ Torch::ValueTensorType inputFloatType = binder.toValidTensorType (
638
+ operands[3 ].getType ()); // Since operands[3 / 4] will have the same
639
+ // float type
640
+
641
+ Value numMelBinsItem =
642
+ getItemOp<Torch::IntType>(binder, rewriter, operands[0 ]);
643
+ Value dftLengthItem =
644
+ getItemOp<Torch::IntType>(binder, rewriter, operands[1 ]);
645
+ Value sampleRateItem =
646
+ getItemOp<Torch::IntType>(binder, rewriter, operands[2 ]);
647
+ Value lowerEdgeHzItem =
648
+ getItemOp<Torch::FloatType>(binder, rewriter, operands[3 ]);
649
+ Value upperEdgeHzItem =
650
+ getItemOp<Torch::FloatType>(binder, rewriter, operands[4 ]);
651
+
652
+ // Helpers
653
+ ImplicitLocOpBuilder b (binder.getLoc (), rewriter);
654
+ auto ctx = binder.op ->getContext ();
655
+
656
+ // Recurring shapes
657
+ SmallVector<int64_t > unranked ({});
658
+ SmallVector<int64_t > shapeNMB ({numMelBinsInt});
659
+ SmallVector<int64_t > shapeNMBp2 ({numMelBinsInt + 2 });
660
+ SmallVector<int64_t > shape1xNMB ({1 , numMelBinsInt});
661
+ SmallVector<int64_t > shapeNSB ({numSpectrogramBinsInt});
662
+ SmallVector<int64_t > shapeNSBxNMB (
663
+ {numSpectrogramBinsInt, numMelBinsInt});
664
+
665
+ // Recurring DTypes
666
+ Type inpFpDType = inputFloatType.getDtype ();
667
+ Type inpIntDType = inputIntType.getDtype ();
668
+ Type si32Ty = rewriter.getIntegerType (32 , true );
669
+ Type f32Ty = rewriter.getF32Type ();
670
+ Type i1Ty = rewriter.getI1Type ();
671
+
672
+ // Value constants
673
+ Value noneConst = b.create <Torch::ConstantNoneOp>();
674
+ Value negTwoConst =
675
+ b.create <Torch::ConstantIntOp>(rewriter.getI64IntegerAttr (-2 ));
676
+ Value negOneConst =
677
+ b.create <Torch::ConstantIntOp>(rewriter.getI64IntegerAttr (-1 ));
678
+ Value zeroConst =
679
+ b.create <Torch::ConstantIntOp>(rewriter.getI64IntegerAttr (0 ));
680
+ Value oneConst =
681
+ b.create <Torch::ConstantIntOp>(rewriter.getI64IntegerAttr (1 ));
682
+ Value twoConst =
683
+ b.create <Torch::ConstantIntOp>(rewriter.getI64IntegerAttr (2 ));
684
+ Value float32DTypeConst =
685
+ b.create <Torch::ConstantIntOp>(rewriter.getI64IntegerAttr (6 ));
686
+
687
+ Torch::ValueTensorType dftLenType =
688
+ Torch::ValueTensorType::get (ctx, unranked, inpIntDType);
689
+ Type freqBinsIntType =
690
+ Torch::ValueTensorType::get (ctx, shapeNMBp2, si32Ty);
691
+ Type freqBinsFltType =
692
+ Torch::ValueTensorType::get (ctx, shapeNMBp2, f32Ty);
693
+
694
+ Value dftLengthDivTwoFlt =
695
+ b.create <Torch::AtenDivIntOp>(dftLengthItem, twoConst);
696
+ Value dftLengthDivTwo =
697
+ b.create <Torch::AtenIntFloatOp>(dftLengthDivTwoFlt);
698
+ Value numSpectrogramBins =
699
+ b.create <Torch::AtenAddIntOp>(dftLengthDivTwo, oneConst);
700
+ Value numSpectrogramBinsItem = numSpectrogramBins;
701
+ Value freqBinsInit = b.create <Torch::AtenArangeOp>(
702
+ freqBinsIntType, numMelBinsItem, /* dtype=*/ float32DTypeConst,
703
+ /* layout=*/ noneConst, /* device=*/ noneConst,
704
+ /* pin_memory=*/ noneConst);
705
+
706
+ // From Ref Impl of Onnx.MelWeightMatrix:
707
+ // https://github.com/onnx/onnx/blob/main/onnx/reference/ops/op_mel_weight_matrix.py#L25-L32
708
+ // convert input Freq Hz to Mel
709
+ Value twoFiveNineFiveConst =
710
+ b.create <Torch::ConstantFloatOp>(rewriter.getF64FloatAttr (2595 ));
711
+ Value sevenHConst =
712
+ b.create <Torch::ConstantFloatOp>(rewriter.getF64FloatAttr (700 ));
713
+ Value tenConst =
714
+ b.create <Torch::ConstantFloatOp>(rewriter.getF64FloatAttr (10 ));
715
+
716
+ Value lfDiv7Hfloat =
717
+ b.create <Torch::AtenDivFloatOp>(lowerEdgeHzItem, sevenHConst);
718
+ Type freqType = Torch::ValueTensorType::get (ctx, unranked, inpFpDType);
719
+ Value lfDiv7H =
720
+ b.create <Torch::PrimNumToTensorScalarOp>(freqType, lfDiv7Hfloat);
721
+ Value lfDiv7HAdd1 = b.create <Torch::AtenAddScalarOp>(
722
+ freqType, lfDiv7H, oneConst, /* alpha =*/ oneConst);
723
+ Value lfDiv7HAdd1Log10 =
724
+ b.create <Torch::AtenLog10Op>(freqType, lfDiv7HAdd1);
725
+ Value lfMel = b.create <Torch::AtenMulScalarOp>(
726
+ freqType, lfDiv7HAdd1Log10, twoFiveNineFiveConst);
727
+
728
+ Value hfDiv7Hfloat =
729
+ b.create <Torch::AtenDivFloatOp>(upperEdgeHzItem, sevenHConst);
730
+ Value hfDiv7H =
731
+ b.create <Torch::PrimNumToTensorScalarOp>(freqType, hfDiv7Hfloat);
732
+ Value hfDiv7HAdd1 = b.create <Torch::AtenAddScalarOp>(
733
+ freqType, hfDiv7H, oneConst, /* alpha =*/ oneConst);
734
+ Value hfDiv7HAdd1Log10 =
735
+ b.create <Torch::AtenLog10Op>(freqType, hfDiv7HAdd1);
736
+ Value hfMel = b.create <Torch::AtenMulScalarOp>(
737
+ freqType, hfDiv7HAdd1Log10, twoFiveNineFiveConst);
738
+
739
+ Value hfSubLf = b.create <Torch::AtenSubTensorOp>(
740
+ hfMel.getType (), hfMel, lfMel, /* alpha=*/ oneConst);
741
+ Value melStep = b.create <Torch::AtenDivScalarOp>(
742
+ hfSubLf.getType (), hfSubLf, numMelBinsItem);
743
+
744
+ Value freqBinsMulMelStep = b.create <Torch::AtenMulTensorOp>(
745
+ freqBinsFltType, freqBinsInit, melStep);
746
+ Value freqBinsScaled = b.create <Torch::AtenAddTensorOp>(
747
+ freqBinsFltType, freqBinsMulMelStep, lfMel, /* alpha=*/ oneConst);
748
+
749
+ // Mel to Hz conv
750
+
751
+ Value fbDiv = b.create <Torch::AtenDivScalarOp>(
752
+ freqBinsFltType, freqBinsScaled, twoFiveNineFiveConst);
753
+ Value fbClone = b.create <Torch::AtenCloneOp>(
754
+ freqBinsFltType, freqBinsScaled, /* memory_format=*/ noneConst);
755
+ Value tenTensor = b.create <Torch::AtenFillScalarOp>(freqBinsFltType,
756
+ fbClone, tenConst);
757
+ Value fbPow = b.create <Torch::AtenPowTensorTensorOp>(freqBinsFltType,
758
+ tenTensor, fbDiv);
759
+ Value fbPowSubOne = b.create <Torch::AtenSubScalarOp>(
760
+ freqBinsFltType, fbPow, oneConst, /* alpha=*/ oneConst);
761
+ Value freqBinsHz = b.create <Torch::AtenMulScalarOp>(
762
+ freqBinsFltType, fbPowSubOne, sevenHConst);
763
+
764
+ // Normalize freqBinsHz
765
+ Value dftLenPlusOne = b.create <Torch::AtenAddScalarOp>(
766
+ dftLenType, operands[1 ], oneConst, /* alpha=*/ oneConst);
767
+ Value dftLenPlusOneItem =
768
+ getItemOp<Torch::IntType>(binder, rewriter, dftLenPlusOne);
769
+ Value fbMulDft = b.create <Torch::AtenMulScalarOp>(
770
+ freqBinsFltType, freqBinsHz, dftLenPlusOneItem);
771
+ Value freqBinsNormalized = b.create <Torch::AtenDivScalarOp>(
772
+ freqBinsFltType, fbMulDft, sampleRateItem);
773
+
774
+ // cast to int32
775
+ Value int32DTypeConst =
776
+ b.create <Torch::ConstantIntOp>(rewriter.getI64IntegerAttr (3 ));
777
+ Value falseConst = b.create <Torch::ConstantBoolOp>(false );
778
+ Value freqBins = b.create <Torch::AtenToDtypeOp>(
779
+ freqBinsIntType, freqBinsNormalized, /* dtype=*/ int32DTypeConst,
780
+ /* non_blocking=*/ falseConst, /* copy=*/ falseConst,
781
+ /* memory_format=*/ noneConst);
782
+
783
+ Torch::ValueTensorType sliceResType =
784
+ Torch::ValueTensorType::get (ctx, shapeNMB, si32Ty);
785
+ Type unsqueezeResType =
786
+ sliceResType.getWithSizesAndDtype (shape1xNMB, si32Ty);
787
+ Value lfTensor = b.create <Torch::AtenSliceTensorOp>(
788
+ sliceResType, freqBins, /* dim=*/ zeroConst, /* start=*/ zeroConst,
789
+ /* end=*/ negTwoConst, /* step=*/ oneConst);
790
+ Value lowFreqTensor = b.create <Torch::AtenUnsqueezeOp>(
791
+ unsqueezeResType, lfTensor, /* dim=*/ zeroConst);
792
+
793
+ Value cfTensor = b.create <Torch::AtenSliceTensorOp>(
794
+ sliceResType, freqBins, /* dim=*/ zeroConst, /* start=*/ oneConst,
795
+ /* end=*/ negOneConst, /* step=*/ oneConst);
796
+ Value centerFreqTensor = b.create <Torch::AtenUnsqueezeOp>(
797
+ unsqueezeResType, cfTensor, /* dim=*/ zeroConst);
798
+
799
+ Value hfTensor = b.create <Torch::AtenSliceTensorOp>(
800
+ sliceResType, freqBins, /* dim=*/ zeroConst, /* start=*/ zeroConst,
801
+ /* end=*/ noneConst, /* step=*/ oneConst);
802
+ Value highFreqTensor = b.create <Torch::AtenUnsqueezeOp>(
803
+ unsqueezeResType, hfTensor, /* dim=*/ zeroConst);
804
+
805
+ Value lowToCenter =
806
+ b.create <Torch::AtenSubTensorOp>(unsqueezeResType, centerFreqTensor,
807
+ lowFreqTensor, /* alpha=*/ oneConst);
808
+ Value centerToHigh = b.create <Torch::AtenSubTensorOp>(
809
+ unsqueezeResType, highFreqTensor, centerFreqTensor,
810
+ /* alpha=*/ oneConst);
811
+
812
+ Type zeroToNInitType =
813
+ inputIntType.getWithSizesAndDtype (shapeNSB, f32Ty);
814
+ Value zeroToNInit = b.create <Torch::AtenArangeOp>(
815
+ zeroToNInitType, numSpectrogramBinsItem,
816
+ /* dtype=*/ float32DTypeConst,
817
+ /* layout=*/ noneConst, /* device=*/ noneConst,
818
+ /* pin_memory=*/ noneConst);
819
+
820
+ Type zeroToNBaseType = inputIntType.getWithSizesAndDtype (
821
+ ArrayRef<int64_t >{numSpectrogramBinsInt, 1 }, f32Ty);
822
+ Value zeroToNBase = b.create <Torch::AtenUnsqueezeOp>(
823
+ zeroToNBaseType, zeroToNInit, /* dim=*/ oneConst);
824
+ Type zeroToNumElesType =
825
+ inputIntType.getWithSizesAndDtype (shapeNSBxNMB, f32Ty);
826
+ Value expandShapeList = b.create <Torch::PrimListConstructOp>(
827
+ rewriter.getType <Torch::ListType>(
828
+ rewriter.getType <Torch::IntType>()),
829
+ SmallVector<Value>{numSpectrogramBinsItem, numMelBinsItem});
830
+ Value zeroToNumEles = b.create <Torch::AtenExpandOp>(
831
+ zeroToNumElesType, zeroToNBase, expandShapeList,
832
+ /* implicit=*/ falseConst);
833
+
834
+ Type maskType = inputIntType.getWithSizesAndDtype (shape1xNMB, i1Ty);
835
+ Value maskLowToCenterZero =
836
+ b.create <Torch::AtenEqScalarOp>(maskType, lowToCenter, zeroConst);
837
+
838
+ // L2C computation
839
+ Value lowToCenterNoZero = b.create <Torch::AtenWhereScalarSelfOp>(
840
+ unsqueezeResType, maskLowToCenterZero, negOneConst, lowToCenter);
841
+ Type maskL2CAfterCType =
842
+ inputIntType.getWithSizesAndDtype (shapeNSBxNMB, i1Ty);
843
+ Value maskL2CAfterC = b.create <Torch::AtenGtTensorOp>(
844
+ maskL2CAfterCType, zeroToNumEles, centerFreqTensor);
845
+ Type maxLFResTy =
846
+ inputIntType.getWithSizesAndDtype (ArrayRef<int64_t >{1 }, si32Ty);
847
+ Value maxLowerFreq =
848
+ b.create <Torch::AtenMaxOp>(maxLFResTy, lowFreqTensor);
849
+ Value maxLowerFreqItem =
850
+ getItemOp<Torch::IntType>(binder, rewriter, maxLowerFreq);
851
+ Value zeroToNumElesL2C = b.create <Torch::AtenWhereScalarSelfOp>(
852
+ zeroToNumElesType, maskLowToCenterZero, maxLowerFreqItem,
853
+ zeroToNumEles);
854
+ Value upslopeDiff = b.create <Torch::AtenSubTensorOp>(
855
+ zeroToNumElesType, zeroToNumElesL2C, lowFreqTensor,
856
+ /* alpha=*/ oneConst);
857
+ Type l2cNZFltTy = inputIntType.getWithSizesAndDtype (shape1xNMB, f32Ty);
858
+ Value l2cNZFlt = b.create <Torch::AtenToDtypeOp>(
859
+ l2cNZFltTy, lowToCenterNoZero, /* dtype=*/ float32DTypeConst,
860
+ /* non_blocking=*/ falseConst, /* copy=*/ falseConst,
861
+ /* memory_format=*/ noneConst);
862
+ Value upslopeL2C0 = b.create <Torch::AtenDivTensorOp>(
863
+ zeroToNumElesType, upslopeDiff, l2cNZFlt);
864
+ Type maskUpslopeL2C0PosType =
865
+ inputIntType.getWithSizesAndDtype (shapeNSBxNMB, i1Ty);
866
+ Value maskUpslopeL2C0Pos = b.create <Torch::AtenGtScalarOp>(
867
+ maskUpslopeL2C0PosType, upslopeL2C0, zeroConst);
868
+ Value upslopeL2C0PosRanged = b.create <Torch::AtenWhereScalarOtherOp>(
869
+ zeroToNumElesType, maskUpslopeL2C0Pos, upslopeL2C0, zeroConst);
870
+ Value maskIdxL2CAfterCList = b.create <Torch::PrimListConstructOp>(
871
+ rewriter.getType <Torch::ListType>(maskL2CAfterC.getType ()),
872
+ ValueRange{maskL2CAfterC});
873
+ Value zeroConstTensor = Torch::createRank0Tensor (
874
+ rewriter, binder.getLoc (),
875
+ Torch::ValueTensorType::get (ctx, std::nullopt, f32Ty), zeroConst);
876
+ Value upslopeL2C1 = b.create <Torch::AtenIndexPutOp>(
877
+ zeroToNumElesType, upslopeL2C0PosRanged, maskIdxL2CAfterCList,
878
+ zeroConstTensor, falseConst);
879
+ Value maskIdxL2CZeroList = b.create <Torch::PrimListConstructOp>(
880
+ rewriter.getType <Torch::ListType>(maskLowToCenterZero.getType ()),
881
+ ValueRange{maskLowToCenterZero});
882
+ Type centerFreqTensorL2CZeroType =
883
+ inputIntType.getWithSizesAndDtype (ArrayRef<int64_t >{-1 }, si32Ty);
884
+ Value centerFreqTensorL2CZero = b.create <Torch::AtenIndexTensorOp>(
885
+ centerFreqTensorL2CZeroType, centerFreqTensor, maskIdxL2CZeroList);
886
+ Type maskSqueezeType =
887
+ inputIntType.getWithSizesAndDtype (shapeNMB, i1Ty);
888
+ Value maskLowToCenterZeroSqueeze = b.create <Torch::AtenSqueezeOp>(
889
+ maskSqueezeType, maskLowToCenterZero);
890
+ Type maskL2CIntTy = inputIntType.getWithSizesAndDtype (shapeNMB, si32Ty);
891
+ Value maskLowToCenterInt = b.create <Torch::AtenToDtypeOp>(
892
+ maskL2CIntTy, maskLowToCenterZeroSqueeze, /* dtype=*/ int32DTypeConst,
893
+ /* non_blocking=*/ falseConst, /* copy=*/ falseConst,
894
+ /* memory_format=*/ noneConst);
895
+ Value upslopeOneIdxList = b.create <Torch::PrimListConstructOp>(
896
+ rewriter.getType <Torch::ListType>(
897
+ centerFreqTensorL2CZero.getType ()),
898
+ ValueRange{centerFreqTensorL2CZero, maskLowToCenterInt});
899
+ Value oneConstTensor = Torch::createRank0Tensor (
900
+ rewriter, binder.getLoc (),
901
+ Torch::ValueTensorType::get (ctx, std::nullopt, f32Ty), oneConst);
902
+ Value upslopeL2C = b.create <Torch::AtenIndexPutOp>(
903
+ zeroToNumElesType, upslopeL2C1, upslopeOneIdxList, oneConstTensor,
904
+ falseConst);
905
+
906
+ // H2C computation
907
+ Value maskCenterToHighZero =
908
+ b.create <Torch::AtenEqScalarOp>(maskType, centerToHigh, zeroConst);
909
+ Value maskH2CBeforeC = b.create <Torch::AtenLtTensorOp>(
910
+ maskL2CAfterCType, zeroToNumEles, centerFreqTensor);
911
+ Value centerToHighNoZero = b.create <Torch::AtenWhereScalarSelfOp>(
912
+ unsqueezeResType, maskCenterToHighZero, negOneConst, centerToHigh);
913
+ Value c2hNZFlt = b.create <Torch::AtenToDtypeOp>(
914
+ l2cNZFltTy, centerToHighNoZero, /* dtype=*/ float32DTypeConst,
915
+ /* non_blocking=*/ falseConst, /* copy=*/ falseConst,
916
+ /* memory_format=*/ noneConst);
917
+ Value zeroToNumElesC2H = b.create <Torch::AtenWhereScalarSelfOp>(
918
+ zeroToNumElesType, maskCenterToHighZero, zeroConst, zeroToNumEles);
919
+ Value downslopeDiff = b.create <Torch::AtenSubTensorOp>(
920
+ zeroToNumElesType, highFreqTensor, zeroToNumElesC2H,
921
+ /* alpha=*/ oneConst);
922
+ Value downslopeC2H0 = b.create <Torch::AtenDivTensorOp>(
923
+ zeroToNumElesType, downslopeDiff, c2hNZFlt);
924
+ Value maskDownslopeC2H0Pos = b.create <Torch::AtenGtScalarOp>(
925
+ maskUpslopeL2C0PosType, downslopeC2H0, zeroConst);
926
+ Value downslopeC2H0Pos = b.create <Torch::AtenWhereScalarOtherOp>(
927
+ zeroToNumElesType, maskDownslopeC2H0Pos, downslopeC2H0, zeroConst);
928
+ Value idxH2CBeforeCList = b.create <Torch::PrimListConstructOp>(
929
+ rewriter.getType <Torch::ListType>(maskH2CBeforeC.getType ()),
930
+ ValueRange{maskH2CBeforeC});
931
+ Value downslopeC2H = b.create <Torch::AtenIndexPutOp>(
932
+ zeroToNumElesType, downslopeC2H0Pos, idxH2CBeforeCList,
933
+ zeroConstTensor, falseConst);
934
+
935
+ // final result Calculation
936
+ Value maskH2CNonZero = b.create <Torch::AtenNeScalarOp>(
937
+ maskL2CAfterCType, downslopeC2H, zeroConst);
938
+ Value idxH2CNZList = b.create <Torch::PrimListConstructOp>(
939
+ rewriter.getType <Torch::ListType>(maskH2CNonZero.getType ()),
940
+ ValueRange{maskH2CNonZero});
941
+ Value upslopeL2CMasked = b.create <Torch::AtenIndexPutOp>(
942
+ zeroToNumElesType, upslopeL2C, idxH2CNZList, zeroConstTensor,
943
+ falseConst);
944
+
945
+ Value slopesFinal = b.create <Torch::AtenAddTensorOp>(
946
+ zeroToNumElesType, upslopeL2CMasked, downslopeC2H,
947
+ /* alpha=*/ oneConst);
948
+
949
+ Value outputDTypeConst = b.create <Torch::ConstantIntOp>(
950
+ rewriter.getType <Torch::IntType>(),
951
+ rewriter.getI64IntegerAttr (torchDTypeInt.value ()));
952
+ Value finalOutput = b.create <Torch::AtenToDtypeOp>(
953
+ resultType, slopesFinal, /* dtype=*/ outputDTypeConst,
954
+ /* non_blocking=*/ falseConst, /* copy=*/ falseConst,
955
+ /* memory_format=*/ noneConst);
956
+
957
+ rewriter.replaceOp (binder.op , finalOutput);
958
+ return success ();
959
+ });
960
+
594
961
patterns.onOp (
595
962
" Multinomial" , 7 ,
596
963
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
0 commit comments