Skip to content

Commit

Permalink
typing and order
Browse files Browse the repository at this point in the history
  • Loading branch information
felixdittrich92 committed Oct 12, 2023
1 parent 218c1ed commit e942c16
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions doctr/models/utils/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@ def _copy_tensor(x: torch.Tensor) -> torch.Tensor:
return x.clone().detach()


def _bf16_to_numpy_dtype(x: torch.Tensor) -> torch.Tensor:
# bfloat16 is not supported in .numpy(): torch/csrc/utils/tensor_numpy.cpp:aten_to_numpy_dtype
return x.float() if x.dtype == torch.bfloat16 else x


def load_pretrained_params(
model: nn.Module,
url: Optional[str] = None,
Expand Down Expand Up @@ -157,8 +162,3 @@ def export_model_to_onnx(model: nn.Module, model_name: str, dummy_input: torch.T
)
logging.info(f"Model exported to {model_name}.onnx")
return f"{model_name}.onnx"


def _bf16_to_numpy_dtype(x):
# bfloat16 is not supported in .numpy(): torch/csrc/utils/tensor_numpy.cpp:aten_to_numpy_dtype
return x.float() if x.dtype == torch.bfloat16 else x

0 comments on commit e942c16

Please sign in to comment.