Skip to content

Commit

Permalink
map_coordinates also cannot handle np.float16
Browse files Browse the repository at this point in the history
  • Loading branch information
BrianPugh committed Jan 4, 2025
1 parent b6e59bc commit d814cde
Showing 1 changed file with 25 additions and 22 deletions.
47 changes: 25 additions & 22 deletions py360convert/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,28 +376,29 @@ def __init__(
self._order = order

def __call__(self, img: NDArray[DType]) -> NDArray[DType]:
padded = self._pad(img)
if self._use_cv2:
if padded.dtype == np.float16:
source_dtype = np.float16
else:
source_dtype = None
if img.dtype == np.float16:
source_dtype = np.float16
else:
source_dtype = None

if source_dtype:
padded = padded.astype(np.float32)
if source_dtype:
img = img.astype(np.float32) # pyright: ignore

padded = self._pad(img)
if self._use_cv2:
# cv2.remap can handle uint8, float32, float64
out = cv2.remap(padded, self._coor_x, self._coor_y, interpolation=self._order) # pyright: ignore

if source_dtype:
out = out.astype(source_dtype)
else:
# map_coordinates can handle uint8, float32, float64
out = map_coordinates(
padded,
(self._coor_y, self._coor_x),
order=self._order,
)[..., 0]

if source_dtype:
out = out.astype(source_dtype)

return out # pyright: ignore[reportReturnType]

def _pad(self, img: NDArray[DType]) -> NDArray[DType]:
Expand Down Expand Up @@ -548,26 +549,28 @@ def __call__(self, cube_faces: NDArray[DType]) -> NDArray[DType]:
if w != self._w:
raise ValueError("Input width {w} doesn't match expected height {self._w}.")

padded = self._pad(cube_faces)
if self._use_cv2:
if padded.dtype == np.float16:
source_dtype = np.float16
else:
source_dtype = None
if cube_faces.dtype == np.float16:
source_dtype = np.float16
else:
source_dtype = None

if source_dtype:
padded = padded.astype(np.float32)
if source_dtype:
cube_faces = cube_faces.astype(np.float32) # pyright: ignore

padded = self._pad(cube_faces)
if self._use_cv2:
w = padded.shape[-1]
v_img = padded.reshape(-1, w)

# cv2.remap can handle uint8, float32, float64
out = cv2.remap(v_img, self._coor_x, self._coor_y, interpolation=self._order) # pyright: ignore

if source_dtype:
out = out.astype(source_dtype)
else:
# map_coordinates can handle uint8, float32, float64
out = map_coordinates(padded, (self._tp, self._coor_y, self._coor_x), order=self._order)

if source_dtype:
out = out.astype(source_dtype)

return out # pyright: ignore[reportReturnType]

def _pad(self, cube_faces: NDArray[DType]) -> NDArray[DType]:
Expand Down

0 comments on commit d814cde

Please sign in to comment.