diff --git a/README.md b/README.md index a04b2fa..cedd53d 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,11 @@ +**This is the clone of the [TAPIR Repository](https://github.com/google-deepmind/tapnet) that addresses the Standard TAPIR model compatibility with Torchscript, see [PR#85](https://github.com/google-deepmind/tapnet/pull/85). Only _tapir_model.py,_ _nets.py_ and _utils.py_ from _torch_ directory are updated.** + +**It is not yet aligned with the version 2 of the model [bootstapir_checkpoint_v2.pt](https://storage.googleapis.com/dm-tapnet/bootstap/bootstapir_checkpoint_v2.pt), only with the original version [bootstapir_checkpoint.pt](https://storage.googleapis.com/dm-tapnet/bootstap/bootstapir_checkpoint.pt).** + +**Online TAPIR is not yet supported.** + +--- + # Tracking Any Point (TAP) [[`TAP-Vid`](https://tapvid.github.io/)] [[`TAPIR`](https://deepmind-tapir.github.io/)] [[`RoboTAP`](https://robotap.github.io/)] [[`Blog Post`](https://deepmind-tapir.github.io/blogpost.html)] [[`BootsTAP`](https://arxiv.org/abs/2402.00847)] diff --git a/torch/nets.py b/torch/nets.py index 4650982..71b7c8d 100644 --- a/torch/nets.py +++ b/torch/nets.py @@ -57,7 +57,7 @@ def forward(self, x): x = x.permute(0, 3, 1, 2) prev_frame = torch.cat([x[0:1], x[:-1]], dim=0) next_frame = torch.cat([x[1:], x[-1:]], dim=0) - resid = torch.cat([x, prev_frame, next_frame], axis=1) + resid = torch.cat([x, prev_frame, next_frame], dim=1) resid = self.conv(resid) resid = F.gelu(resid, approximate='tanh') x += self.conv_1(resid) @@ -198,10 +198,20 @@ def forward(self, x): x = self.linear_1(x) return x +class DummyModel: + + def __init__(self): + pass + + def forward(self): + return torch.tensor(0) + + def __call__(self, input): + return self.forward() class BlockV2(nn.Module): """ResNet V2 block.""" - + def __init__( self, channels_in: int, @@ -223,14 +233,16 @@ def __init__( self.use_projection = use_projection if self.use_projection: - self.proj_conv = nn.Conv2d( + self.proj_conv = nn.Conv2d( in_channels=channels_in, out_channels=channels_out, kernel_size=1, stride=stride, padding=0, bias=False, - ) + ) + else: + self.proj_conv = DummyModel() self.bn_0 = nn.InstanceNorm2d( num_features=channels_in, diff --git a/torch/tapir_model.py b/torch/tapir_model.py index 327f264..c972500 100644 --- a/torch/tapir_model.py +++ b/torch/tapir_model.py @@ -39,9 +39,10 @@ class FeatureGrids(NamedTuple): resolution. """ - lowres: Sequence[torch.Tensor] - hires: Sequence[torch.Tensor] - resolutions: Sequence[Tuple[int, int]] + # see https://pytorch.org/docs/stable/jit_language_reference.html#supported-type for TorchScript supported types, Sequence is not supported + lowres: list[torch.Tensor] + hires: list[torch.Tensor] + resolutions: list[Tuple[int, int]] class QueryFeatures(NamedTuple): @@ -61,10 +62,14 @@ class QueryFeatures(NamedTuple): resolution. """ - lowres: Sequence[torch.Tensor] - hires: Sequence[torch.Tensor] - resolutions: Sequence[Tuple[int, int]] + lowres: list[torch.Tensor] # not Sequence[torch.Tensor] + hires: list[torch.Tensor] # not Sequence[torch.Tensor] + resolutions: list[Tuple[int, int]] # not Sequence[Tuple[int, int]] +class OutputAll(NamedTuple): + occlusion: list[torch.Tensor] + tracks: list[torch.Tensor] + expected_dist: list[torch.Tensor] class TAPIR(nn.Module): """TAPIR model.""" @@ -83,7 +88,7 @@ def __init__( initial_resolution: Tuple[int, int] = (256, 256), blocks_per_group: Sequence[int] = (2, 2, 2, 2), feature_extractor_chunk_size: int = 10, - extra_convs: bool = True, + extra_convs_b: bool = True, ): super().__init__() @@ -125,20 +130,22 @@ def __init__( input_dim = dim + (self.pyramid_level + 2) * 49 self.torch_pips_mixer = nets.PIPSMLPMixer(input_dim, dim) - if extra_convs: + self.extra_convs_b = extra_convs_b + if extra_convs_b: self.extra_convs = nets.ExtraConvs() else: - self.extra_convs = None + self.extra_convs = nets.DummyModel() def forward( self, video: torch.Tensor, query_points: torch.Tensor, is_training: bool = False, - query_chunk_size: Optional[int] = 64, + query_chunk_size: int = 64, get_query_feats: bool = False, refinement_resolutions: Optional[List[Tuple[int, int]]] = None, - ) -> Mapping[str, torch.Tensor]: + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + # -> Mapping[str, torch.Tensor]: - not friendly with torch.jit.script(), fails with Unknown type constructor Mapping """Runs a forward pass of the model. Args: @@ -165,6 +172,12 @@ def forward( expected_dist: uncertainty estimate logits, of shape [batch, num_queries, num_frames], where higher indicates more likely to be far from the correct answer. + + Usage: + outputs = model(frames, query_points) + occlusions = outputs[0][0] + tracks = outputs[1][0] + expected_dist = outputs[2][0] """ if get_query_feats: raise ValueError('Get query feats not supported in TAPIR.') @@ -193,18 +206,12 @@ def forward( ) p = self.num_pips_iter - out = dict( - occlusion=torch.mean( - torch.stack(trajectories['occlusion'][p::p]), dim=0 - ), - tracks=torch.mean(torch.stack(trajectories['tracks'][p::p]), dim=0), - expected_dist=torch.mean( - torch.stack(trajectories['expected_dist'][p::p]), dim=0 - ), - unrefined_occlusion=trajectories['occlusion'][:-1], - unrefined_tracks=trajectories['tracks'][:-1], - unrefined_expected_dist=trajectories['expected_dist'][:-1], - ) + + # change the code below to make it torch.jit.trace/torch.jit.script friendly + out = (torch.mean(torch.stack(trajectories.occlusion[p::p]), dim=0), + torch.mean(torch.stack(trajectories.tracks[p::p]), dim=0), + torch.mean(torch.stack(trajectories.expected_dist[p::p]), dim=0) + ) return out @@ -215,7 +222,7 @@ def get_query_features( query_points: torch.Tensor, feature_grids: Optional[FeatureGrids] = None, refinement_resolutions: Optional[List[Tuple[int, int]]] = None, - ) -> QueryFeatures: + ) -> QueryFeatures: """Computes query features, which can be used for estimate_trajectories. Args: @@ -259,14 +266,14 @@ def get_query_features( continue position_in_grid = utils.convert_grid_coordinates( query_points, - shape[1:4], - feature_grid[i].shape[1:4], + torch.tensor(shape[1:4], device=query_points.device), + torch.tensor(feature_grid[i].shape[1:4], device=query_points.device), coordinate_format='tyx', ) position_in_grid_hires = utils.convert_grid_coordinates( query_points, - shape[1:4], - hires_feats[i].shape[1:4], + torch.tensor(shape[1:4], device=query_points.device), + torch.tensor(hires_feats[i].shape[1:4], device=query_points.device), coordinate_format='tyx', ) @@ -281,7 +288,7 @@ def get_query_features( query_feats.append(interp_features) return QueryFeatures( - tuple(query_feats), tuple(hires_query_feats), tuple(resize_im_shape) + query_feats, hires_query_feats, resize_im_shape ) def get_feature_grids( @@ -308,27 +315,26 @@ def get_feature_grids( """ del is_training if refinement_resolutions is None: - refinement_resolutions = utils.generate_default_resolutions( - video.shape[2:4], self.initial_resolution - ) + refinement_resolutions = utils.generate_default_resolutions((video.shape[2], video.shape[3]), self.initial_resolution) all_required_resolutions = [self.initial_resolution] all_required_resolutions.extend(refinement_resolutions) feature_grid = [] hires_feats = [] - resize_im_shape = [] + resize_im_shape = [(int(0), int(0))] # in the torch.jit.script() context, annotated assignments without assigned value aren't supported + resize_im_shape.clear() curr_resolution = (-1, -1) - latent = None - hires = None - video_resize = None + latent = torch.empty((0)) + hires = torch.empty((0)) + video_resize = torch.empty((0)) for resolution in all_required_resolutions: if resolution[0] % 8 != 0 or resolution[1] % 8 != 0: raise ValueError('Image resolution must be a multiple of 8.') if not utils.is_same_res(curr_resolution, resolution): - if utils.is_same_res(curr_resolution, video.shape[-3:-1]): + if utils.is_same_res(curr_resolution, (video.shape[-3], video.shape[-2])): video_resize = video else: video_resize = utils.bilinear(video, resolution) @@ -358,39 +364,55 @@ def get_feature_grids( latent = resnet_out['resnet_unit_3'].permute(0, 2, 3, 1).detach() hires = resnet_out['resnet_unit_1'].permute(0, 2, 3, 1).detach() - if self.extra_convs: + if self.extra_convs_b: latent = self.extra_convs(latent) + s1 = torch.square(latent) + tmp1 = torch.sum(s1, dim=-1) #, keepdims=True) # https://github.com/pytorch/pytorch/issues/47955 - keepdims is no longer supported by JIT + tmp1 = torch.unsqueeze(tmp1, -1) + + s2 = torch.square(hires) + tmp2 = torch.sum(s2, dim=-1) # , keepdims=True) + tmp2 = torch.unsqueeze(tmp2, -1) + latent = latent / torch.sqrt( torch.maximum( - torch.sum(torch.square(latent), axis=-1, keepdims=True), + tmp1, torch.tensor(1e-12, device=latent.device), ) ) hires = hires / torch.sqrt( torch.maximum( - torch.sum(torch.square(hires), axis=-1, keepdims=True), + tmp2, torch.tensor(1e-12, device=hires.device), ) ) feature_grid.append(latent[None, ...]) hires_feats.append(hires[None, ...]) - resize_im_shape.append(video_resize.shape[2:4]) + resize_im_shape.append((video_resize.shape[2], video_resize.shape[3])) return FeatureGrids( - tuple(feature_grid), tuple(hires_feats), tuple(resize_im_shape) + feature_grid, hires_feats, resize_im_shape + ) + + def train2orig(self, x, video_size: list[int])-> torch.Tensor: + return utils.convert_grid_coordinates( + x, + torch.tensor(self.initial_resolution[::-1], device=x.device),#self.initial_resolution[::-1], + torch.tensor(video_size[::-1], device=x.device),#video_size[::-1], + coordinate_format='xy', ) def estimate_trajectories( self, - video_size: Tuple[int, int], + video_size: list[int], is_training: bool, feature_grids: FeatureGrids, query_features: QueryFeatures, query_points_in_video: Optional[torch.Tensor], - query_chunk_size: Optional[int] = None, - ) -> Mapping[str, Any]: + query_chunk_size: int = 64, + ) -> OutputAll: """Estimates trajectories given features for a video and query features. Args: @@ -418,30 +440,16 @@ def estimate_trajectories( """ del is_training - def train2orig(x): - return utils.convert_grid_coordinates( - x, - self.initial_resolution[::-1], - video_size[::-1], - coordinate_format='xy', - ) - - occ_iters = [] - pts_iters = [] - expd_iters = [] + occ_iters = [[torch.empty((0, 0), dtype=torch.float32)]] #[[torch.tensor([])]] + pts_iters = [[torch.empty((0, 0), dtype=torch.float32)]] #[[torch.tensor([])]] + expd_iters = [[torch.empty((0, 0), dtype=torch.float32)]] #[[torch.tensor([])]] + occ_iters.clear(); pts_iters.clear(); expd_iters.clear() num_iters = self.num_pips_iter * (len(feature_grids.lowres) - 1) for _ in range(num_iters + 1): occ_iters.append([]) pts_iters.append([]) expd_iters.append([]) - infer = functools.partial( - self.tracks_from_cost_volume, - im_shp=feature_grids.lowres[0].shape[0:2] - + self.initial_resolution - + (3,), - ) - num_queries = query_features.lowres[0].shape[1] perm = torch.randperm(num_queries) inv_perm = torch.zeros_like(perm) @@ -458,23 +466,26 @@ def train2orig(x): num_frames = feature_grids.lowres[0].shape[1] infer_query_points = utils.convert_grid_coordinates( infer_query_points, - (num_frames,) + video_size, - (num_frames,) + self.initial_resolution, + torch.tensor((num_frames,) + video_size, device=infer_query_points.device), + torch.tensor((num_frames,) + self.initial_resolution, device=infer_query_points.device), coordinate_format='tyx', ) else: infer_query_points = None - points, occlusion, expected_dist = infer( + points, occlusion, expected_dist = self.tracks_from_cost_volume( chunk, feature_grids.lowres[0], infer_query_points, + im_shp=list(feature_grids.lowres[0].shape[0:2] + self.initial_resolution + (3,)) ) - pts_iters[0].append(train2orig(points)) + + pts_iters[0].append(self.train2orig(points, video_size)) occ_iters[0].append(occlusion) expd_iters[0].append(expected_dist) - mixer_feats = None + mixer_feats_none = True + mixer_feats = torch.empty((0, 0), dtype=torch.float32) for i in range(num_iters): feature_level = i // self.num_pips_iter + 1 queries = [ @@ -504,19 +515,22 @@ def train2orig(x): points, occlusion, expected_dist, - orig_hw=self.initial_resolution, - last_iter=mixer_feats, + orig_hw=list(self.initial_resolution), + last_iter = None if mixer_feats_none else mixer_feats, mixer_iter=i, - resize_hw=feature_grids.resolutions[feature_level], + resize_hw=list(feature_grids.resolutions[feature_level]), ) points, occlusion, expected_dist, mixer_feats = refined - pts_iters[i + 1].append(train2orig(points)) + pts_iters[i + 1].append(self.train2orig(points, video_size)) occ_iters[i + 1].append(occlusion) expd_iters[i + 1].append(expected_dist) if (i + 1) % self.num_pips_iter == 0: - mixer_feats = None + mixer_feats_none = True + mixer_feats = torch.empty((0, 0), dtype=torch.float32) expected_dist = expd_iters[0][-1] occlusion = occ_iters[0][-1] + else: + mixer_feats_none = False occlusion = [] points = [] @@ -526,26 +540,21 @@ def train2orig(x): points.append(torch.cat(pts_iters[i], dim=1)[:, inv_perm]) expd.append(torch.cat(expd_iters[i], dim=1)[:, inv_perm]) - out = dict( - occlusion=occlusion, - tracks=points, - expected_dist=expd, - ) - return out + return OutputAll(occlusion, points, expd) def refine_pips( self, - target_feature, - frame_features, - pyramid, - pos_guess, - occ_guess, - expd_guess, - orig_hw, - last_iter=None, - mixer_iter=0.0, - resize_hw=None, - ): + target_feature : list[torch.Tensor], + frame_features : Optional[torch.Tensor], + pyramid : list[torch.Tensor], + pos_guess : torch.Tensor, + occ_guess : torch.Tensor, + expd_guess : torch.Tensor, + orig_hw : list[int], + last_iter : Optional[torch.Tensor], + mixer_iter : int, + resize_hw : list[int], + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: del frame_features del mixer_iter orig_h, orig_w = orig_hw @@ -555,10 +564,12 @@ def refine_pips( for pyridx, (query, grid) in enumerate(zip(target_feature, pyramid)): # note: interp needs [y,x] coords = utils.convert_grid_coordinates( - pos_guess, (orig_w, orig_h), grid.shape[-2:-4:-1] + pos_guess, + torch.tensor((orig_w, orig_h), device=pos_guess.device), + torch.tensor(grid.shape[-2:-4:-1], device=pos_guess.device), ) coords = torch.flip(coords, dims=(-1,)) - last_iter_query = None + last_iter_query = torch.empty((0, 0), dtype=torch.float32) if last_iter is not None: if pyridx == 0: last_iter_query = last_iter[..., : self.highres_dim] @@ -574,7 +585,7 @@ def refine_pips( neighborhood = utils.map_coordinates_2d(grid, coords2) # s is spatial context size - if last_iter_query is None: + if last_iter is None: patches = torch.einsum('bnfsc,bnc->bnfs', neighborhood, query) else: patches = torch.einsum( @@ -591,7 +602,7 @@ def refine_pips( # mlp_input is batch, num_points, num_chunks, frames_per_chunk, channels if last_iter is None: - both_feature = torch.cat([target_feature[0], target_feature[1]], axis=-1) + both_feature = torch.cat([target_feature[0], target_feature[1]], dim=-1) mlp_input_features = torch.tile( both_feature.unsqueeze(2), (1, 1, corrs_chunked.shape[-2], 1) ) @@ -608,16 +619,20 @@ def refine_pips( mlp_input_features, corrs_chunked, ], - axis=-1, + dim=-1, ) - x = utils.einshape('bnfc->(bn)fc', mlp_input) + + #x = utils.einshape('bnfc->(bn)fc', mlp_input) + x = torch.reshape(mlp_input, (mlp_input.shape[0] * mlp_input.shape[1], mlp_input.shape[2], mlp_input.shape[3])) res = self.torch_pips_mixer(x.float()) - res = utils.einshape('(bn)fc->bnfc', res, b=mlp_input.shape[0]) + #res = utils.einshape('(bn)fc->bnfc', res, b=mlp_input.shape[0]) + res = torch.reshape(res, (mlp_input.shape[0], int(res.shape[0] / mlp_input.shape[0]), res.shape[1], res.shape[2])) + t = res[..., :2].detach() pos_update = utils.convert_grid_coordinates( - res[..., :2].detach(), - (resized_w, resized_h), - (orig_w, orig_h), + t, + torch.tensor((resized_w, resized_h), device=t.device), + torch.tensor((orig_w, orig_h), device=t.device), ) return ( pos_update + pos_guess, @@ -631,7 +646,7 @@ def tracks_from_cost_volume( interp_feature: torch.Tensor, feature_grid: torch.Tensor, query_points: Optional[torch.Tensor], - im_shp=None, + im_shp: List[int], ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Converts features into tracks by computing a cost volume. @@ -667,7 +682,9 @@ def tracks_from_cost_volume( shape = cost_volume.shape batch_size, num_points = cost_volume.shape[1:3] - cost_volume = utils.einshape('tbnhw->(tbn)hw1', cost_volume) + #cost_volume = utils.einshape('tbnhw->(tbn)hw1', cost_volume) + cost_volume = torch.reshape(cost_volume, (shape[0] * shape[1] * shape[2], shape[3], shape[4])) + cost_volume = torch.unsqueeze(cost_volume, -1) cost_volume = cost_volume.permute(0, 3, 1, 2) occlusion = mods['hid1'](cost_volume) @@ -675,11 +692,14 @@ def tracks_from_cost_volume( pos = mods['hid2'](occlusion) pos = pos.permute(0, 2, 3, 1) - pos_rshp = utils.einshape('(tb)hw1->t(b)hw1', pos, t=shape[0]) - - pos = utils.einshape( - 't(bn)hw1->bnthw', pos_rshp, b=batch_size, n=num_points - ) + #pos_rshp = utils.einshape('(tb)hw1->t(b)hw1', pos, t=shape[0]) + pos_rshp = torch.reshape(pos, (shape[0], int(pos.shape[0] / shape[0]), pos.shape[1], pos.shape[2], 1)) + + #pos = utils.einshape( + # 't(bn)hw1->bnthw', pos_rshp, b=batch_size, n=num_points + #) + pos = pos_rshp.squeeze(-1).permute(1, 0, 2, 3) + pos = torch.reshape(pos, (batch_size, int(pos.shape[0]/batch_size), pos.shape[1], pos.shape[2], pos.shape[3])) pos_sm = pos.reshape(pos.size(0), pos.size(1), pos.size(2), -1) softmaxed = F.softmax(pos_sm * self.softmax_temperature, dim=-1) pos = softmaxed.view_as(pos) @@ -694,10 +714,17 @@ def tracks_from_cost_volume( occlusion = torch.nn.functional.relu(occlusion) occlusion = mods['occ_out'](occlusion) - expected_dist = utils.einshape( - '(tbn)1->bnt', occlusion[..., 1:2], n=shape[2], t=shape[0] - ) - occlusion = utils.einshape( - '(tbn)1->bnt', occlusion[..., 0:1], n=shape[2], t=shape[0] - ) + #expected_dist = utils.einshape( + # '(tbn)1->bnt', occlusion[..., 1:2], n=shape[2], t=shape[0] + #) + occlusion1 = occlusion[..., 1:2].squeeze(-1) + expected_dist = torch.reshape(occlusion1, (shape[0], int(torch.numel(occlusion1) / (shape[2]*shape[0])), shape[2])) + expected_dist = expected_dist.permute(1, 2, 0) + + #occlusion = utils.einshape( + # '(tbn)1->bnt', occlusion[..., 0:1], n=shape[2], t=shape[0] + #) + occlusion0 = occlusion[..., 0:1].squeeze(-1) + occlusion = torch.reshape(occlusion0, (shape[0], int(torch.numel(occlusion0) / (shape[2]*shape[0])), shape[2])) + occlusion = occlusion.permute(1, 2, 0) return points, occlusion, expected_dist diff --git a/torch/utils.py b/torch/utils.py index f518a33..41809ed 100644 --- a/torch/utils.py +++ b/torch/utils.py @@ -15,13 +15,13 @@ """Pytorch model utilities.""" -from typing import Any, Sequence, Union -from einshape.src import abstract_ops -from einshape.src import backend +from typing import Any, Union, Optional, List +#from einshape.src import abstract_ops +#from einshape.src import backend import numpy as np import torch import torch.nn.functional as F - +import math def bilinear(x: torch.Tensor, resolution: tuple[int, int]) -> torch.Tensor: """Resizes a 5D tensor using bilinear interpolation. @@ -113,7 +113,7 @@ def map_coordinates_2d( return out -def soft_argmax_heatmap_batched(softmax_val, threshold=5): +def soft_argmax_heatmap_batched(softmax_val, threshold:float=5): """Test if two image resolutions are the same.""" b, h, w, d1, d2 = softmax_val.shape y, x = torch.meshgrid( @@ -126,16 +126,10 @@ def soft_argmax_heatmap_batched(softmax_val, threshold=5): argmax_pos = torch.argmax(softmax_val_flat, dim=-1) pos = coords.reshape(-1, 2)[argmax_pos] - valid = ( - torch.sum( - torch.square( - coords[None, None, None, :, :, :] - pos[:, :, :, None, None, :] - ), - dim=-1, - keepdims=True, - ) - < threshold**2 - ) + + tmp1 = torch.square(coords[None, None, None, :, :, :] - pos[:, :, :, None, None, :]) + tmp2 = torch.unsqueeze(torch.sum(tmp1, dim=-1), -1) + valid = (tmp2 < threshold**2) weighted_sum = torch.sum( coords[None, None, None, :, :, :] @@ -152,9 +146,9 @@ def soft_argmax_heatmap_batched(softmax_val, threshold=5): def heatmaps_to_points( all_pairs_softmax, - image_shape, - threshold=5, - query_points=None, + image_shape:List[int], + threshold:float=5, + query_points:Optional[torch.Tensor]=None, ): """Convert heatmaps to points using soft argmax.""" @@ -162,18 +156,20 @@ def heatmaps_to_points( feature_grid_shape = all_pairs_softmax.shape[1:] # Note: out_points is now [x, y]; we need to divide by [width, height]. # image_shape[3] is width and image_shape[2] is height. + t1 = out_points.detach() out_points = convert_grid_coordinates( - out_points.detach(), - feature_grid_shape[3:1:-1], - image_shape[3:1:-1], + t1, + torch.tensor(feature_grid_shape[3:1:-1], device=t1.device), + torch.tensor(image_shape[3:1:-1], device=t1.device), ) assert feature_grid_shape[1] == image_shape[1] if query_points is not None: # The [..., 0:1] is because we only care about the frame index. + t2 = query_points.detach() query_frame = convert_grid_coordinates( - query_points.detach(), - image_shape[1:4], - feature_grid_shape[1:4], + t2, + torch.tensor(image_shape[1:4], device=t2.device), + torch.tensor(feature_grid_shape[1:4], device=t2.device), coordinate_format='tyx', )[..., 0:1] @@ -193,15 +189,16 @@ def heatmaps_to_points( return out_points -def is_same_res(r1, r2): +def is_same_res(r1 : tuple[int,int], r2 : tuple[int,int]): """Test if two image resolutions are the same.""" - return all([x == y for x, y in zip(r1, r2)]) + #return all([x == y for x, y in zip(r1, r2)]) + return (r1[0] == r2[0]) and (r1[1] == r2[1]) def convert_grid_coordinates( coords: torch.Tensor, - input_grid_size: Sequence[int], - output_grid_size: Sequence[int], + input_grid_size, # Sequence not supported by TorchScript, caller must convert torch.size or a tuple to torch.tensor + output_grid_size,# Sequence not supported by TorchScript, caller must convert torch.size or a tuple to torch.tensor coordinate_format: str = 'xy', ) -> torch.Tensor: """Convert grid coordinates to correct format.""" @@ -209,6 +206,10 @@ def convert_grid_coordinates( input_grid_size = torch.tensor(input_grid_size, device=coords.device) if isinstance(output_grid_size, tuple): output_grid_size = torch.tensor(output_grid_size, device=coords.device) + if not isinstance(input_grid_size, torch.Tensor): + raise ValueError(f"input_grid_size must be a torch.tensor, not {type(input_grid_size)}") + if not isinstance(output_grid_size, torch.Tensor): + raise ValueError(f"output_grid_size must be a torch.tensor, not {type(output_grid_size)}") if coordinate_format == 'xy': if input_grid_size.shape[0] != 2 or output_grid_size.shape[0] != 2: @@ -230,49 +231,7 @@ def convert_grid_coordinates( return position_in_grid - -class _JaxBackend(backend.Backend[torch.Tensor]): - """Einshape implementation for PyTorch.""" - - # https://github.com/vacancy/einshape/blob/main/einshape/src/pytorch/pytorch_ops.py - - def reshape(self, x: torch.Tensor, op: abstract_ops.Reshape) -> torch.Tensor: - return x.reshape(op.shape) - - def transpose( - self, x: torch.Tensor, op: abstract_ops.Transpose - ) -> torch.Tensor: - return x.permute(op.perm) - - def broadcast( - self, x: torch.Tensor, op: abstract_ops.Broadcast - ) -> torch.Tensor: - shape = op.transform_shape(x.shape) - for axis_position in sorted(op.axis_sizes.keys()): - x = x.unsqueeze(axis_position) - return x.expand(shape) - - -def einshape( - equation: str, value: Union[torch.Tensor, Any], **index_sizes: int -) -> torch.Tensor: - """Reshapes `value` according to the given Shape Equation. - - Args: - equation: The Shape Equation specifying the index regrouping and reordering. - value: Input tensor, or tensor-like object. - **index_sizes: Sizes of indices, where they cannot be inferred from - `input_shape`. - - Returns: - Tensor derived from `value` by reshaping as specified by `equation`. - """ - if not isinstance(value, torch.Tensor): - value = torch.tensor(value) - return _JaxBackend().exec(equation, value, value.shape, **index_sizes) - - -def generate_default_resolutions(full_size, train_size, num_levels=None): +def generate_default_resolutions(full_size : tuple[int,int], train_size : tuple[int,int]): #, num_levels : int = None): """Generate a list of logarithmically-spaced resolutions. Generated resolutions are between train_size and full_size, inclusive, with @@ -292,9 +251,13 @@ def generate_default_resolutions(full_size, train_size, num_levels=None): if all([x == y for x, y in zip(train_size, full_size)]): return [train_size] - if num_levels is None: - size_ratio = np.array(full_size) / np.array(train_size) - num_levels = int(np.ceil(np.max(np.log2(size_ratio))) + 1) + ratio0 = full_size[0] / train_size[0] + ratio1 = full_size[1] / train_size[1] + if ratio1 > ratio0: + ratio0 = ratio1 + tr = torch.tensor([ratio0]) + tr = torch.ceil(torch.log2(tr)) + num_levels = int(tr.item() + 1) if num_levels <= 1: return [train_size] @@ -305,13 +268,16 @@ def generate_default_resolutions(full_size, train_size, num_levels=None): 'Warning: output size is not a multiple of 8. Final layer ' + 'will round size down.' ) - ll_h, ll_w = train_size[0:2] + ll_h = int(train_size[0]) + ll_w = int(train_size[1]) - sizes = [] + sizes = [(int(0), int(0))] # in the torch.jit.script() context, annotated assignments without assigned value aren't supported + sizes.clear() for i in range(num_levels): size = ( - int(round((ll_h * (h / ll_h) ** (i / (num_levels - 1))) // 8)) * 8, - int(round((ll_w * (w / ll_w) ** (i / (num_levels - 1))) // 8)) * 8, + int(round( float(ll_h * (h / ll_h) ** (i / (num_levels - 1))) // 8 )) * 8, + int(round( float(ll_w * (w / ll_w) ** (i / (num_levels - 1))) // 8 )) * 8, ) sizes.append(size) return sizes +