-
Notifications
You must be signed in to change notification settings - Fork 350
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
❓ [Question] dynamo conversion failing w/ TRTInterpreter #3124
Comments
@patrick-botco Are you able to share a repro of this issue? |
yea let me get one |
@narendasan @apbose this is a stripped down portion of Meta's SAM2 (original at https://github.com/facebookresearch/segment-anything-2/blob/main/sam2/modeling/sam/prompt_encoder.py), with minor modifications of fill in # Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from typing import Optional, Tuple, Any
import torch
from torch import nn
import numpy as np
class PositionEmbeddingRandom(nn.Module):
"""
Positional encoding using random spatial frequencies.
"""
def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:
super().__init__()
if scale is None or scale <= 0.0:
scale = 1.0
self.register_buffer(
"positional_encoding_gaussian_matrix",
scale * torch.randn((2, num_pos_feats)),
)
def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
"""Positionally encode points that are normalized to [0,1]."""
# assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
coords = 2 * coords - 1
coords = coords @ self.positional_encoding_gaussian_matrix
coords = 2 * np.pi * coords
# outputs d_1 x ... x d_n x C shape
return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
def forward(self, size: Tuple[int, int]) -> torch.Tensor:
"""Generate positional encoding for a grid of the specified size."""
h, w = size
device: Any = self.positional_encoding_gaussian_matrix.device
grid = torch.ones((h, w), device=device, dtype=torch.float32)
y_embed = grid.cumsum(dim=0) - 0.5
x_embed = grid.cumsum(dim=1) - 0.5
y_embed = y_embed / h
x_embed = x_embed / w
pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1))
return pe.permute(2, 0, 1) # C x H x W
def forward_with_coords(
self, coords_input: torch.Tensor, image_size: Tuple[int, int]
) -> torch.Tensor:
"""Positionally encode points that are not normalized to [0,1]."""
coords = coords_input.clone()
coords[:, :, 0] = coords[:, :, 0] / image_size[1]
coords[:, :, 1] = coords[:, :, 1] / image_size[0]
return self._pe_encoding(coords.to(torch.float)) # B x N x C
class PromptEncoder(nn.Module):
def __init__(
self,
embed_dim: int,
image_embedding_size: Tuple[int, int],
input_image_size: Tuple[int, int],
) -> None:
"""
Encodes prompts for input to SAM's mask decoder.
Arguments:
embed_dim (int): The prompts' embedding dimension
image_embedding_size (tuple(int, int)): The spatial size of the
image embedding, as (H, W).
input_image_size (int): The padded size of the image as input
to the image encoder, as (H, W).
"""
super().__init__()
self.embed_dim = embed_dim
self.input_image_size = input_image_size
self.image_embedding_size = image_embedding_size
self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)
self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners
point_embeddings = [
nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)
]
self.point_embeddings = nn.ModuleList(point_embeddings)
self.not_a_point_embed = nn.Embedding(1, embed_dim)
def _embed_points(
self,
points: torch.Tensor,
labels: torch.Tensor,
pad: bool,
) -> torch.Tensor:
"""Embeds point prompts."""
points = points + 0.5 # Shift to center of pixel
if pad:
padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device)
padding_label = -torch.ones((labels.shape[0], 1), device=labels.device)
points = torch.cat([points, padding_point], dim=1)
labels = torch.cat([labels, padding_label], dim=1)
point_embedding = self.pe_layer.forward_with_coords(
points, self.input_image_size
)
point_embedding = torch.where(
labels[:, :, None] == -1,
self.not_a_point_embed.weight,
point_embedding + torch.where(
labels[:, :, None] == 0,
self.point_embeddings[0].weight,
torch.where(
labels[:, :, None] == 1,
self.point_embeddings[1].weight,
torch.where(
labels[:, :, None] == 2,
self.point_embeddings[2].weight,
self.point_embeddings[3].weight,
),
),
),
)
return point_embedding
def _get_device(self) -> torch.device:
return self.point_embeddings[0].weight.device
def forward(
self,
points: Optional[Tuple[torch.Tensor, torch.Tensor]],
boxes: Optional[torch.Tensor],
masks: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Embeds different types of prompts, returning both sparse and dense
embeddings.
Arguments:
points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates
and labels to embed.
Returns:
torch.Tensor: sparse embeddings for the points and boxes, with shape
BxNx(embed_dim), where N is determined by the number of input points
and boxes.
"""
sparse_embeddings = torch.empty(
(1, 0, self.embed_dim), device=self._get_device()
)
if points is not None:
coords, labels = points
point_embeddings = self._embed_points(coords, labels, pad=True)
sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)
return sparse_embeddings
class PositionalEncoder(nn.Module):
def __init__(self) -> None:
super().__init__()
self.sam_prompt_encoder = PromptEncoder(
embed_dim=256,
image_embedding_size=(16, 16),
input_image_size=(256, 256),
)
def forward(self, sam_point_coords: torch.Tensor, sam_point_labels: torch.Tensor) -> torch.Tensor:
sparse_embeddings = self.sam_prompt_encoder(
points=(sam_point_coords, sam_point_labels),
boxes=None,
masks=None,
)
return sparse_embeddings
CHECKPOINT_PATH = "/path/to/checkpoint.pt"
if __name__ == "__main__":
model = PositionalEncoder().to("cuda")
model.eval()
sam_point_coords = torch.randn([1, 1, 2], dtype=torch.float32, device="cuda")
sam_point_labels = torch.randn([1, 1], dtype=torch.float32, device="cuda")
inputs = (sam_point_coords, sam_point_labels)
# reference output from original model
ref_out = model(*inputs)
# export, serialize, deserialize
ep = torch.export.export(model, inputs)
torch.export.save(ep, CHECKPOINT_PATH)
reloaded_model = torch.export.load(CHECKPOINT_PATH)
# compare outputs
trace_out = reloaded_model.module()(*inputs)
assert torch.allclose(ref_out, trace_out) load the ckpt on a Jetson (this may be unnecessary to repro), and attempt to build the engine import torch
import torch_tensorrt
ep = torch.export.load("/path/to/checkpoint.pt")
example_inputs = ep.example_inputs[0]
model = ep.module().to("cuda")
# reference output from traced model
ref_out = model(*example_inputs)
optimized_model = torch_tensorrt.compile(
model,
ir="torch_compile",
inputs=example_inputs,
enabled_precisions={torch.float, torch.half},
workspace_size=4 << 30,
min_block_size=7,
torch_executed_ops={},
)
opt_out = optimized_model(*example_inputs)
assert torch.allclose(ref_out, opt_out) |
you should see this fx Graph dump as part of the stack trace
|
@narendasan @apbose does the above repro for you guys? |
Reproed. Looking into this. |
thanks @apbose , lmk if I can help in any way |
Seems like the cat converter is receiving an empty tensor. In the above code the sparse_embedding is
which results in an empty tensor when you are giving dim1= 0. I do not see the above cat failing when I give
instead it fails in torch ref, traced output and torchTRT output not matching @patrick-botco would you know why the traced and torch output won't match? |
❓ Question
im able to
torch.export
and generate an ExportedProgram with no issues for my model. upon compiling withtorch_tensorrt
...... i run into this error:
im currently able to cleanly generate an
ExportedProgram
viatorch.export
, and outputs from the trace match the original PyTorch model. in particular, its unclear to me why!weights.values == !weights.count
would be anAPI Usage Error
, and the discrepancy between torch.compile and how torch_tensorrt interprets / performs the op conversion (torch.compile on the ExportedProgram module works fine)What you have already tried
i've narrowed the issue down to a single module that does positional encoding. the output of this module is then concat'd with another tensor, which is the error above. without this module, everything works as expected, and i'm able to see about a 5x speedup.
the only unique thing about this module is that it has a buffer and some in-place operations; however, i've dumped and manually inspected the fx Graph and the trace looks correct (buffer lifted as a constant input). other things ive done are: re-writing the forward so that they are no in-place operations to make graph capture easier.
Environment
conda
,pip
,libtorch
, source): pipAdditional context
cc @narendasan not sure if you have any insight here. thanks!
The text was updated successfully, but these errors were encountered: