Skip to content

Commit

Permalink
Fix ASAN/UBSAN issues in DimAnalysis
Browse files Browse the repository at this point in the history
- Fixes a memory leak
- Fixes an integer overflow caused by a dynamic shape
- Fixes reshape to wrong type in LIT tests

Signed-off-by: Rickert, Jonas <[email protected]>
  • Loading branch information
jorickert committed Jan 28, 2025
1 parent cf17e0d commit 7ad9f4f
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 24 deletions.
16 changes: 9 additions & 7 deletions src/Dialect/ONNX/ONNXDimAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -361,11 +361,11 @@ static bool exploreSameDimsUsingShapeHelper(const DimAnalysis::DimT &dim,
ONNXOpShapeHelper *shapeHelper =
shape_op.getShapeHelper(op, {}, nullptr, nullptr);
// If no shape helper, or unimplemented, just abort.
if (!shapeHelper || !shapeHelper->isImplemented())
if (!shapeHelper)
return false;

// Compute shape.
if (failed(shapeHelper->computeShape())) {
if (!shapeHelper->isImplemented() || failed(shapeHelper->computeShape())) {
delete shapeHelper;
return false;
}
Expand Down Expand Up @@ -961,12 +961,14 @@ void DimAnalysis::visitDim(
bool outputHasOneDynamicDim =
(llvm::count(outputType.getShape(), ShapedType::kDynamic) == 1);
// Check if the products of static sizes in the data and output are equal.
// It's ok to count ShapedType::kDynamic (dynamic dimension) in the size.
int64_t dataStaticSize = 1, outputStaticSize = 1;
for (int64_t i = 0; i < dataType.getRank(); ++i)
dataStaticSize *= dataType.getShape()[i];
for (int64_t i = 0; i < outputType.getRank(); ++i)
outputStaticSize *= outputType.getShape()[i];
for (int64_t i = 0; i < dataType.getRank(); ++i) {
dataStaticSize *= dataType.isDynamicDim(i) ? -1 : dataType.getShape()[i];
}
for (int64_t i = 0; i < outputType.getRank(); ++i) {
outputStaticSize *=
outputType.isDynamicDim(i) ? -1 : outputType.getShape()[i];
}
// Conditions hold, the dynamic dimension can be from the data.
if (dataHasOneDynamicDim && outputHasOneDynamicDim &&
(dataStaticSize == outputStaticSize)) {
Expand Down
34 changes: 17 additions & 17 deletions test/mlir/onnx/onnx_dim_analysis.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -184,38 +184,38 @@ func.func @test_matmul_batchsize(%arg0: tensor<?x8x16x16xf32>) -> tensor<?x8x16x

// -----

func.func @test_matmul_batchsize_diff_rank(%arg0: tensor<8x?x16x4xf32>) -> tensor<8x?x16x32xf32> {
func.func @test_matmul_batchsize_diff_rank(%arg0: tensor<8x?x16x4xf32>) -> tensor<8x?x16x128xf32> {
%shape = onnx.Constant dense<[-1, 4, 128]> : tensor<3xi64>
%0 = "onnx.Reshape"(%arg0, %shape) {allowzero = 0 : si64} : (tensor<8x?x16x4xf32>, tensor<3xi64>) -> tensor<?x4x32xf32>
%1 = "onnx.MatMul"(%arg0, %0) : (tensor<8x?x16x4xf32>, tensor<?x4x32xf32>) -> tensor<8x?x16x32xf32>
"onnx.Return"(%1) : (tensor<8x?x16x32xf32>) -> ()
%0 = "onnx.Reshape"(%arg0, %shape) {allowzero = 0 : si64} : (tensor<8x?x16x4xf32>, tensor<3xi64>) -> tensor<?x4x128xf32>
%1 = "onnx.MatMul"(%arg0, %0) : (tensor<8x?x16x4xf32>, tensor<?x4x128xf32>) -> tensor<8x?x16x128xf32>
"onnx.Return"(%1) : (tensor<8x?x16x128xf32>) -> ()

// CHECK-LABEL: func.func @test_matmul_batchsize_diff_rank
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<8x?x16x4xf32>) -> tensor<8x?x16x32xf32> {
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<8x?x16x4xf32>) -> tensor<8x?x16x128xf32> {
// CHECK-DAG: "onnx.DimGroup"([[PARAM_0_]]) {axis = 1 : si64, group_id = 0 : si64} : (tensor<8x?x16x4xf32>) -> ()
// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<[-1, 4, 128]> : tensor<3xi64>
// CHECK: [[VAR_1_:%.+]] = "onnx.Reshape"([[PARAM_0_]], [[VAR_0_]]) {allowzero = 0 : si64} : (tensor<8x?x16x4xf32>, tensor<3xi64>) -> tensor<?x4x32xf32>
// CHECK: "onnx.DimGroup"([[VAR_1_]]) {axis = 0 : si64, group_id = 0 : si64} : (tensor<?x4x32xf32>) -> ()
// CHECK: [[VAR_2_:%.+]] = "onnx.MatMul"([[PARAM_0_]], [[VAR_1_]]) : (tensor<8x?x16x4xf32>, tensor<?x4x32xf32>) -> tensor<8x?x16x32xf32>
// CHECK: "onnx.DimGroup"([[VAR_2_]]) {axis = 1 : si64, group_id = 0 : si64} : (tensor<8x?x16x32xf32>) -> ()
// CHECK: onnx.Return [[VAR_2_]] : tensor<8x?x16x32xf32>
// CHECK: [[VAR_1_:%.+]] = "onnx.Reshape"([[PARAM_0_]], [[VAR_0_]]) {allowzero = 0 : si64} : (tensor<8x?x16x4xf32>, tensor<3xi64>) -> tensor<?x4x128xf32>
// CHECK: "onnx.DimGroup"([[VAR_1_]]) {axis = 0 : si64, group_id = 0 : si64} : (tensor<?x4x128xf32>) -> ()
// CHECK: [[VAR_2_:%.+]] = "onnx.MatMul"([[PARAM_0_]], [[VAR_1_]]) : (tensor<8x?x16x4xf32>, tensor<?x4x128xf32>) -> tensor<8x?x16x128xf32>
// CHECK: "onnx.DimGroup"([[VAR_2_]]) {axis = 1 : si64, group_id = 0 : si64} : (tensor<8x?x16x128xf32>) -> ()
// CHECK: onnx.Return [[VAR_2_]] : tensor<8x?x16x128xf32>
// CHECK: }
}

// -----

func.func @test_reshape_single_dyn_dim(%arg0: tensor<8x?x16x4xf32>) -> tensor<?x4x32xf32> {
func.func @test_reshape_single_dyn_dim(%arg0: tensor<8x?x16x4xf32>) -> tensor<?x4x128xf32> {
%shape = onnx.Constant dense<[-1, 4, 128]> : tensor<3xi64>
%0 = "onnx.Reshape"(%arg0, %shape) {allowzero = 0 : si64} : (tensor<8x?x16x4xf32>, tensor<3xi64>) -> tensor<?x4x32xf32>
"onnx.Return"(%0) : (tensor<?x4x32xf32>) -> ()
%0 = "onnx.Reshape"(%arg0, %shape) {allowzero = 0 : si64} : (tensor<8x?x16x4xf32>, tensor<3xi64>) -> tensor<?x4x128xf32>
"onnx.Return"(%0) : (tensor<?x4x128xf32>) -> ()

// CHECK-LABEL: func.func @test_reshape_single_dyn_dim
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<8x?x16x4xf32>) -> tensor<?x4x32xf32> {
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<8x?x16x4xf32>) -> tensor<?x4x128xf32> {
// CHECK-DAG: "onnx.DimGroup"([[PARAM_0_]]) {axis = 1 : si64, group_id = 0 : si64} : (tensor<8x?x16x4xf32>) -> ()
// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<[-1, 4, 128]> : tensor<3xi64>
// CHECK: [[VAR_1_:%.+]] = "onnx.Reshape"([[PARAM_0_]], [[VAR_0_]]) {allowzero = 0 : si64} : (tensor<8x?x16x4xf32>, tensor<3xi64>) -> tensor<?x4x32xf32>
// CHECK: "onnx.DimGroup"([[VAR_1_]]) {axis = 0 : si64, group_id = 0 : si64} : (tensor<?x4x32xf32>) -> ()
// CHECK: onnx.Return [[VAR_1_]] : tensor<?x4x32xf32>
// CHECK: [[VAR_1_:%.+]] = "onnx.Reshape"([[PARAM_0_]], [[VAR_0_]]) {allowzero = 0 : si64} : (tensor<8x?x16x4xf32>, tensor<3xi64>) -> tensor<?x4x128xf32>
// CHECK: "onnx.DimGroup"([[VAR_1_]]) {axis = 0 : si64, group_id = 0 : si64} : (tensor<?x4x128xf32>) -> ()
// CHECK: onnx.Return [[VAR_1_]] : tensor<?x4x128xf32>
// CHECK: }
}

Expand Down

0 comments on commit 7ad9f4f

Please sign in to comment.