Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Can't use torch.compile #12

Open
shikhartuli opened this issue Jul 3, 2024 · 3 comments
Open

Can't use torch.compile #12

shikhartuli opened this issue Jul 3, 2024 · 3 comments

Comments

@shikhartuli
Copy link

When I compile the model, I get the following error. Any idea how to fix this?

Traceback (most recent call last):
  File "/home/user/anaconda3/envs/hydra/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3577, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "/tmp/ipykernel_18849/856931662.py", line 14, in <module>
    model(input_ids=tokens)
  File "/home/user/anaconda3/envs/hydra/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/user/anaconda3/envs/hydra/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/user/anaconda3/envs/hydra/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 489, in _fn
    return fn(*args, **kwargs)
  File "/home/user/anaconda3/envs/hydra/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/user/anaconda3/envs/hydra/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/group-volume/users/shikhar.tuli/hydra/transformers/src/transformers/models/hydra/modeling_hydra.py", line 1905, in forward
    outputs = self.model(
  File "/home/user/anaconda3/envs/hydra/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/user/anaconda3/envs/hydra/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/group-volume/users/shikhar.tuli/hydra/transformers/src/transformers/models/hydra/modeling_hydra.py", line 1693, in forward
    logger.warning_once(
  File "/group-volume/users/shikhar.tuli/hydra/transformers/src/transformers/models/hydra/modeling_hydra.py", line 1703, in resume_in_forward
    causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
  File "/group-volume/users/shikhar.tuli/hydra/transformers/src/transformers/models/hydra/modeling_hydra.py", line 1726, in resume_in_forward
    layer_outputs = decoder_layer(
  File "/home/user/anaconda3/envs/hydra/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/user/anaconda3/envs/hydra/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/group-volume/users/shikhar.tuli/hydra/transformers/src/transformers/models/hydra/modeling_hydra.py", line 1392, in forward
    hidden_states, router_logits = self.hydra(
  File "/home/user/anaconda3/envs/hydra/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/user/anaconda3/envs/hydra/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/group-volume/users/shikhar.tuli/hydra/transformers/src/transformers/models/hydra/modeling_hydra.py", line 1221, in forward
    return self.cuda_kernels_forward(hidden_states, cache_params)
  File "/group-volume/users/shikhar.tuli/hydra/transformers/src/transformers/models/hydra/modeling_hydra.py", line 1025, in cuda_kernels_forward
    projected_states, _, in_proj_router_logits = self.in_proj(hidden_states)
  File "/home/user/anaconda3/envs/hydra/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/user/anaconda3/envs/hydra/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/group-volume/users/shikhar.tuli/hydra/transformers/src/transformers/models/hydra/modeling_hydra.py", line 892, in forward
    final_hidden_states = self.experts(hidden_states, routing_weights, selected_experts).to(hidden_states.dtype)
  File "/home/user/anaconda3/envs/hydra/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/user/anaconda3/envs/hydra/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/group-volume/users/shikhar.tuli/hydra/scattermoe/scattermoe/mlp.py", line 124, in forward
    padded_block_idxs, expert_offsets = kernels.ops.padded_block_indices(sorted_expert_idxs, self.num_experts)
  File "/group-volume/users/shikhar.tuli/hydra/scattermoe/scattermoe/mlp.py", line 126, in resume_in_forward
    h = self.experts(
  File "/home/user/anaconda3/envs/hydra/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/user/anaconda3/envs/hydra/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/group-volume/users/shikhar.tuli/hydra/scattermoe/scattermoe/parallel_experts.py", line 143, in forward
    results = ParallelLinear.apply(
  File "/home/user/anaconda3/envs/hydra/lib/python3.10/site-packages/torch/autograd/function.py", line 553, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/group-volume/users/shikhar.tuli/hydra/scattermoe/scattermoe/parallel_experts.py", line 14, in forward
    output = kernels.ops.scatter2scatter(
  File "/group-volume/users/shikhar.tuli/hydra/scattermoe/scattermoe/kernels/ops.py", line 139, in scatter2scatter
    with torch.cuda.device(X.device):
  File "/group-volume/users/shikhar.tuli/hydra/scattermoe/scattermoe/kernels/ops.py", line 140, in resume_in_scatter2scatter
    _scatter2scatter[grid](
  File "/home/user/anaconda3/envs/hydra/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 127, in run
    self.nargs = dict(zip(self.arg_names, args))
  File "/home/user/anaconda3/envs/hydra/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 127, in resume_in_run
    self.nargs = dict(zip(self.arg_names, args))
  File "/home/user/anaconda3/envs/hydra/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 127, in resume_in_run
    self.nargs = dict(zip(self.arg_names, args))
  File "/home/user/anaconda3/envs/hydra/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 128, in resume_in_run
    if len(self.configs) > 1:
  File "/home/user/anaconda3/envs/hydra/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 151, in resume_in_run
    config = self.configs[0]
  File "/home/user/anaconda3/envs/hydra/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 152, in resume_in_run
    self.best_config = config
  File "/home/user/anaconda3/envs/hydra/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 156, in resume_in_run
    ret = self.fn.run(
  File "/home/user/anaconda3/envs/hydra/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 655, in catch_errors
    return callback(frame, cache_entry, hooks, frame_state)
  File "/home/user/anaconda3/envs/hydra/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 727, in _convert_frame
    result = inner_convert(frame, cache_entry, hooks, frame_state)
  File "/home/user/anaconda3/envs/hydra/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 383, in _convert_frame_assert
    compiled_product = _compile(
  File "/home/user/anaconda3/envs/hydra/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 665, in _compile
    raise InternalTorchDynamoError(str(e)).with_traceback(
  File "/home/user/anaconda3/envs/hydra/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 646, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/home/user/anaconda3/envs/hydra/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 244, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/user/anaconda3/envs/hydra/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 562, in compile_inner
    out_code = transform_code_object(code, transform)
  File "/home/user/anaconda3/envs/hydra/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 1033, in transform_code_object
    transformations(instructions, code_options)
  File "/home/user/anaconda3/envs/hydra/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 151, in _fn
    return fn(*args, **kwargs)
  File "/home/user/anaconda3/envs/hydra/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 527, in transform
    tracer.run()
  File "/home/user/anaconda3/envs/hydra/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2128, in run
    super().run()
  File "/home/user/anaconda3/envs/hydra/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 818, in run
    and self.step()
  File "/home/user/anaconda3/envs/hydra/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 781, in step
    getattr(self, inst.opname)(inst)
  File "/home/user/anaconda3/envs/hydra/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 470, in wrapper
    return inner_fn(self, inst)
  File "/home/user/anaconda3/envs/hydra/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1252, in CALL_FUNCTION_EX
    self.call_function(fn, argsvars.items, kwargsvars.items)
  File "/home/user/anaconda3/envs/hydra/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 652, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/home/user/anaconda3/envs/hydra/lib/python3.10/site-packages/torch/_dynamo/variables/misc.py", line 660, in call_function
    return self.obj.call_method(tx, self.name, args, kwargs)
  File "/home/user/anaconda3/envs/hydra/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 775, in call_method
    return self.clone(grid=grid).call_function(tx, args, kwargs)
  File "/home/user/anaconda3/envs/hydra/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 743, in call_function
    "kwargs": meta.as_proxy(),
  File "/home/user/anaconda3/envs/hydra/lib/python3.10/site-packages/torch/_dynamo/variables/dicts.py", line 33, in as_proxy
    return {k: v.as_proxy() for k, v in self.items.items()}
  File "/home/user/anaconda3/envs/hydra/lib/python3.10/site-packages/torch/_dynamo/variables/dicts.py", line 33, in <dictcomp>
    return {k: v.as_proxy() for k, v in self.items.items()}
  File "/home/user/anaconda3/envs/hydra/lib/python3.10/site-packages/torch/_dynamo/variables/lazy.py", line 90, in realize_and_forward
    return getattr(self.realize(), name)(*args, **kwargs)
  File "/home/user/anaconda3/envs/hydra/lib/python3.10/site-packages/torch/_dynamo/variables/base.py", line 274, in as_proxy
    raise NotImplementedError(str(self))
torch._dynamo.exc.InternalTorchDynamoError: UserDefinedObjectVariable(dtype)

from user code:
   File "/home/user/anaconda3/envs/hydra/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 305, in run
    return self.fn.run(*args, **kwargs)

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True


During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/user/anaconda3/envs/hydra/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 2168, in showtraceback
    stb = self.InteractiveTB.structured_traceback(
  File "/home/user/anaconda3/envs/hydra/lib/python3.10/site-packages/IPython/core/ultratb.py", line 1454, in structured_traceback
    return FormattedTB.structured_traceback(
  File "/home/user/anaconda3/envs/hydra/lib/python3.10/site-packages/IPython/core/ultratb.py", line 1345, in structured_traceback
    return VerboseTB.structured_traceback(
  File "/home/user/anaconda3/envs/hydra/lib/python3.10/site-packages/IPython/core/ultratb.py", line 1192, in structured_traceback
    formatted_exception = self.format_exception_as_a_whole(etype, evalue, etb, number_of_lines_of_context,
  File "/home/user/anaconda3/envs/hydra/lib/python3.10/site-packages/IPython/core/ultratb.py", line 1082, in format_exception_as_a_whole
    self.get_records(etb, number_of_lines_of_context, tb_offset) if etb else []
  File "/home/user/anaconda3/envs/hydra/lib/python3.10/site-packages/IPython/core/ultratb.py", line 1179, in get_records
    res = list(stack_data.FrameInfo.stack_data(etb, options=options))[tb_offset:]
  File "/home/user/anaconda3/envs/hydra/lib/python3.10/site-packages/stack_data/core.py", line 597, in stack_data
    yield from collapse_repeated(
  File "/home/user/anaconda3/envs/hydra/lib/python3.10/site-packages/stack_data/utils.py", line 77, in collapse_repeated
    for is_highlighted, group in itertools.groupby(
  File "/home/user/anaconda3/envs/hydra/lib/python3.10/site-packages/stack_data/utils.py", line 45, in highlight_unique
    counts = Counter(lst)
  File "/home/user/anaconda3/envs/hydra/lib/python3.10/collections/__init__.py", line 577, in __init__
    self.update(iterable, **kwds)
  File "/home/user/anaconda3/envs/hydra/lib/python3.10/collections/__init__.py", line 670, in update
    _count_elements(self, iterable)
TypeError: unhashable type: 'dict'
@mayank31398
Copy link
Contributor

mayank31398 commented Sep 22, 2024

@shikhartuli, not sure if you are still working on this but I have added a compilable vesion of scattermoe here:
https://github.com/mayank31398/kernel-hyperdrive/blob/04e1dd2c6eb0154eab519cc91f5a2c9a3321c105/khd/scattermoe/triton_implementation/__init__.py#L37

Also, you seem familiar 🤔

I am still testing it but it seems to work without any graph breaks

@shikhartuli
Copy link
Author

@mayank31398 I was not able to get any speed up with your version when training a 1.5B MoE model on H200s. Could you share your profiling implementation?

Also, remember me from IIT?

@mayank31398
Copy link
Contributor

@shikhartuli the speedup is more for the full moe.
barebones kernel is not giving me a speedup either.
also compile doesnt trace through the MLIR generated from triton and most of the code is the kernel.

this is the repo: https://github.com/IBM/dolomite-engine I used for training
this is the sample config: https://github.com/IBM/dolomite-engine/blob/main/configs/pretraining-examples/moe/moe.yml
you can enable/disable compile in the config.

PS: yeah I remember LOL

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants