diff --git a/backends/qualcomm/tests/utils.py b/backends/qualcomm/tests/utils.py index 173590b2a63..4b7e9663b17 100644 --- a/backends/qualcomm/tests/utils.py +++ b/backends/qualcomm/tests/utils.py @@ -51,6 +51,7 @@ prepare_pt2e, prepare_qat_pt2e, ) +from unittest.mock import patch def generate_context_binary( @@ -516,9 +517,20 @@ def get_qdq_module( block_size_map: Dict[str, Tuple] = None, submodule_qconfig_list: Optional[List[Tuple[Callable, ModuleQConfig]]] = None, ) -> torch.fx.GraphModule: + from executorch.backends.qualcomm.utils.utils import draw_graph + with patch.object( + torch._utils_internal, + "export_training_ir_rollout_check", + return_value=False, + ): + m_with_patch = torch.export.export( + module, inputs, dynamic_shapes=dynamic_shapes, strict=False + ).module() + draw_graph("export_with_patch", ".", m_with_patch) m = torch.export.export( - module, inputs, dynamic_shapes=dynamic_shapes, strict=True - ).module() + module, inputs, dynamic_shapes=dynamic_shapes, strict=False + ).module() + draw_graph("export", ".", m) quantizer = make_quantizer( quant_dtype=quant_dtype, @@ -529,6 +541,9 @@ def get_qdq_module( ) if block_size_map is not None: quantizer.set_block_size_map(block_size_map) + prepared_with_patch = prepare_pt2e(m_with_patch, quantizer) + prepared_with_patch(*inputs) + quantized_module_with_patch = convert_pt2e(prepared_with_patch) prepared = prepare_pt2e(m, quantizer) prepared(*inputs) quantized_module = convert_pt2e(prepared) @@ -543,6 +558,8 @@ def get_qdq_module( } if not bypass_check: self.assertTrue(nodes.intersection(q_and_dq)) + draw_graph("convert_pt2e_with_patch", ".", quantized_module_with_patch) + draw_graph("convert_pt2e", ".", quantized_module) return quantized_module def get_prepared_qat_module(