Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ValueError: converting frame count is not supported. #81

Open
Aryan9101 opened this issue Feb 20, 2024 · 12 comments
Open

ValueError: converting frame count is not supported. #81

Aryan9101 opened this issue Feb 20, 2024 · 12 comments

Comments

@Aryan9101
Copy link

I want to track points across many short horizon videos of, say, dimension (8, 256, 256, 3). Suppose I have B such videos and I want to track the same N points in each video. Then, the input that I pass into TAPIR is of dimensions (B, 8, 256, 256, 3) for the frames and (B, N, 3) for the queries. Suppose B = 4 and N = 32 for the sake of an example.

However, this always seems to give me a ValueError: converting frame count is not supported. Any ideas of what might be happening here?

@Aryan9101
Copy link
Author

On second thought, is this because batched inference might not be supported? I am using the PyTorch model btw.

@cdoersch
Copy link
Collaborator

Batched inference is supported.

The error you're getting, however, suggests that you're taking a query point from one video and using transforms.convert_grid_coordinates to convert it to a different framerate for use with the same video. This seldom makes sense, and it's easy to get it wrong, which is why we prevent the library function from doing it and force you to do it manually (are you converting framerate or cropping the video? The required operation will be different.)

In other words, I think I need more details about how why you're converting across framerates in order to be able to help you.

@Aryan9101
Copy link
Author

Aryan9101 commented Mar 1, 2024 via email

@cdoersch
Copy link
Collaborator

cdoersch commented Mar 1, 2024

Could you post a full stack trace? I don't see anything about this setup that shouldn't work.

@Aryan9101
Copy link
Author

Aryan9101 commented Mar 3, 2024

