From 6077adc6bc08ae89d3f41c817cec5e9cd6882117 Mon Sep 17 00:00:00 2001 From: nihui Date: Tue, 22 Oct 2024 18:31:22 +0800 Subject: [PATCH] pnnx do not fold tensor with dynamic shape, use fp32 module by default (#5755) --- tools/pnnx/src/ir.cpp | 3 +++ tools/pnnx/src/pass_level0/shape_inference.cpp | 4 +++- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/tools/pnnx/src/ir.cpp b/tools/pnnx/src/ir.cpp index 8b2b6dfd2d7..a0eb8d692bf 100644 --- a/tools/pnnx/src/ir.cpp +++ b/tools/pnnx/src/ir.cpp @@ -2390,6 +2390,7 @@ int Graph::python(const std::string& pypath, const std::string& pnnxbinpath) { fprintf(pyfp, "def export_torchscript():\n"); fprintf(pyfp, " net = Model()\n"); + fprintf(pyfp, " net.float()\n"); fprintf(pyfp, " net.eval()\n"); fprintf(pyfp, "\n"); fprintf(pyfp, " torch.manual_seed(0)\n"); @@ -2455,6 +2456,7 @@ int Graph::python(const std::string& pypath, const std::string& pnnxbinpath) { fprintf(pyfp, "def export_onnx():\n"); fprintf(pyfp, " net = Model()\n"); + fprintf(pyfp, " net.float()\n"); fprintf(pyfp, " net.eval()\n"); fprintf(pyfp, "\n"); fprintf(pyfp, " torch.manual_seed(0)\n"); @@ -2576,6 +2578,7 @@ int Graph::python(const std::string& pypath, const std::string& pnnxbinpath) { fprintf(pyfp, "def test_inference():\n"); fprintf(pyfp, " net = Model()\n"); + fprintf(pyfp, " net.float()\n"); fprintf(pyfp, " net.eval()\n"); fprintf(pyfp, "\n"); fprintf(pyfp, " torch.manual_seed(0)\n"); diff --git a/tools/pnnx/src/pass_level0/shape_inference.cpp b/tools/pnnx/src/pass_level0/shape_inference.cpp index a273dd79df8..5865390bdfa 100644 --- a/tools/pnnx/src/pass_level0/shape_inference.cpp +++ b/tools/pnnx/src/pass_level0/shape_inference.cpp @@ -418,12 +418,14 @@ void shape_inference(const torch::jit::Module& mod, std::shared_ptr sizes1 = type1->symbolic_sizes().sizes().value(); std::vector sizes2 = type2->symbolic_sizes().sizes().value(); + bool is_shape_static = true; for (size_t i = 0; i < sizes1.size(); i++) { if (sizes1[i] == sizes2[i]) continue; sizes1[i] = c10::ShapeSymbol::fromStaticSize(-1); + is_shape_static = false; } auto finaltype = type1->withSymbolicShapes(c10::SymbolicShape(sizes1)); @@ -431,7 +433,7 @@ void shape_inference(const torch::jit::Module& mod, std::shared_ptrsetType(finaltype); // check if value that does not depend on inputs - if (value_link_input_map.find(v->debugName()) == value_link_input_map.end() && value_link_output(v, g_outputs)) + if (is_shape_static && value_link_input_map.find(v->debugName()) == value_link_input_map.end() && value_link_output(v, g_outputs)) { // fprintf(stderr, "foldable_constant %s\n", v->debugName().c_str()); foldable_constants.insert(v->debugName());