diff --git a/doctr/models/utils/pytorch.py b/doctr/models/utils/pytorch.py index 4e15fa628a..2179ba572b 100644 --- a/doctr/models/utils/pytorch.py +++ b/doctr/models/utils/pytorch.py @@ -25,6 +25,11 @@ def _copy_tensor(x: torch.Tensor) -> torch.Tensor: return x.clone().detach() +def _bf16_to_numpy_dtype(x: torch.Tensor) -> torch.Tensor: + # bfloat16 is not supported in .numpy(): torch/csrc/utils/tensor_numpy.cpp:aten_to_numpy_dtype + return x.float() if x.dtype == torch.bfloat16 else x + + def load_pretrained_params( model: nn.Module, url: Optional[str] = None, @@ -157,8 +162,3 @@ def export_model_to_onnx(model: nn.Module, model_name: str, dummy_input: torch.T ) logging.info(f"Model exported to {model_name}.onnx") return f"{model_name}.onnx" - - -def _bf16_to_numpy_dtype(x): - # bfloat16 is not supported in .numpy(): torch/csrc/utils/tensor_numpy.cpp:aten_to_numpy_dtype - return x.float() if x.dtype == torch.bfloat16 else x