Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Get dynamic shapes to work with Phi-3-mini-128k-instruct #1579

Open
tfogal opened this issue Dec 20, 2024 · 4 comments
Open

Get dynamic shapes to work with Phi-3-mini-128k-instruct #1579

tfogal opened this issue Dec 20, 2024 · 4 comments
Assignees
Labels
enhancement New feature or request nemo Issues needed to support NVIDIA NeMo models.

Comments

@tfogal
Copy link
Collaborator

tfogal commented Dec 20, 2024

🚀 Feature

The program below fails due to the use of cache=thunder.core.options.CACHE_OPTIONS.SYMBOLIC_VALUES.

Traceback (most recent call last):
  File "/home/tfogal/dev/ak-bench/bench_targets/llm_peft/sample.py", line 146, in <module>
    cmodel(**d)
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1740, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1751, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py", line 569, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1740, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1751, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/transformers/models/phi3/modeling_phi3.py", line 1193, in forward
    @add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING)
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1740, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1751, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py", line 738, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/fx/graph_module.py", line 822, in call_wrapped
    return self._wrapped_call(self, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/fx/graph_module.py", line 400, in __call__
    raise e
  File "/usr/local/lib/python3.12/dist-packages/torch/fx/graph_module.py", line 387, in __call__
    return super(self.cls, obj).__call__(*args, **kwargs)  # type: ignore[misc]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1740, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1751, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<eval_with_key>.14", line 5, in forward
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1740, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1751, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tfogal/dev/thunder/thunder/core/module.py", line 80, in forward
    res = self._forward_fn(*args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tfogal/dev/thunder/thunder/__init__.py", line 742, in wrapped
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/tfogal/dev/thunder/thunder/__init__.py", line 777, in fn_
    cache_entry, inps, pro_to_epi = get_computation_and_inputs(*args, **kwargs)
                                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tfogal/dev/thunder/thunder/__init__.py", line 724, in wrapped
    cache_entry, inps, pro_to_epi = get_computation_and_inputs_fn(*args, **kwargs)
                                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tfogal/dev/thunder/thunder/core/langctxs.py", line 136, in _fn
    result = fn(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^
  File "/home/tfogal/dev/thunder/thunder/__init__.py", line 236, in cache_info_wrapper
    res = fn(*args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^
  File "/home/tfogal/dev/thunder/thunder/__init__.py", line 630, in get_computation_and_inputs
    computation_trc, backward_trc = split_forward_backward(computation_trc, cd, cs, *inps)
                                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tfogal/dev/thunder/thunder/executors/torch_autograd.py", line 156, in split_forward_backward
    fw_trace, bw_trace = forward_and_backward_from_trace(primal_trace, torch_autograd=True)
                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tfogal/dev/thunder/thunder/core/transforms.py", line 3053, in forward_and_backward_from_trace
    forward_trace, result, env = augmented_forward_pass_trace(trace, *trace.args, **trace.kwargs)
                                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tfogal/dev/thunder/thunder/core/transforms.py", line 2669, in augmented_forward_pass_trace
    trace, result, env = interpret_trace_to_trace(
                         ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tfogal/dev/thunder/thunder/core/trace_interpreter.py", line 169, in interpret_trace_to_trace
    result = prim_func(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tfogal/dev/thunder/thunder/core/transforms.py", line 2582, in _vjp_impl
    out_primal, out_residuals = vjp_impl(*primals, **kwargs)
                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tfogal/dev/thunder/thunder/core/transforms.py", line 2357, in decomposed_fn_aug_fwd_rule
    saved_for_backward = deconstruct_forward_env_for_backward(trace, env)
                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tfogal/dev/thunder/thunder/core/transforms.py", line 2323, in deconstruct_forward_env_for_backward
    saved_for_backward = tuple(env[sequencify(symbol.output)[0].name].residuals for symbol in bound_symbols)
                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tfogal/dev/thunder/thunder/core/transforms.py", line 2323, in <genexpr>
    saved_for_backward = tuple(env[sequencify(symbol.output)[0].name].residuals for symbol in bound_symbols)
                                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: 'int' object has no attribute 'name'

Motivation

With NeMo, we are starting to test fine-tuning with varying sequence lengths, and thus the tensor sizes are changing every step.

Pitch

We do not give an error :-)

Alternatives

The alternative is probably to pad the tensor up to a power of two and compile that.

Additional context

import math
import datasets
import torch
import thunder
from thunder.dynamo import thunderfx
import transformers
import nemo
import nvtx

nvtx.push_range("startup") # force nvtx initialization
nvtx.pop_range()

m_id = "microsoft/Phi-3-mini-128k-instruct"
nvtx.push_range("loading")
cfg = transformers.AutoConfig.from_pretrained(
  m_id,
  torch_dtype=torch.bfloat16,
  num_hidden_layers=2, # scale down for testing
)
cfg.hidden_size = cfg.num_attention_heads
with torch.device("cuda"):
  model = transformers.AutoModelForCausalLM.from_config(cfg).to(torch.bfloat16)

tokenizer = transformers.AutoTokenizer.from_pretrained(
  m_id,
  torch_dtype='auto',
  trust_remote_code=False,

symbolic = thunder.core.options.CACHE_OPTIONS.SYMBOLIC_VALUES
nvtx.pop_range() # loading
cmodel = thunderfx(model, cache=symbolic)

def argument_details(args: list[torch.Tensor]):
  for a in args:
    if isinstance(a, torch.Tensor):
      print(f"arg {a.shape=}")
    else:
      print(f"arg {a=}")

@nvtx.annotate()
def make_squad_hf_dataset(tokenizer):
    EOS_TOKEN = tokenizer.eos_token  # Must add EOS_TOKEN

    def formatting_prompts_func(examples):
        alpaca_prompt = """Below is an instruction that describes a task,
        ### Instruction:
        {}

        ### Input:
        {}

        ### Response:
        {}"""
        print("-- FORMATTING PROMPTS!")
        instruction = examples["context"]
        input = examples["question"]
        output = examples["answers"]['text']
        if isinstance(output, list):
            output = output[0]
        text = alpaca_prompt.format(instruction, input, output) + EOS_TOKEN
        ans = tokenizer(text)
        ans['labels'] = list(ans['input_ids'][1:])
        ans['input_ids'] = ans['input_ids'][:-1]
        ans['attention_mask'] = ans['attention_mask'][:-1]
        print("answer:", ans)
        return ans

    datamodule = datasets.load_dataset("rajpurkar/squad", split="train[:100]")
    return datamodule.map(
        formatting_prompts_func,
        batched=False,
        batch_size=2,
        remove_columns=["id", "title", "context", "question", 'answers'],
    )

@staticmethod
def collate_fn(batch, pad_token_id=0):
    def batchify(tensor):
        if tensor.ndim == 1:
            return tensor.unsqueeze_(0)
        return tensor

    def extract_key_from_dicts(batch, key):
        return list(map(lambda x: x[key], batch))

    def pad_within_micro(batch, pad_token_id):
        max_len = max(map(len, batch))
        return [item + [pad_token_id] * (max_len - len(item)) for item in batch]


    return {
        key: batchify(
            torch.LongTensor(
                pad_within_micro(
                    extract_key_from_dicts(batch, key),
                    pad_token_id,
                )
            )
        )
        for key in batch[0].keys()
    }

#ds = datasets.load_dataset("rajpurkar/squad")
dm = make_squad_hf_dataset(tokenizer)
loader = torch.utils.data.DataLoader(dm, collate_fn=collate_fn, num_workers=1, pin_memory=True)

print(dm)
counter = 0
for d in loader:
  #print(d)
  nvtx.push_range("moving to GPU")
  d = {k: v.cuda() for k,v in d.items()}
  nvtx.pop_range()
  for k in d:
    print(f"{k=}:", d[k].shape)
  nvtx.push_range(f"run model {d['input_ids'].shape}")
  cmodel(**d)
  nvtx.pop_range()
  #print(thunder.last_traces(cmodel)[-1])

  counter = counter + 1
  if counter >= 15:
    break

cc @tfogal

@tfogal tfogal added enhancement New feature or request nemo Issues needed to support NVIDIA NeMo models. labels Dec 20, 2024
@jjsjann123
Copy link
Collaborator

jjsjann123 commented Jan 3, 2025

sorry for the slow turn-around. Got busy with some other week last couple weeks.

Here's a smaller repro on this issue.

import thunder

def foo(a):
    return a[slice(None, None, None), None, None, slice(None, None, None)]

import torch

dtype = torch.float32
a = torch.randn(1, 128, device="cuda").to(dtype=dtype)
 
a.requires_grad_()
jfoo = thunder.jit(foo, cache="symbolic values")
 
out = jfoo(a)

This one is coming from repetitive

(i0, i1) = prims.shape(x)
(i0, i1) = prims.shape(x)
(i0, i1) = prims.shape(x)

being optimized by dce as

(i0, i1) = prims.shape(x)
(_, _) = prims.shape(x)
(_, _) = prims.shape(x)

A quick WAR is to run cse before the dce pass in grad transform.

@jjsjann123
Copy link
Collaborator

import thunder

def foo(a, w):
    return torch.nn.functional.linear(a, w)

import torch

dtype = torch.float32
a = torch.randn(2, 8, 128, device="cuda").to(dtype=dtype)
b = torch.randn(64, 128, device="cuda").to(dtype=dtype)

a.requires_grad_()
b.requires_grad_()
jfoo = thunder.jit(foo, cache="symbolic values")

out = jfoo(a, b)

I'm now seeing this one error out. Looks like we are missing a prims.shape call in grad.

@tfogal
Copy link
Collaborator Author

tfogal commented Jan 15, 2025

I created a histogram of all the sequence lengths we see in one dataset example:

image

Of course, the ideal is to do this generally, but a temporary solution might be bucketing. There are two clusters at the end (near 400 and at ~480) that could be grouped into a bucket. For the lower region, we might consider more buckets.

Given vectorization widths etc. my recommendation would be that we bucket every 16 elements.

@csarofeen
Copy link
Collaborator

Thanks for the distribution! At minimum bucket intervals should be 128bit aligned, yes, that makes a lot of sense. On the smaller size we might do well with larger buckets than the larger size, as the perf delta won't be as large at smaller sizes. CC @IvanYashchuk since I was chatting with him this morning about bucketing and padding approaches.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request nemo Issues needed to support NVIDIA NeMo models.
Projects
None yet
Development

No branches or pull requests

3 participants