Skip to content

Commit

Permalink
Fix batched inference
Browse files Browse the repository at this point in the history
  • Loading branch information
dajes committed Nov 21, 2023
1 parent 29a1bff commit 4815c4b
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 4 deletions.
5 changes: 2 additions & 3 deletions interpolator.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,8 @@ def debug_forward(self, x0, x1, batch_dt) -> Dict[str, List[torch.Tensor]]:
# Note: In film_net we fix time to be 0.5, and recursively invoke the interpo-
# lator for multi-frame interpolation. Below, we create a constant tensor of
# shape [B]. We use the `time` tensor to infer the batch size.
backward_flow = util.multiply_pyramid(backward_flow_pyramid, batch_dt[:, 0])
forward_flow = util.multiply_pyramid(forward_flow_pyramid, 1 - batch_dt[:, 0])
backward_flow = util.multiply_pyramid(backward_flow_pyramid, batch_dt)
forward_flow = util.multiply_pyramid(forward_flow_pyramid, 1 - batch_dt)

pyramids_to_warp = [
util.concatenate_pyramids(image_pyramids[0][:self.fusion_pyramid_levels],
Expand Down Expand Up @@ -154,6 +154,5 @@ def debug_forward(self, x0, x1, batch_dt) -> Dict[str, List[torch.Tensor]]:
'backward_flow_pyramid': backward_flow_pyramid,
}

@torch.jit.export
def forward(self, x0, x1, batch_dt) -> torch.Tensor:
return self.debug_forward(x0, x1, batch_dt)['image'][0]
2 changes: 1 addition & 1 deletion util.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def multiply_pyramid(pyramid: List[torch.Tensor],
# the batch of images from BxHxWxC-format to CxHxWxB. This can then be
# multiplied with a batch of scalars, then we transpose back to the standard
# BxHxWxC form.
return [image * scalar for image in pyramid]
return [image * scalar[..., None, None] for image in pyramid]


def flow_pyramid_synthesis(
Expand Down

0 comments on commit 4815c4b

Please sign in to comment.