Skip to content

Commit

Permalink
fix naming, comments
Browse files Browse the repository at this point in the history
  • Loading branch information
mvpant committed Dec 23, 2024
1 parent 3ac0c23 commit 147c226
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 37 deletions.
10 changes: 7 additions & 3 deletions stablehlo/testdata/bn_conv_fuse_float32.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,17 @@ module @jit_main attributes {torch.debug_module_name = "ResNet"} {
%cst_1 = stablehlo.constant dense<[1.01694489, 3.71674347, 5.81334356E-11, 3.28254271, 1.71074404E-13, 0.658226967, 4.37006235, 6.60045282E-12, 0.915522992, 1.93175254E-9, 4.12558556, 2.74399233, 2.8390913, 4.79658588E-8, 11.0722713, 0.500745952, 2.23128176, 4.82570696, 2.69861364, 9.36995506, 3.73391747, 5.48429585, 5.7126689, 0.445444882, 0.436275303, 7.15633583, 13.7179089, 5.25117493, 6.81737518, 1.67235756, 1.65343034, 1.23245978, 4.90762854, 3.07305121, 4.23838568, 4.99363518, 1.44646307E-12, 1.52116203, 1.03519833E-13, 0.351344079, 0.17024748, 1.42054474, 1.90848303, 2.15124035, 2.66084933, 4.84443378, 1.92971194, 1.49994361, 2.94806145E-13, 1.53064024, 0.365027189, 2.93755412, 5.46641159, 0.707924544, 3.33150721, 0.771802961, 2.40678358, 6.5213666, 4.12625027, 1.05063522, 2.95303202, 11.3656216, 4.76904678, 1.65587807]> : tensor<64xf32>
%cst_2 = stablehlo.constant dense<[0.234872743, 0.266257942, -5.10959595E-8, 0.518699706, 3.44040196E-9, 0.222385287, 0.422887057, 1.31532403E-7, 0.25093165, 1.5152026E-6, 0.316871643, 0.250491828, 0.378926098, 1.08618351E-5, 2.752640e-01, 0.236741036, 0.242021769, 0.395314813, 0.469346285, 0.2908957, 0.272684187, 0.27802828, 0.290692091, 0.206927493, 0.258990377, 0.278710574, 0.291149527, 0.316013753, 0.388891488, 0.304111898, 0.267757207, 0.210925162, 0.287084132, 0.332426429, 0.42672804, 0.373260558, 7.48037578E-8, 0.19067812, 1.47401256E-8, 0.223029822, 0.179079413, 0.248600766, 0.27399528, 0.259228647, 0.294202209, 0.299236417, 0.223688841, 0.262799472, 2.20011476E-8, 0.266098082, 0.220890298, 0.284285516, 0.330723315, 0.226809531, 0.365380913, 0.21229881, 0.239653021, 0.24949576, 0.525830686, 0.248247579, 0.295652747, 0.258776665, 0.4832564, 0.26670444]> : tensor<64xf32>
%cst_3 = stablehlo.constant dense<[0.230717152, 0.253822476, -1.05429808E-6, -0.664388895, -1.65705547E-8, 0.161521927, 0.454503953, -4.301950e-07, 0.300513744, -8.005240e-06, 0.349418074, 0.311480612, -0.249529764, -3.474890e-05, 0.107726313, 0.218970656, 0.381412596, -0.529882133, -0.628644109, 0.571398079, 0.299846917, 0.584303737, 0.48202154, 0.328526348, 0.196717009, 0.194961801, 0.152145416, 0.085522361, 0.513142824, 0.0152367353, 0.166441768, 0.332394391, 0.249211237, 0.443366677, -0.280169278, -0.0203848016, -2.45068748E-7, 0.321340501, -4.9151744E-8, 0.237767309, 0.232907727, 0.315274626, 0.427762389, 0.293127537, 0.263794243, 0.675975859, 0.429100394, 0.345662743, -8.69090186E-8, 0.247294366, 0.303160846, 0.615772783, 0.39834857, 0.332067341, -0.412187815, 0.378069043, 0.178953409, 0.25747788, -0.449079722, 0.213058949, 0.569339037, 5.727430e-01, -0.402383476, 0.23406373]> : tensor<64xf32>

// Inputs/expected represent the input and output of the first Conv operation in the ResNet model,
// obtained by passing a random image through the ONNX Runtime compiled with debug flags
// to capture intermediate tensor shapes and data.
%0 = call @inputs() : () -> tensor<1x3x224x224xf32>
%1 = call @expected() : () -> tensor<1x64x112x112xf32>

// Slicing the weight to reduce CPU cycles spend in interpreter.
// Slicing the kernel to reduce CPU cycles spend in interpreter.
// Calculating just a couple of layers already takes ~10s to complete.
%weight_slice = stablehlo.slice %cst [30:32, 0:3, 0:7, 0:7] : (tensor<64x3x7x7xf32>) -> tensor<2x3x7x7xf32>
%2 = stablehlo.convolution(%0, %weight_slice)
%kernel_slice = stablehlo.slice %cst [30:32, 0:3, 0:7, 0:7] : (tensor<64x3x7x7xf32>) -> tensor<2x3x7x7xf32>
%2 = stablehlo.convolution(%0, %kernel_slice)
dim_numbers = [b, f, 0, 1]x[o, i, 0, 1]->[b, f, 0, 1],
window = {stride = [2, 2], pad = [[3, 3], [3, 3]], rhs_dilate = [1, 1]}
{batch_group_count = 1 : i64, feature_group_count = 1 : i64}
Expand Down
70 changes: 36 additions & 34 deletions stablehlo/transforms/StablehloAggressiveSimplification.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1468,12 +1468,12 @@ struct ReorderElementwiseAndShapeOp final
}
};

