diff --git a/pytorch_blade/tests/disc/test_dynamo.py b/pytorch_blade/tests/disc/test_dynamo.py index b684082eb31..92c62637e1f 100644 --- a/pytorch_blade/tests/disc/test_dynamo.py +++ b/pytorch_blade/tests/disc/test_dynamo.py @@ -26,7 +26,10 @@ def test_capture(self): import torch._dynamo as dynamo import torch_blade.dynamo explain_out = dynamo.explain(func1, b=torch.rand([2])) - self.assertEqual(explain_out.graph_count, 1) + if type(explain_out) is tuple: + self.assertEqual(len(explain_out[2]), 1) + else: + self.assertEqual(explain_out.graph_count, 1) if __name__ == '__main__': unittest.main()