Skip to content

Commit

Permalink
pnnx do not fold tensor with dynamic shape, use fp32 module by default (
Browse files Browse the repository at this point in the history
  • Loading branch information
nihui authored Oct 22, 2024
1 parent e7602a2 commit 6077adc
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 1 deletion.
3 changes: 3 additions & 0 deletions tools/pnnx/src/ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2390,6 +2390,7 @@ int Graph::python(const std::string& pypath, const std::string& pnnxbinpath)
{
fprintf(pyfp, "def export_torchscript():\n");
fprintf(pyfp, " net = Model()\n");
fprintf(pyfp, " net.float()\n");
fprintf(pyfp, " net.eval()\n");
fprintf(pyfp, "\n");
fprintf(pyfp, " torch.manual_seed(0)\n");
Expand Down Expand Up @@ -2455,6 +2456,7 @@ int Graph::python(const std::string& pypath, const std::string& pnnxbinpath)
{
fprintf(pyfp, "def export_onnx():\n");
fprintf(pyfp, " net = Model()\n");
fprintf(pyfp, " net.float()\n");
fprintf(pyfp, " net.eval()\n");
fprintf(pyfp, "\n");
fprintf(pyfp, " torch.manual_seed(0)\n");
Expand Down Expand Up @@ -2576,6 +2578,7 @@ int Graph::python(const std::string& pypath, const std::string& pnnxbinpath)
{
fprintf(pyfp, "def test_inference():\n");
fprintf(pyfp, " net = Model()\n");
fprintf(pyfp, " net.float()\n");
fprintf(pyfp, " net.eval()\n");
fprintf(pyfp, "\n");
fprintf(pyfp, " torch.manual_seed(0)\n");
Expand Down
4 changes: 3 additions & 1 deletion tools/pnnx/src/pass_level0/shape_inference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -418,20 +418,22 @@ void shape_inference(const torch::jit::Module& mod, std::shared_ptr<torch::jit::
std::vector<c10::ShapeSymbol> sizes1 = type1->symbolic_sizes().sizes().value();
std::vector<c10::ShapeSymbol> sizes2 = type2->symbolic_sizes().sizes().value();

bool is_shape_static = true;
for (size_t i = 0; i < sizes1.size(); i++)
{
if (sizes1[i] == sizes2[i])
continue;

sizes1[i] = c10::ShapeSymbol::fromStaticSize(-1);
is_shape_static = false;
}

auto finaltype = type1->withSymbolicShapes(c10::SymbolicShape(sizes1));

v->setType(finaltype);

// check if value that does not depend on inputs
if (value_link_input_map.find(v->debugName()) == value_link_input_map.end() && value_link_output(v, g_outputs))
if (is_shape_static && value_link_input_map.find(v->debugName()) == value_link_input_map.end() && value_link_output(v, g_outputs))
{
// fprintf(stderr, "foldable_constant %s\n", v->debugName().c_str());
foldable_constants.insert(v->debugName());
Expand Down

0 comments on commit 6077adc

Please sign in to comment.