Skip to content

Commit 026dfad

Browse files
authored
onnx.MelWeightMatrix TorchOnnxToTorch (llvm#3503)
Just uploading what I have till now [Gist](https://gist.github.com/PhaneeshB/761f75f5522d9f4a40ef949a328e93fe) of pytorch impl that I'm following to implement the OnnxToTorch lowering Additional Details - (also pasted as comment in gist) [Op Description](https://github.com/onnx/onnx/blob/main/docs/Operators.md#melweightmatrix) in Onnx Documentation [Example](https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-93) Used the same example in this file. the Expected output is shown in the example [Reference Onnx Impl](https://github.com/onnx/onnx/blob/4c3ed5e08be75bbe1eeb6818e490b1b6a370183e/onnx/reference/ops/op_mel_weight_matrix.py#L13) - This is the base for the above code.
1 parent 334633b commit 026dfad

File tree

2 files changed

+483
-0
lines changed

2 files changed

+483
-0
lines changed

lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp

+367
Original file line numberDiff line numberDiff line change
@@ -591,6 +591,373 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
591591
binder.op, resultType, lhs, rhs);
592592
return success();
593593
});
594+
595+
patterns.onOp(
596+
"MelWeightMatrix", 17,
597+
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
598+
llvm::SmallVector<Value> operands;
599+
Torch::ValueTensorType resultType;
600+
int64_t output_dtype_attr;
601+
if (binder.tensorOperands(operands, 5) ||
602+
binder.tensorResultType(resultType) || operands.size() != 5 ||
603+
binder.s64IntegerAttr(output_dtype_attr, "output_datatype", 1)) {
604+
return failure();
605+
}
606+
// operands sequence :
607+
// num_mel_bins, dft_length, sample_rate -> int32/64 tensors
608+
// lower_edge_hertz, upper_edge_hertz -> f16/32/64
609+
610+
// Need to backtrack the values of num_mel_bins and dft_length//2+1 from
611+
// result shape since the inputs are tensors and we cannot know their
612+
// values at compile time. if the result type does not contain static
613+
// shapes, then the implementation will be unsupported.
614+
if (!resultType.areAllSizesKnown())
615+
return rewriter.notifyMatchFailure(
616+
binder.op, "Unknown result sizes, not supported.");
617+
618+
ArrayRef<int64_t> resShape = resultType.getSizes();
619+
if (resShape.size() != 2)
620+
return rewriter.notifyMatchFailure(
621+
binder.op,
622+
"Expected result rank to be 2, not supported for other ranks.");
623+
624+
std::optional<int64_t> torchDTypeInt =
625+
onnxDtypeIntToTorchDtypeInt(output_dtype_attr);
626+
if (!torchDTypeInt.has_value()) {
627+
return rewriter.notifyMatchFailure(
628+
binder.op, "conversion to given output dtype unsupported");
629+
}
630+
631+
// Here Onwards all shapes will be computed using these sizes
632+
int64_t numSpectrogramBinsInt = resShape[0];
633+
int64_t numMelBinsInt = resShape[1];
634+
Torch::ValueTensorType inputIntType = binder.toValidTensorType(
635+
operands[0].getType()); // Since operands[0 / 1 / 2] will have the
636+
// same int type.
637+
Torch::ValueTensorType inputFloatType = binder.toValidTensorType(
638+
operands[3].getType()); // Since operands[3 / 4] will have the same
639+
// float type
640+
641+
Value numMelBinsItem =
642+
getItemOp<Torch::IntType>(binder, rewriter, operands[0]);
643+
Value dftLengthItem =
644+
getItemOp<Torch::IntType>(binder, rewriter, operands[1]);
645+
Value sampleRateItem =
646+
getItemOp<Torch::IntType>(binder, rewriter, operands[2]);
647+
Value lowerEdgeHzItem =
648+
getItemOp<Torch::FloatType>(binder, rewriter, operands[3]);
649+
Value upperEdgeHzItem =
650+
getItemOp<Torch::FloatType>(binder, rewriter, operands[4]);
651+
652+
// Helpers
653+
ImplicitLocOpBuilder b(binder.getLoc(), rewriter);
654+
auto ctx = binder.op->getContext();
655+
656+
// Recurring shapes
657+
SmallVector<int64_t> unranked({});
658+
SmallVector<int64_t> shapeNMB({numMelBinsInt});
659+
SmallVector<int64_t> shapeNMBp2({numMelBinsInt + 2});
660+
SmallVector<int64_t> shape1xNMB({1, numMelBinsInt});
661+
SmallVector<int64_t> shapeNSB({numSpectrogramBinsInt});
662+
SmallVector<int64_t> shapeNSBxNMB(
663+
{numSpectrogramBinsInt, numMelBinsInt});
664+
665+
// Recurring DTypes
666+
Type inpFpDType = inputFloatType.getDtype();
667+
Type inpIntDType = inputIntType.getDtype();
668+
Type si32Ty = rewriter.getIntegerType(32, true);
669+
Type f32Ty = rewriter.getF32Type();
670+
Type i1Ty = rewriter.getI1Type();
671+
672+
// Value constants
673+
Value noneConst = b.create<Torch::ConstantNoneOp>();
674+
Value negTwoConst =
675+
b.create<Torch::ConstantIntOp>(rewriter.getI64IntegerAttr(-2));
676+
Value negOneConst =
677+
b.create<Torch::ConstantIntOp>(rewriter.getI64IntegerAttr(-1));
678+
Value zeroConst =
679+
b.create<Torch::ConstantIntOp>(rewriter.getI64IntegerAttr(0));
680+
Value oneConst =
681+
b.create<Torch::ConstantIntOp>(rewriter.getI64IntegerAttr(1));
682+
Value twoConst =
683+
b.create<Torch::ConstantIntOp>(rewriter.getI64IntegerAttr(2));
684+
Value float32DTypeConst =
685+
b.create<Torch::ConstantIntOp>(rewriter.getI64IntegerAttr(6));
686+
687+
Torch::ValueTensorType dftLenType =
688+
Torch::ValueTensorType::get(ctx, unranked, inpIntDType);
689+
Type freqBinsIntType =
690+
Torch::ValueTensorType::get(ctx, shapeNMBp2, si32Ty);
691+
Type freqBinsFltType =
692+
Torch::ValueTensorType::get(ctx, shapeNMBp2, f32Ty);
693+
694+
Value dftLengthDivTwoFlt =
695+
b.create<Torch::AtenDivIntOp>(dftLengthItem, twoConst);
696+
Value dftLengthDivTwo =
697+
b.create<Torch::AtenIntFloatOp>(dftLengthDivTwoFlt);
698+
Value numSpectrogramBins =
699+
b.create<Torch::AtenAddIntOp>(dftLengthDivTwo, oneConst);
700+
Value numSpectrogramBinsItem = numSpectrogramBins;
701+
Value freqBinsInit = b.create<Torch::AtenArangeOp>(
702+
freqBinsIntType, numMelBinsItem, /*dtype=*/float32DTypeConst,
703+
/*layout=*/noneConst, /*device=*/noneConst,
704+
/*pin_memory=*/noneConst);
705+
706+
// From Ref Impl of Onnx.MelWeightMatrix:
707+
// https://github.com/onnx/onnx/blob/main/onnx/reference/ops/op_mel_weight_matrix.py#L25-L32
708+
// convert input Freq Hz to Mel
709+
Value twoFiveNineFiveConst =
710+
b.create<Torch::ConstantFloatOp>(rewriter.getF64FloatAttr(2595));
711+
Value sevenHConst =
712+
b.create<Torch::ConstantFloatOp>(rewriter.getF64FloatAttr(700));
713+
Value tenConst =
714+
b.create<Torch::ConstantFloatOp>(rewriter.getF64FloatAttr(10));
715+
716+
Value lfDiv7Hfloat =
717+
b.create<Torch::AtenDivFloatOp>(lowerEdgeHzItem, sevenHConst);
718+
Type freqType = Torch::ValueTensorType::get(ctx, unranked, inpFpDType);
719+
Value lfDiv7H =
720+
b.create<Torch::PrimNumToTensorScalarOp>(freqType, lfDiv7Hfloat);
721+
Value lfDiv7HAdd1 = b.create<Torch::AtenAddScalarOp>(
722+
freqType, lfDiv7H, oneConst, /*alpha =*/oneConst);
723+
Value lfDiv7HAdd1Log10 =
724+
b.create<Torch::AtenLog10Op>(freqType, lfDiv7HAdd1);
725+
Value lfMel = b.create<Torch::AtenMulScalarOp>(
726+
freqType, lfDiv7HAdd1Log10, twoFiveNineFiveConst);
727+
728+
Value hfDiv7Hfloat =
729+
b.create<Torch::AtenDivFloatOp>(upperEdgeHzItem, sevenHConst);
730+
Value hfDiv7H =
731+
b.create<Torch::PrimNumToTensorScalarOp>(freqType, hfDiv7Hfloat);
732+
Value hfDiv7HAdd1 = b.create<Torch::AtenAddScalarOp>(
733+
freqType, hfDiv7H, oneConst, /*alpha =*/oneConst);
734+
Value hfDiv7HAdd1Log10 =
735+
b.create<Torch::AtenLog10Op>(freqType, hfDiv7HAdd1);
736+
Value hfMel = b.create<Torch::AtenMulScalarOp>(
737+
freqType, hfDiv7HAdd1Log10, twoFiveNineFiveConst);
738+
739+
Value hfSubLf = b.create<Torch::AtenSubTensorOp>(
740+
hfMel.getType(), hfMel, lfMel, /*alpha=*/oneConst);
741+
Value melStep = b.create<Torch::AtenDivScalarOp>(
742+
hfSubLf.getType(), hfSubLf, numMelBinsItem);
743+
744+
Value freqBinsMulMelStep = b.create<Torch::AtenMulTensorOp>(
745+
freqBinsFltType, freqBinsInit, melStep);
746+
Value freqBinsScaled = b.create<Torch::AtenAddTensorOp>(
747+
freqBinsFltType, freqBinsMulMelStep, lfMel, /*alpha=*/oneConst);
748+
749+
// Mel to Hz conv
750+
751+
Value fbDiv = b.create<Torch::AtenDivScalarOp>(
752+
freqBinsFltType, freqBinsScaled, twoFiveNineFiveConst);
753+
Value fbClone = b.create<Torch::AtenCloneOp>(
754+
freqBinsFltType, freqBinsScaled, /*memory_format=*/noneConst);
755+
Value tenTensor = b.create<Torch::AtenFillScalarOp>(freqBinsFltType,
756+
fbClone, tenConst);
757+
Value fbPow = b.create<Torch::AtenPowTensorTensorOp>(freqBinsFltType,
758+
tenTensor, fbDiv);
759+
Value fbPowSubOne = b.create<Torch::AtenSubScalarOp>(
760+
freqBinsFltType, fbPow, oneConst, /*alpha=*/oneConst);
761+
Value freqBinsHz = b.create<Torch::AtenMulScalarOp>(
762+
freqBinsFltType, fbPowSubOne, sevenHConst);
763+
764+
// Normalize freqBinsHz
765+
Value dftLenPlusOne = b.create<Torch::AtenAddScalarOp>(
766+
dftLenType, operands[1], oneConst, /*alpha=*/oneConst);
767+
Value dftLenPlusOneItem =
768+
getItemOp<Torch::IntType>(binder, rewriter, dftLenPlusOne);
769+
Value fbMulDft = b.create<Torch::AtenMulScalarOp>(
770+
freqBinsFltType, freqBinsHz, dftLenPlusOneItem);
771+
Value freqBinsNormalized = b.create<Torch::AtenDivScalarOp>(
772+
freqBinsFltType, fbMulDft, sampleRateItem);
773+
774+
// cast to int32
775+
Value int32DTypeConst =
776+
b.create<Torch::ConstantIntOp>(rewriter.getI64IntegerAttr(3));
777+
Value falseConst = b.create<Torch::ConstantBoolOp>(false);
778+
Value freqBins = b.create<Torch::AtenToDtypeOp>(
779+
freqBinsIntType, freqBinsNormalized, /*dtype=*/int32DTypeConst,
780+
/*non_blocking=*/falseConst, /*copy=*/falseConst,
781+
/*memory_format=*/noneConst);
782+
783+
Torch::ValueTensorType sliceResType =
784+
Torch::ValueTensorType::get(ctx, shapeNMB, si32Ty);
785+
Type unsqueezeResType =
786+
sliceResType.getWithSizesAndDtype(shape1xNMB, si32Ty);
787+
Value lfTensor = b.create<Torch::AtenSliceTensorOp>(
788+
sliceResType, freqBins, /*dim=*/zeroConst, /*start=*/zeroConst,
789+
/*end=*/negTwoConst, /*step=*/oneConst);
790+
Value lowFreqTensor = b.create<Torch::AtenUnsqueezeOp>(
791+
unsqueezeResType, lfTensor, /*dim=*/zeroConst);
792+
793+
Value cfTensor = b.create<Torch::AtenSliceTensorOp>(
794+
sliceResType, freqBins, /*dim=*/zeroConst, /*start=*/oneConst,
795+
/*end=*/negOneConst, /*step=*/oneConst);
796+
Value centerFreqTensor = b.create<Torch::AtenUnsqueezeOp>(
797+
unsqueezeResType, cfTensor, /*dim=*/zeroConst);
798+
799+
Value hfTensor = b.create<Torch::AtenSliceTensorOp>(
800+
sliceResType, freqBins, /*dim=*/zeroConst, /*start=*/zeroConst,
801+
/*end=*/noneConst, /*step=*/oneConst);
802+
Value highFreqTensor = b.create<Torch::AtenUnsqueezeOp>(
803+
unsqueezeResType, hfTensor, /*dim=*/zeroConst);
804+
805+
Value lowToCenter =
806+
b.create<Torch::AtenSubTensorOp>(unsqueezeResType, centerFreqTensor,
807+
lowFreqTensor, /*alpha=*/oneConst);
808+
Value centerToHigh = b.create<Torch::AtenSubTensorOp>(
809+
unsqueezeResType, highFreqTensor, centerFreqTensor,
810+
/*alpha=*/oneConst);
811+
812+
Type zeroToNInitType =
813+
inputIntType.getWithSizesAndDtype(shapeNSB, f32Ty);
814+
Value zeroToNInit = b.create<Torch::AtenArangeOp>(
815+
zeroToNInitType, numSpectrogramBinsItem,
816+
/*dtype=*/float32DTypeConst,
817+
/*layout=*/noneConst, /*device=*/noneConst,
818+
/*pin_memory=*/noneConst);
819+
820+
Type zeroToNBaseType = inputIntType.getWithSizesAndDtype(
821+
ArrayRef<int64_t>{numSpectrogramBinsInt, 1}, f32Ty);
822+
Value zeroToNBase = b.create<Torch::AtenUnsqueezeOp>(
823+
zeroToNBaseType, zeroToNInit, /*dim=*/oneConst);
824+
Type zeroToNumElesType =
825+
inputIntType.getWithSizesAndDtype(shapeNSBxNMB, f32Ty);
826+
Value expandShapeList = b.create<Torch::PrimListConstructOp>(
827+
rewriter.getType<Torch::ListType>(
828+
rewriter.getType<Torch::IntType>()),
829+
SmallVector<Value>{numSpectrogramBinsItem, numMelBinsItem});
830+
Value zeroToNumEles = b.create<Torch::AtenExpandOp>(
831+
zeroToNumElesType, zeroToNBase, expandShapeList,
832+
/*implicit=*/falseConst);
833+
834+
Type maskType = inputIntType.getWithSizesAndDtype(shape1xNMB, i1Ty);
835+
Value maskLowToCenterZero =
836+
b.create<Torch::AtenEqScalarOp>(maskType, lowToCenter, zeroConst);
837+
838+
// L2C computation
839+
Value lowToCenterNoZero = b.create<Torch::AtenWhereScalarSelfOp>(
840+
unsqueezeResType, maskLowToCenterZero, negOneConst, lowToCenter);
841+
Type maskL2CAfterCType =
842+
inputIntType.getWithSizesAndDtype(shapeNSBxNMB, i1Ty);
843+
Value maskL2CAfterC = b.create<Torch::AtenGtTensorOp>(
844+
maskL2CAfterCType, zeroToNumEles, centerFreqTensor);
845+
Type maxLFResTy =
846+
inputIntType.getWithSizesAndDtype(ArrayRef<int64_t>{1}, si32Ty);
847+
Value maxLowerFreq =
848+
b.create<Torch::AtenMaxOp>(maxLFResTy, lowFreqTensor);
849+
Value maxLowerFreqItem =
850+
getItemOp<Torch::IntType>(binder, rewriter, maxLowerFreq);
851+
Value zeroToNumElesL2C = b.create<Torch::AtenWhereScalarSelfOp>(
852+
zeroToNumElesType, maskLowToCenterZero, maxLowerFreqItem,
853+
zeroToNumEles);
854+
Value upslopeDiff = b.create<Torch::AtenSubTensorOp>(
855+
zeroToNumElesType, zeroToNumElesL2C, lowFreqTensor,
856+
/*alpha=*/oneConst);
857+
Type l2cNZFltTy = inputIntType.getWithSizesAndDtype(shape1xNMB, f32Ty);
858+
Value l2cNZFlt = b.create<Torch::AtenToDtypeOp>(
859+
l2cNZFltTy, lowToCenterNoZero, /*dtype=*/float32DTypeConst,
860+
/*non_blocking=*/falseConst, /*copy=*/falseConst,
861+
/*memory_format=*/noneConst);
862+
Value upslopeL2C0 = b.create<Torch::AtenDivTensorOp>(
863+
zeroToNumElesType, upslopeDiff, l2cNZFlt);
864+
Type maskUpslopeL2C0PosType =
865+
inputIntType.getWithSizesAndDtype(shapeNSBxNMB, i1Ty);
866+
Value maskUpslopeL2C0Pos = b.create<Torch::AtenGtScalarOp>(
867+
maskUpslopeL2C0PosType, upslopeL2C0, zeroConst);
868+
Value upslopeL2C0PosRanged = b.create<Torch::AtenWhereScalarOtherOp>(
869+
zeroToNumElesType, maskUpslopeL2C0Pos, upslopeL2C0, zeroConst);
870+
Value maskIdxL2CAfterCList = b.create<Torch::PrimListConstructOp>(
871+
rewriter.getType<Torch::ListType>(maskL2CAfterC.getType()),
872+
ValueRange{maskL2CAfterC});
873+
Value zeroConstTensor = Torch::createRank0Tensor(
874+
rewriter, binder.getLoc(),
875+
Torch::ValueTensorType::get(ctx, std::nullopt, f32Ty), zeroConst);
876+
Value upslopeL2C1 = b.create<Torch::AtenIndexPutOp>(
877+
zeroToNumElesType, upslopeL2C0PosRanged, maskIdxL2CAfterCList,
878+
zeroConstTensor, falseConst);
879+
Value maskIdxL2CZeroList = b.create<Torch::PrimListConstructOp>(
880+
rewriter.getType<Torch::ListType>(maskLowToCenterZero.getType()),
881+
ValueRange{maskLowToCenterZero});
882+
Type centerFreqTensorL2CZeroType =
883+
inputIntType.getWithSizesAndDtype(ArrayRef<int64_t>{-1}, si32Ty);
884+
Value centerFreqTensorL2CZero = b.create<Torch::AtenIndexTensorOp>(
885+
centerFreqTensorL2CZeroType, centerFreqTensor, maskIdxL2CZeroList);
886+
Type maskSqueezeType =
887+
inputIntType.getWithSizesAndDtype(shapeNMB, i1Ty);
888+
Value maskLowToCenterZeroSqueeze = b.create<Torch::AtenSqueezeOp>(
889+
maskSqueezeType, maskLowToCenterZero);
890+
Type maskL2CIntTy = inputIntType.getWithSizesAndDtype(shapeNMB, si32Ty);
891+
Value maskLowToCenterInt = b.create<Torch::AtenToDtypeOp>(
892+
maskL2CIntTy, maskLowToCenterZeroSqueeze, /*dtype=*/int32DTypeConst,
893+
/*non_blocking=*/falseConst, /*copy=*/falseConst,
894+
/*memory_format=*/noneConst);
895+
Value upslopeOneIdxList = b.create<Torch::PrimListConstructOp>(
896+
rewriter.getType<Torch::ListType>(
897+
centerFreqTensorL2CZero.getType()),
898+
ValueRange{centerFreqTensorL2CZero, maskLowToCenterInt});
899+
Value oneConstTensor = Torch::createRank0Tensor(
900+
rewriter, binder.getLoc(),
901+
Torch::ValueTensorType::get(ctx, std::nullopt, f32Ty), oneConst);
902+
Value upslopeL2C = b.create<Torch::AtenIndexPutOp>(
903+
zeroToNumElesType, upslopeL2C1, upslopeOneIdxList, oneConstTensor,
904+
falseConst);
905+
906+
// H2C computation
907+
Value maskCenterToHighZero =
908+
b.create<Torch::AtenEqScalarOp>(maskType, centerToHigh, zeroConst);
909+
Value maskH2CBeforeC = b.create<Torch::AtenLtTensorOp>(
910+
maskL2CAfterCType, zeroToNumEles, centerFreqTensor);
911+
Value centerToHighNoZero = b.create<Torch::AtenWhereScalarSelfOp>(
912+
unsqueezeResType, maskCenterToHighZero, negOneConst, centerToHigh);
913+
Value c2hNZFlt = b.create<Torch::AtenToDtypeOp>(
914+
l2cNZFltTy, centerToHighNoZero, /*dtype=*/float32DTypeConst,
915+
/*non_blocking=*/falseConst, /*copy=*/falseConst,
916+
/*memory_format=*/noneConst);
917+
Value zeroToNumElesC2H = b.create<Torch::AtenWhereScalarSelfOp>(
918+
zeroToNumElesType, maskCenterToHighZero, zeroConst, zeroToNumEles);
919+
Value downslopeDiff = b.create<Torch::AtenSubTensorOp>(
920+
zeroToNumElesType, highFreqTensor, zeroToNumElesC2H,
921+
/*alpha=*/oneConst);
922+
Value downslopeC2H0 = b.create<Torch::AtenDivTensorOp>(
923+
zeroToNumElesType, downslopeDiff, c2hNZFlt);
924+
Value maskDownslopeC2H0Pos = b.create<Torch::AtenGtScalarOp>(
925+
maskUpslopeL2C0PosType, downslopeC2H0, zeroConst);
926+
Value downslopeC2H0Pos = b.create<Torch::AtenWhereScalarOtherOp>(
927+
zeroToNumElesType, maskDownslopeC2H0Pos, downslopeC2H0, zeroConst);
928+
Value idxH2CBeforeCList = b.create<Torch::PrimListConstructOp>(
929+
rewriter.getType<Torch::ListType>(maskH2CBeforeC.getType()),
930+
ValueRange{maskH2CBeforeC});
931+
Value downslopeC2H = b.create<Torch::AtenIndexPutOp>(
932+
zeroToNumElesType, downslopeC2H0Pos, idxH2CBeforeCList,
933+
zeroConstTensor, falseConst);
934+
935+
// final result Calculation
936+
Value maskH2CNonZero = b.create<Torch::AtenNeScalarOp>(
937+
maskL2CAfterCType, downslopeC2H, zeroConst);
938+
Value idxH2CNZList = b.create<Torch::PrimListConstructOp>(
939+
rewriter.getType<Torch::ListType>(maskH2CNonZero.getType()),
940+
ValueRange{maskH2CNonZero});
941+
Value upslopeL2CMasked = b.create<Torch::AtenIndexPutOp>(
942+
zeroToNumElesType, upslopeL2C, idxH2CNZList, zeroConstTensor,
943+
falseConst);
944+
945+
Value slopesFinal = b.create<Torch::AtenAddTensorOp>(
946+
zeroToNumElesType, upslopeL2CMasked, downslopeC2H,
947+
/*alpha=*/oneConst);
948+
949+
Value outputDTypeConst = b.create<Torch::ConstantIntOp>(
950+
rewriter.getType<Torch::IntType>(),
951+
rewriter.getI64IntegerAttr(torchDTypeInt.value()));
952+
Value finalOutput = b.create<Torch::AtenToDtypeOp>(
953+
resultType, slopesFinal, /*dtype=*/outputDTypeConst,
954+
/*non_blocking=*/falseConst, /*copy=*/falseConst,
955+
/*memory_format=*/noneConst);
956+
957+
rewriter.replaceOp(binder.op, finalOutput);
958+
return success();
959+
});
960+
594961
patterns.onOp(
595962
"Multinomial", 7,
596963
[](OpBinder binder, ConversionPatternRewriter &rewriter) {

0 commit comments

Comments
 (0)