diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index 66729d4445142..93622529a5d33 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -2749,8 +2749,69 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( binder.op, "unimplemented: cubic coeff must be -0.75"); } + Value inputTensor = operands[0]; + Torch::ValueTensorType inputTensor_type = + cast(inputTensor.getType()); + ArrayRef inputTensor_sizes = inputTensor_type.getSizes(); + ArrayRef outputTensor_sizes = outputTensor_type.getSizes(); + + int64_t const batchDimension = 0; + int64_t const channelDimension = 1; + int64_t nonScalableDimensions[] = { + batchDimension, + channelDimension, + }; + + auto errorMessageForScaling = [](int64_t givenDimension) { + switch (givenDimension) { + case batchDimension: + return "Unexpected intent to scale the batch dimension"; + case channelDimension: + return "Unexpected intent to scale the channel dimension"; + default: + return "Scalable dimension treated as non-scalable"; + } + }; + + auto unknownSize = Torch::kUnknownSize; + + // Compile-time check for dimensions of static size + for (auto eachDimension : nonScalableDimensions) { + auto eachInputSize = inputTensor_sizes[eachDimension]; + auto eachOutputSize = outputTensor_sizes[eachDimension]; + + if (eachInputSize == unknownSize || eachOutputSize == unknownSize) { + continue; + } else if (eachInputSize == eachOutputSize) { + continue; + } + + return rewriter.notifyMatchFailure( + binder.op, errorMessageForScaling(eachDimension)); + } + auto binderLocation = binder.getLoc(); + // Run-time check for dimensions of dynamic size + for (auto eachDimension : nonScalableDimensions) { + auto eachDimensionAsValue = rewriter.create( + binderLocation, rewriter.getI64IntegerAttr(eachDimension)); + + Value eachInputSizeAsValue = rewriter.create( + binderLocation, inputTensor, eachDimensionAsValue); + + int64_t eachOutputSize = outputTensor_sizes[eachDimension]; + Value eachOutputSizeAsValue = rewriter.create( + binderLocation, rewriter.getI64IntegerAttr(eachOutputSize)); + + Value eachSizeComparison = rewriter.create( + binderLocation, eachInputSizeAsValue, eachOutputSizeAsValue); + + rewriter.create( + binderLocation, eachSizeComparison, + rewriter.getStringAttr(errorMessageForScaling(eachDimension))); + }; + Value cstFalse = rewriter.create(binderLocation, false); Value cstTrue = @@ -2770,10 +2831,6 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( rewriter.create(binderLocation, modeStr); } - Value inputTensor = operands[0]; - Torch::ValueTensorType inputTensor_type = - cast(inputTensor.getType()); - ArrayRef inputTensor_sizes = inputTensor_type.getSizes(); unsigned inputTensor_rank = inputTensor_sizes.size(); // supported modes: diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir index 16c86218dbc8b..b4803be6ed3b9 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir @@ -2256,7 +2256,7 @@ func.func @test_sce_mean_3d_log_prob(%arg0: !torch.vtensor<[3,5,2],f32>, %arg1: // CHECK-LABEL: func.func @test_resize_sizes_nearest func.func @test_resize_sizes_nearest(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { %none = torch.constant.none - // CHECK: torch.aten.__interpolate.size_list_scale_list %arg0, %4, %none_0, %str, %false, %none_0, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32> + // CHECK: torch.aten.__interpolate.size_list_scale_list %arg0, %8, %none_1, %str, %false, %none_1, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32> %0 = torch.operator "onnx.Resize"(%arg0, %none, %none, %arg1) {torch.onnx.coordinate_transformation_mode = "asymmetric", torch.onnx.cubic_coeff_a = -7.500000e-01 : f32, torch.onnx.mode = "nearest", torch.onnx.nearest_mode = "floor"} : (!torch.vtensor<[1,1,2,4],f32>, !torch.none, !torch.none, !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32> return %0 : !torch.vtensor<[?,?,?,?],f32> } @@ -2267,7 +2267,7 @@ func.func @test_sce_mean_3d_log_prob(%arg0: !torch.vtensor<[3,5,2],f32>, %arg1: func.func @test_resize_sizes_nearest(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { %none = torch.constant.none // CHECK: %[[STR:.+]] = torch.constant.str "nearest_half_pixel,round_prefer_floor" - // CHECK: torch.aten.__interpolate.size_list_scale_list %arg0, %4, %none_0, %[[STR]], %false, %none_0, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32> + // CHECK: torch.aten.__interpolate.size_list_scale_list %arg0, %8, %none_1, %[[STR]], %false, %none_1, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32> %0 = torch.operator "onnx.Resize"(%arg0, %none, %none, %arg1) { torch.onnx.coordinate_transformation_mode = "half_pixel", torch.onnx.mode = "nearest"} : (!torch.vtensor<[1,1,2,4],f32>, !torch.none, !torch.none, !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32> @@ -2280,7 +2280,7 @@ func.func @test_resize_sizes_nearest(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1 func.func @test_resize_sizes_linear(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?], f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { %none = torch.constant.none - // CHECK: torch.aten.__interpolate.size_list_scale_list %arg0, %4, %none_0, %str, %false, %none_0, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32> + // CHECK: torch.aten.__interpolate.size_list_scale_list %arg0, %8, %none_1, %str, %false, %none_1, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32> %0 = torch.operator "onnx.Resize"(%arg0, %none, %none, %arg1) {torch.onnx.mode = "linear"} : (!torch.vtensor<[1,1,2,4],f32>, !torch.none, !torch.none, !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32> return %0 : !torch.vtensor<[?,?,?,?],f32> }