Skip to content

Commit

Permalink
Only copy non-writeable image when CPU upscaling (#2860)
Browse files Browse the repository at this point in the history
* Only copy non-writeable image when CPU upscaling

* wrap with try/finally

* use original writeable value instead
  • Loading branch information
joeyballentine authored May 18, 2024
1 parent 61392ef commit 91b80c7
Showing 1 changed file with 18 additions and 4 deletions.
22 changes: 18 additions & 4 deletions backend/src/nodes/impl/pytorch/auto_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,23 @@ def _rgb_to_bgr(t: torch.Tensor) -> torch.Tensor:
return t


def _into_tensor(
img: np.ndarray, device: torch.device, dtype: torch.dtype
) -> torch.Tensor:
img = np.ascontiguousarray(img)
writeable = img.flags.writeable
try:
if not writeable and device == torch.device("cpu"):
img = np.copy(img)
else:
# since we are going to copy the image to the GPU, we can skip the copy here
img.flags.writeable = True
input_tensor = torch.from_numpy(img).to(device, dtype)
return input_tensor
finally:
img.flags.writeable = writeable


@torch.inference_mode()
def pytorch_auto_split(
img: np.ndarray,
Expand All @@ -72,10 +89,7 @@ def upscale(img: np.ndarray, _: object):
input_tensor = None
try:
# convert to tensor
img = np.ascontiguousarray(img)
if not img.flags.writeable:
img = np.copy(img)
input_tensor = torch.from_numpy(img).to(device, dtype)
input_tensor = _into_tensor(img, device, dtype)
input_tensor = _rgb_to_bgr(input_tensor)
input_tensor = _into_batched_form(input_tensor)

Expand Down

0 comments on commit 91b80c7

Please sign in to comment.