// Fuses batch normalization operation with convolution weight:
// X = conv(input, weight)
// Fuses batch normalization operation with convolution kernel:
// X = conv(input, kernel.old)
// Y = batch_norm_inference(X, ...)
// into ->
// X = conv(input, weight(new))
// Y = add(X, broadcast_in_dim(Bias(new)))
// X = conv(input, kernel.new)
// Y = add(X, broadcast_in_dim(bias.new))
//
struct FuseConvolutionBatchNormalization final
: OpRewritePattern<BatchNormInferenceOp> {
Expand All @@ -1489,55 +1489,57 @@ struct FuseConvolutionBatchNormalization final
auto convOp = op.getOperand().getDefiningOp<ConvolutionOp>();
if (!convOp) return failure();

auto convWeight = convOp.getRhs();
auto convWeightType = convWeight.getType();
auto convWeightShape = convWeightType.getShape();
auto convKernel = convOp.getRhs();
auto convKernelType = convKernel.getType();
auto convKernelShape = convKernelType.getShape();

auto dimNumbers = convOp.getDimensionNumbers();
if (dimNumbers.getInputBatchDimension() != 0 ||
dimNumbers.getInputFeatureDimension() != 1 ||
dimNumbers.getOutputBatchDimension() != 0 ||
dimNumbers.getOutputFeatureDimension() != 1 ||
dimNumbers.getKernelOutputFeatureDimension() != 0 ||
dimNumbers.getKernelInputFeatureDimension() != 1)
return rewriter.notifyMatchFailure(convOp,
"Only [b, f, ...]x[o, i, ...]->[b, f, "
"...] configuration is supported");
dimNumbers.getKernelInputFeatureDimension() != 1) {
constexpr StringLiteral msg =
"Only [b, f, ...]x[o, i, ...]->[b, f, ...] configuration is "
"supported";
return rewriter.notifyMatchFailure(convOp, msg);
}

if (convOp.getFeatureGroupCount() > 1 || convOp.getBatchGroupCount() > 1)
return rewriter.notifyMatchFailure(
convOp, "feature or batch grouping is not supported");

if (bnOperandShape[bnFeatureIndex] != convWeightShape.front())
if (bnOperandShape[bnFeatureIndex] != convKernelShape.front())
return failure();

DenseFPElementsAttr convWeightElems;
DenseFPElementsAttr convKernelElems;
DenseFPElementsAttr scaleElems;
DenseFPElementsAttr offsetElems;
DenseFPElementsAttr meanElems;
DenseFPElementsAttr varianceElems;

