Skip to content

Commit

Permalink
decode_latent_tokens
Browse files Browse the repository at this point in the history
  • Loading branch information
johndpope committed Oct 4, 2024
1 parent 61c1440 commit ca2cec9
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 10 deletions.
18 changes: 8 additions & 10 deletions inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,15 @@ def process_video(model, video_path, output_path, transform, device, frame_skip=
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))

# Process reference frame
reference_frame = vr[0].asnumpy()
reference_frame = Image.fromarray(reference_frame)
reference_frame = transform(reference_frame).unsqueeze(0).to(device)

with torch.no_grad():
f_r = model.dense_feature_encoder(reference_frame)
t_r = model.latent_token_encoder(reference_frame)

total_frames = len(vr)
for i in range(1, total_frames):
if i % (frame_skip + 1) != 0:
Expand All @@ -42,7 +47,8 @@ def process_video(model, video_path, output_path, transform, device, frame_skip=
current_frame = transform(current_frame).unsqueeze(0).to(device)

with torch.no_grad():
reconstructed_frame = model(current_frame, reference_frame)
t_c = model.latent_token_encoder(current_frame)
reconstructed_frame = model.decode_latent_tokens(f_r, t_r, t_c)

reconstructed_frame = reconstructed_frame.squeeze().cpu().numpy().transpose(1, 2, 0)
reconstructed_frame = (reconstructed_frame * 255).astype(np.uint8)
Expand Down Expand Up @@ -77,16 +83,8 @@ def main():
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

if config.input.video_path:
process_video(model, config.input.video_path, config.output.path, transform, device, config.input.frame_skip)
else:
current_frame = load_image(config.input.current_frame_path, transform).to(device)
reference_frame = load_image(config.input.reference_frame_path, transform).to(device)

with torch.no_grad():
reconstructed_frame = model(current_frame, reference_frame)
process_video(model, config.input.video_path, config.output.path, transform, device, config.input.frame_skip)

save_output(reconstructed_frame, config.output.path)

if __name__ == "__main__":
main()
6 changes: 6 additions & 0 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,12 @@ def style_mixing(self, t_c, t_r):
return t_c_mixed, t_r_mixed
return t_c, t_r

def tokens(self, x_current, x_reference):
f_r = self.dense_feature_encoder(x_reference)
t_r = self.latent_token_encoder(x_reference)
t_c = self.latent_token_encoder(x_current)
return f_r,t_r,t_c

def decode_latent_tokens(self,f_r,t_r,t_c):
mix_t_c = t_c
mix_t_r = t_r
Expand Down

0 comments on commit ca2cec9

Please sign in to comment.