diff --git a/mmrazor/models/task_modules/tracer/fx/custom_tracer.py b/mmrazor/models/task_modules/tracer/fx/custom_tracer.py index 68d5f0809..41b654bc9 100644 --- a/mmrazor/models/task_modules/tracer/fx/custom_tracer.py +++ b/mmrazor/models/task_modules/tracer/fx/custom_tracer.py @@ -142,6 +142,10 @@ def _get_attrs(target, attrs): if isinstance(attr, nn.Module): module_dict[node.target] = nn.Module() special_nodes.append(node) + #; the original design fails to + #; trace any Tensor object + elif isinstance(attr, torch.Tensor): + module_dict[node.target] = attr elif node.op == 'call_method': for special_node in special_nodes: if special_node in node.args or \