Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

module with buffer requires wrapper module to avoidjit_ext error #1637

Closed
ali-alshaar7 opened this issue Jan 10, 2025 · 1 comment
Closed
Assignees

Comments

@ali-alshaar7
Copy link
Contributor

ali-alshaar7 commented Jan 10, 2025

Note: If you have a model or program that is not supported yet but should be, please use the program coverage template.

🐛 Bug

To Reproduce

import torch
import torch.nn as nn
from typing import Tuple, Optional
import thunder

class cast(nn.Module):
    def __init__(
        self,
        k_shape: Tuple[int, int, int, int],
        v_shape: Tuple[int, int, int, int],
        device: Optional[torch.device] = None,
        dtype: Optional[torch.dtype] = None,
    ) -> None:
        super().__init__()
        self.register_buffer("k", torch.zeros(k_shape, device=device, dtype=dtype), persistent=False)
        self.register_buffer("v", torch.zeros(v_shape, device=device, dtype=dtype), persistent=False)

    def forward(self, k: torch.Tensor, v: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        # move the buffer to the activation dtype for when AMP is used
        self.k = self.k.to(k.dtype)
        self.v = self.v.to(v.dtype)
        # update the cache
        return self.k, self.v

with torch.device("cpu"):
    k_shape = (2, 3, 4, 5)
    v_shape = (2, 3, 4, 5)
    dtype = torch.float32
    model = (cast(k_shape, v_shape, dtype=dtype).eval().requires_grad_(False))

k = torch.randn(2, 3, 4, 5, dtype=torch.half)
v = torch.randn(2, 3, 4, 5, dtype=torch.half)
cast_jit = thunder.jit(model)
output_k, output_v = cast_jit(k, v)

fails with:

  File "/teamspace/studios/this_studio/lightning-thunder/thunder/core/jit_ext.py", line 1875, in thunder_general_jit
    process_recorded_modifications(ctx, epilogue_trace)
  File "/teamspace/studios/this_studio/lightning-thunder/thunder/core/jit_ext.py", line 1757, in process_recorded_modifications
    typ, name, root_module_provenance = get_parameter_or_buffer_or_submodule_name_and_root(
  File "/teamspace/studios/this_studio/lightning-thunder/thunder/core/jit_ext.py", line 1414, in get_parameter_or_buffer_or_submodule_name_and_root
    assert provenance.inputs[0].inst is PseudoInst.LOAD_ATTR
IndexError: list index out of range

requires this to pass

class ParentModule(nn.Module):
    def __init__(self, k_shape: Tuple[int, int, int, int], v_shape: Tuple[int, int, int, int], device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None):
        super().__init__()
        self.cast_module = cast(k_shape, v_shape, device=device, dtype=dtype)

    def forward(self, k: torch.Tensor, v: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        return self.cast_module(k, v)

Environment

  • PyTorch Version (e.g., 1.0):
  • OS (e.g., Linux):
  • How you installed PyTorch (conda, pip, source):
  • Build command you used (if compiling from source):
  • Python version:
  • CUDA/cuDNN version:
  • GPU models and configuration:
  • Any other relevant information:

Additional context

@ali-alshaar7 ali-alshaar7 self-assigned this Jan 10, 2025
@t-vi
Copy link
Collaborator

t-vi commented Jan 13, 2025

I think this is "setattr" doing it for us.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants