Skip to content

Commit

Permalink
Implement CascadePSP
Browse files Browse the repository at this point in the history
  • Loading branch information
ooe1123 committed Feb 16, 2022
1 parent f6e6fc2 commit a38d62a
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 17 deletions.
63 changes: 46 additions & 17 deletions background_removal/cascade_psp/cascade_psp.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ def resize(img, size=None, out_shape=None, method='bilinear'):
oh = int(ratio * h)
ow = int(ratio * w)

f = False
if f:
use_pytorch = False
if not use_pytorch:
inp = {
'bilinear': cv2.INTER_LINEAR,
'bicubic': cv2.INTER_CUBIC,
Expand Down Expand Up @@ -102,10 +102,6 @@ def preprocess(img, gray=False):
return img


def post_process(img):
return img


def safe_forward(net, img, seg, inter_s8=None):
_, _, ph, pw = seg.shape

Expand Down Expand Up @@ -228,9 +224,39 @@ def predict(net, net_s8, img, seg):
if (seg_part_norm.mean() > high_thres) or (seg_part_norm.mean() < low_thres):
continue

print("---", x_idx, y_idx)
grid_images = safe_forward(net, im_part, seg_224_part, seg_56_part)
grid_pred_224 = grid_images[0]

# Padding
pred_sx = pred_sy = 0
pred_ex = step_len
pred_ey = step_len

if start_x != 0:
start_x += padding
pred_sx += padding
if start_y != 0:
start_y += padding
pred_sy += padding
if end_x != w:
end_x -= padding
pred_ex -= padding
if end_y != h:
end_y -= padding
pred_ey -= padding

return None
combined_224[:, :, start_y:end_y, start_x:end_x] += grid_pred_224[:, :, pred_sy:pred_ey, pred_sx:pred_ex]

del grid_pred_224

# Used for averaging
combined_weight[:, :, start_y:end_y, start_x:end_x] += 1

# Final full resolution output
seg_norm = (r_pred_224 / 2 + 0.5)
pred_224 = np.divide(combined_224, combined_weight, out=seg_norm, where=combined_weight != 0)

return pred_224[0, 0]


def recognize_from_image(net, net_s8):
Expand Down Expand Up @@ -269,13 +295,12 @@ def recognize_from_image(net, net_s8):
else:
output = predict(net, net_s8, img, mask_img)

# # postprocessing
# res_img = post_process(*output, a=True)
# res_img = cv2.cvtColor(res_img, cv2.COLOR_RGBA2BGRA)
#
# savepath = get_savepath(args.savepath, image_path, ext='.png')
# logger.info(f'saved at : {savepath}')
# cv2.imwrite(savepath, res_img)
# postprocessing
res_img = (output * 255).astype(np.uint8)

savepath = get_savepath(args.savepath, image_path, ext='.png')
logger.info(f'saved at : {savepath}')
cv2.imwrite(savepath, res_img)

logger.info('Script finished successfully.')

Expand All @@ -290,8 +315,12 @@ def main():
env_id = args.env_id

# net initialize
net = ailia.Net(MODEL_PATH, WEIGHT_PATH, env_id=env_id)
net_s8 = ailia.Net(MODEL_INTER_S8_PATH, WEIGHT_INTER_S8_PATH, env_id=env_id)
logger.info("This model requires 10GB or more memory.")
memory_mode = ailia.get_memory_mode(
reduce_constant=True, ignore_input_with_initializer=True,
reduce_interstage=False, reuse_interstage=True)
net = ailia.Net(MODEL_PATH, WEIGHT_PATH, env_id=env_id, memory_mode=memory_mode)
net_s8 = ailia.Net(MODEL_INTER_S8_PATH, WEIGHT_INTER_S8_PATH, env_id=env_id, memory_mode=memory_mode)

recognize_from_image(net, net_s8)

Expand Down
Binary file added background_removal/cascade_psp/output.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit a38d62a

Please sign in to comment.