auto epsilon = op.getEpsilon();
const auto epsilon = op.getEpsilon();

if (!matchPattern(convWeight, m_Constant(&convWeightElems)))
if (!matchPattern(convKernel, m_Constant(&convKernelElems)))
return rewriter.notifyMatchFailure(
op, "expected constant convolution weight");
op, "expected constant convolution kernel");

if (!matchPattern(op.getScale(), m_Constant(&scaleElems)) ||
!matchPattern(op.getOffset(), m_Constant(&offsetElems)) ||
!matchPattern(op.getMean(), m_Constant(&meanElems)) ||
!matchPattern(op.getVariance(), m_Constant(&varianceElems)))
return failure();

const auto &convWeightSemantics =
cast<FloatType>(convWeightType.getElementType()).getFloatSemantics();
const auto &convKernelSemantics =
cast<FloatType>(convKernelType.getElementType()).getFloatSemantics();

// W(new) = W(old) * gamma * rsqrt(variance + epsilon)
// B(new) = (B(old) - mean) * rsqrt(variance + epsilon) * gamma + betta
// K.new = K.old * gamma * rsqrt(variance + epsilon)
// B.new = (B.old - mean) * rsqrt(variance + epsilon) * gamma + beta
// where: gamma - scaling factor
// betta - shifting factor
// beta - shifting factor
// rsqrt - reciprocal square root function
// W - weight
// K - kernel(a.k.a weight)
// B - bias
//
const SmallVector<double> multipliers = llvm::map_to_vector(
Expand All @@ -1549,22 +1551,22 @@ struct FuseConvolutionBatchNormalization final
return rsqrt * scale.convertToDouble();
});

SmallVector<APFloat> newWeight;
newWeight.reserve(convWeightType.getNumElements());
SmallVector<APFloat> newKernel;
newKernel.reserve(convKernelType.getNumElements());

const size_t outFeatureTileSize =
computeProduct(convWeightShape.drop_front());
auto it = convWeightElems.begin();
computeProduct(convKernelShape.drop_front());
auto it = convKernelElems.begin();
for (const auto &multiplier : multipliers) {
for (size_t i = 0; i < outFeatureTileSize; ++i) {
double v = (*it).convertToDouble() * multiplier;
APFloat result(v);
bool losesInfo;
if (APFloat::opStatus::opInvalidOp ==
result.convert(convWeightSemantics, APFloat::rmNearestTiesToEven,
result.convert(convKernelSemantics, APFloat::rmNearestTiesToEven,
&losesInfo))
return failure();
newWeight.push_back(result);
newKernel.push_back(result);
++it;
}
}
Expand All @@ -1582,26 +1584,26 @@ struct FuseConvolutionBatchNormalization final

bool losesInfo;
if (APFloat::opStatus::opInvalidOp ==
result.convert(convWeightSemantics, APFloat::rmNearestTiesToEven,
result.convert(convKernelSemantics, APFloat::rmNearestTiesToEven,
&losesInfo))
return failure();

biasValues.push_back(result);
}

rewriter.setInsertionPoint(op);
auto newConvWeight = rewriter.create<ConstantOp>(
convWeight.getLoc(), convWeightType,
DenseFPElementsAttr::get(convWeightType, newWeight));
auto newConvKernel = rewriter.create<ConstantOp>(
convKernel.getLoc(), convKernelType,
DenseFPElementsAttr::get(convKernelType, newKernel));

// Keep old convolution as it might have other users
auto newConvOp = rewriter.create<ConvolutionOp>(
convOp.getLoc(), convOp->getResultTypes(),
ValueRange{convOp.getLhs(), newConvWeight}, convOp->getAttrs());
ValueRange{convOp.getLhs(), newConvKernel}, convOp->getAttrs());

SmallVector<int64_t> biasShape{static_cast<int64_t>(biasValues.size())};
auto biasType =
convWeightType.cloneWith(biasShape, convWeightType.getElementType());
convKernelType.cloneWith(biasShape, convKernelType.getElementType());
auto bias = rewriter.create<ConstantOp>(
op.getLoc(), biasType, DenseFPElementsAttr::get(biasType, biasValues));

Expand Down

0 comments on commit 147c226

Please sign in to comment.