Skip to content

Commit

Permalink
Fix custom tile size type (#2870)
Browse files Browse the repository at this point in the history
* Fix custom tile size type

* None -> 0
  • Loading branch information
joeyballentine authored May 18, 2024
1 parent a971d74 commit e01f8bc
Show file tree
Hide file tree
Showing 5 changed files with 8 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def upscale_image_node(
img: np.ndarray,
model: NcnnModelWrapper,
tile_size: TileSize,
custom_tile_size: int | None,
custom_tile_size: int,
separate_alpha: bool,
) -> np.ndarray:
settings = get_settings(context)
Expand All @@ -219,9 +219,7 @@ def upscale(i: np.ndarray) -> np.ndarray:
model,
model.model.layers[0].outputs[0],
model.model.layers[-1].outputs[0],
TileSize(custom_tile_size)
if tile_size == CUSTOM and custom_tile_size is not None
else tile_size,
TileSize(custom_tile_size) if tile_size == CUSTOM else tile_size,
)
if ic == 3:
i = cv2.cvtColor(i, cv2.COLOR_RGB2BGR)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

def check_will_upscale(context: NodeContext, interp: NcnnModelWrapper):
fake_img = np.ones((3, 3, 3), dtype=np.float32, order="F")
result = upscale_image_node(context, fake_img, interp, NO_TILING, None, False)
result = upscale_image_node(context, fake_img, interp, NO_TILING, 0, False)

mean_color = np.mean(result)
del result
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def upscale_image_node(
img: np.ndarray,
model: OnnxModel,
tile_size: TileSize,
custom_tile_size: int | None,
custom_tile_size: int,
separate_alpha: bool,
) -> np.ndarray:
settings = get_settings(context)
Expand Down Expand Up @@ -141,9 +141,7 @@ def upscale_image_node(
lambda i: upscale(
i,
session,
TileSize(custom_tile_size)
if tile_size == CUSTOM and custom_tile_size is not None
else tile_size,
TileSize(custom_tile_size) if tile_size == CUSTOM else tile_size,
change_shape,
exact_size,
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def perform_interp(

def check_will_upscale(context: NodeContext, model: OnnxModel):
fake_img = np.ones((3, 3, 3), dtype=np.float32, order="F")
result = upscale_image_node(context, fake_img, model, NO_TILING, None, False)
result = upscale_image_node(context, fake_img, model, NO_TILING, 0, False)

mean_color = np.mean(result)
del result
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ def upscale_image_node(
use_custom_scale: bool,
custom_scale: int,
tile_size: TileSize,
custom_tile_size: int | None,
custom_tile_size: int,
separate_alpha: bool,
) -> np.ndarray:
exec_options = get_settings(context)
Expand All @@ -286,9 +286,7 @@ def inner_upscale(img: np.ndarray) -> np.ndarray:
lambda i: upscale(
i,
model,
TileSize(custom_tile_size)
if tile_size == CUSTOM and custom_tile_size is not None
else tile_size,
TileSize(custom_tile_size) if tile_size == CUSTOM else tile_size,
exec_options,
context,
),
Expand Down

0 comments on commit e01f8bc

Please sign in to comment.