Sure! Here is what the stacktrace looks like (I did wrap the model with DataParallel to allow for faster multi-batch inference)

  File "/home/aryanjain/miniconda3/envs/fvm/lib/python3.9/site-packages/torch/nn/parallel/parallel_apply.py", line 85, in _worker
    output = module(*input, **kwargs)
  File "/home/aryanjain/miniconda3/envs/fvm/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/aryanjain/miniconda3/envs/fvm/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/aryanjain/tapnet/torch/tapir_model.py", line 178, in forward
    query_features = self.get_query_features(
  File "/home/aryanjain/tapnet/torch/tapir_model.py", line 260, in get_query_features
    position_in_grid = utils.convert_grid_coordinates(
  File "/home/aryanjain/tapnet/torch/utils.py", line 224, in convert_grid_coordinates
    raise ValueError('converting frame count is not supported.')
ValueError: converting frame count is not supported.

Here is what the code for calling TAPIR looks like

def preprocess_frames(frames):
  """Preprocess frames to model inputs.

  Args:
    frames: [num_frames, height, width, 3], [0, 255], np.uint8

  Returns:
    frames: [num_frames, height, width, 3], [-1, 1], np.float32
  """
  frames = frames.float()
  frames = frames / 255 * 2 - 1
  return frames

def forward_tapir(tapir, video, queries, track_cfg):
    """
    video: b t c h w
    queries: b n 3
    """
    b, _, _, _, _ = video.shape

    video = rearrange(video, 'b t c h w -> (b t) c h w')
    video = T.Resize((256, 256), antialias=True)(video)

    # Preprocess video to match model input format
    video = rearrange(video, 'd c h w -> d h w c')
    video = preprocess_frames(video)
    video = rearrange(video, '(b t) h w c -> b t h w c', b=b)
    queries = queries.float() * (256.0/224.0) # scale grid from 224 x 224 to 256 x 256
    print(video.shape, video.dtype, video.min(), video.max())
    print(queries.shape, queries.dtype, queries.min(), queries.max())

    tracks = []
    batch_size = track_cfg.tracker_batch_size * track_cfg.tracker_gpus
    for i in range(math.ceil(b / batch_size)):
        video_batch = video[i*batch_size:(i+1)*batch_size]
        queries_batch = queries[i*batch_size:(i+1)*batch_size]
        outputs = tapir(video_batch, queries_batch) # b t n 2
        tracks_batch = outputs['tracks']
        tracks.append(tracks_batch)
    tracks = torch.cat(tracks, dim=0)
    tracks = tracks.permute(0, 2, 1, 3)
    return tracks

@Aryan9101
Copy link
Author

Aryan9101 commented Mar 11, 2024

@cdoersch sorry for the ping! Do you have any insight into what might be going wrong here? Let me know if there is any other information I can provide.

@justachetan
Copy link

I think the issue is that TAPIR thinks that the time frame index in the query t in [t, y, x] is from the original video while the new video is a shorter slice. You need to update the t value in each query point according to the new video slice before passing it to the model

@justachetan
Copy link

But @cdoersch I did a simpler experiment where I am just trying to pass 4 videos from Kubric (tensor dimensions are [4, 256, 256, 3]) to the model as a batch with the corresponding query points (tensor dim: [4, 256, 3]) and am still running into the same error

Here is the full stacktrace:

Traceback (most recent call last):
  File "/home/ac2538/work/tapnet.torch/train.py", line 41, in <module>
    trainer.train()
  File "/home/ac2538/work/tapnet.torch/trainers/base_trainer.py", line 91, in train
    self.run_epoch(epoch_idx, phase='train')
  File "/home/ac2538/work/tapnet.torch/trainers/base_trainer.py", line 159, in run_epoch
    outputs, losses = self.process_batch(all_inputs)
  File "/home/ac2538/work/tapnet.torch/trainers/base_trainer.py", line 224, in process_batch
    outputs[dset_name] = self.model(inputs[dset_name])
  File "/home/ac2538/.conda/envs/tapnet2/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/ac2538/.conda/envs/tapnet2/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ac2538/work/tapnet.torch/models/network_module.py", line 42, in forward
    return self.model(batch.video, batch.query_points)
  File "/home/ac2538/.conda/envs/tapnet2/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/ac2538/.conda/envs/tapnet2/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ac2538/work/tapnet.torch/models/tapnet_torch/tapir_model.py", line 179, in forward
    query_features = self.get_query_features(
  File "/home/ac2538/work/tapnet.torch/models/tapnet_torch/tapir_model.py", line 261, in get_query_features
    position_in_grid = utils.convert_grid_coordinates(
  File "/home/ac2538/work/tapnet.torch/models/tapnet_torch/utils.py", line 209, in convert_grid_coordinates
    if isinstance(input_grid_size, tuple):
ValueError: converting frame count is not supported.

Specifically,

File "/home/ac2538/work/tapnet.torch/models/network_module.py", line 42, in forward
    return self.model(batch.video, batch.query_points)

is where I am making a forward pass with a batch.

I think the issue is here:

video_resize = video_resize.view(n*f, h, w, c).permute(0, 3, 1, 2)

You flatten the video batch along the first two dimensions, but never unflatten it. Hence, if the batch size is 1, it will be fine, but if you have a batch size > 1, then the model thinks that you have a single video with batch size * T frames (where T is the number of frames per video)

@justachetan
Copy link

@cdoersch just bumping this up again, could you please confirm that this is an actual bug in the PyTorch version of the code which prevents batched forward passes? Thanks!

@Aryan9101
Copy link
Author

@cdoersch bumping the bump above!

@justachetan
Copy link

justachetan commented Sep 2, 2024

@cdoersch here is a corrected version of the function. Could you please take a look and confirm if this is correct?

def get_feature_grids(
      self,
      video: torch.Tensor,
      is_training: bool,
      refinement_resolutions: Optional[List[Tuple[int, int]]] = None,
  ) -> FeatureGrids:
    """Computes feature grids.

    Args:
      video: A 5-D tensor representing a batch of sequences of images.
      is_training: Whether we are training.
      refinement_resolutions: A list of (height, width) tuples. Refinement will
        be repeated at each specified resolution, to achieve high accuracy on
        resolutions higher than what TAPIR was trained on. If None, reasonable
        refinement resolutions will be inferred from the input video size.

    Returns:
      A FeatureGrids object containing the required features for every
      required resolution. Note that there will be one more feature grid
      than there are refinement_resolutions, because there is always a
      feature grid computed for TAP-Net initialization.
    """
    del is_training
    if refinement_resolutions is None:
      refinement_resolutions = utils.generate_default_resolutions(
          video.shape[2:4], self.initial_resolution
      )

    all_required_resolutions = [self.initial_resolution]
    all_required_resolutions.extend(refinement_resolutions)

    feature_grid = []
    hires_feats = []
    resize_im_shape = []
    curr_resolution = (-1, -1)

    latent = None
    hires = None
    video_resize = None
    for resolution in all_required_resolutions:

      if resolution[0] % 8 != 0 or resolution[1] % 8 != 0:
        raise ValueError('Image resolution must be a multiple of 8.')

      if not utils.is_same_res(curr_resolution, resolution):
        if utils.is_same_res(curr_resolution, video.shape[-3:-1]):
          video_resize = video
        else:
          video_resize = utils.bilinear(video, resolution)

        curr_resolution = resolution
        n, f, h, w, c = video_resize.shape
        video_resize = video_resize.view(n*f, h, w, c).permute(0, 3, 1, 2)

        if self.feature_extractor_chunk_size > 0:
          latent_list = []
          hires_list = []
          chunk_size = self.feature_extractor_chunk_size
          for start_idx in range(0, video_resize.shape[0], chunk_size):
            video_chunk = video_resize[start_idx:start_idx + chunk_size]
            resnet_out = self.resnet_torch(video_chunk)
            u3 = resnet_out['resnet_unit_3'].permute(0, 2, 3, 1)
            latent_list.append(u3)
            u1 = resnet_out['resnet_unit_1'].permute(0, 2, 3, 1)
            hires_list.append(u1)
          latent = torch.cat(latent_list, dim=0)
          hires = torch.cat(hires_list, dim=0)

        else:
          resnet_out = self.resnet_torch(video_resize)
          latent = resnet_out['resnet_unit_3'].permute(0, 2, 3, 1)
          hires = resnet_out['resnet_unit_1'].permute(0, 2, 3, 1)
        
        if self.extra_convs:
          latent = self.extra_convs(latent)

        latent = latent / torch.sqrt(
            torch.maximum(
                torch.sum(torch.square(latent), axis=-1, keepdims=True),
                torch.tensor(1e-12, device=latent.device),
            )
        )
        hires = hires / torch.sqrt(
            torch.maximum(
                torch.sum(torch.square(hires), axis=-1, keepdims=True),
                torch.tensor(1e-12, device=hires.device),
            )
        )
      
      # import ipdb; ipdb.set_trace()
      if latent is not None and latent.dim() == 4:
        latent = torch.unflatten(latent, dim=0, sizes=(n, f))
      if hires is not None and hires.dim() == 4:
        hires = torch.unflatten(hires, dim=0, sizes=(n, f))
      
      feature_grid.append(latent)
      hires_feats.append(hires)
      
      # NOTE (ac): old code from Deepmind
      # feature_grid.append(latent[None, ...])
      # hires_feats.append(hires[None, ...])
      resize_im_shape.append(video_resize.shape[2:4])

    return FeatureGrids(
        tuple(feature_grid), tuple(hires_feats), tuple(resize_im_shape)
    )

@yangyi02
Copy link
Collaborator

yangyi02 commented Sep 3, 2024

Hi all,

I tried to reproduce your error according to
"I have a long video of length, say, (100, 224, 224, 3). I reshape it to (100, 256, 256, 3) and randomly sample a bunch of 8-frame clips from it of shape (8, 256, 256, 3) — suppose I sample, say, 16 clips per video. As for my query points, I pick N points that are spaced around a grid and use the same set of N query points for each video. Therefore, my batched input to the model is a video of dimension (16, 8, 256, 256, 3) and queries of dimension (16, N, 3). The video pixels are in the range (0, 255) and normalized to (-1, 1) and the query points are in the range (0, 255) (since the height and width of each image is 256)."

but I could not.

Here is what I adjust in https://colab.sandbox.google.com/github/deepmind/tapnet/blob/master/colabs/torch_tapir_demo.ipynb

frames = media.resize_video(video, (resize_height, resize_width))
query_points = sample_random_points(0, frames.shape[1], frames.shape[2], num_points)
print('video shape before clipping:', frames.shape)
print('query_points shape before clipping:', query_points.shape)
frames = torch.tensor(frames).to(device)
query_points = torch.tensor(query_points).to(device)

frames = frames.reshape(5, 10, 256, 256, 3)
query_points = query_points[None].repeat(5, 1, 1)
print('video shape after clipping:', frames.shape)
print('query_points shape after clipping:', query_points.shape)

# Preprocess video to match model inputs format
frames = preprocess_frames(frames)
query_points = query_points.float()

# Model inference
outputs = model(frames, query_points)
print('inference done')

The output looks reasonable to me.

video shape before clipping: (50, 256, 256, 3)
query_points shape before clipping: (50, 3)
video shape after clipping: torch.Size([5, 10, 256, 256, 3])
query_points shape after clipping: torch.Size([5, 50, 3])
inference done

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants