Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add torchcompile_xentropy executor #1655

Merged
merged 22 commits into from
Feb 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
6c5b0e1
add cross_entropy to torchcompile_cat executor
riccardofelluga Jan 17, 2025
3982cb7
Merge branch 'main' of github.com:Lightning-AI/lightning-thunder into…
riccardofelluga Jan 28, 2025
ecc1f7c
Revert "add cross_entropy to torchcompile_cat executor"
riccardofelluga Jan 28, 2025
0927d74
move xentropy in its own executor
riccardofelluga Jan 28, 2025
e11ddab
reshape not always required
riccardofelluga Jan 28, 2025
d815892
add torchcompile_xentropy to the tests
riccardofelluga Jan 28, 2025
39530d3
add torchcompile_xentropy as default executor
riccardofelluga Jan 28, 2025
284e7d9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 28, 2025
c4d98e9
remove torch.nn.functional.cross_entropy from supported as it is in r…
riccardofelluga Jan 29, 2025
9338118
bump version (#1708)
ali-alshaar7 Jan 28, 2025
c875621
nvfuser: add option to allow shape only region (#1702)
kshitij12345 Jan 28, 2025
7e857f5
Backward transform dependency fix (#1693)
jjsjann123 Jan 28, 2025
fcec023
pin check schema (#1709)
t-vi Jan 28, 2025
9a40bb0
bump transformers version (#1698)
riccardofelluga Jan 29, 2025
1df5929
Reduce `test_thunderfx_mistral_nemo_small` model size (#1701)
riccardofelluga Jan 29, 2025
71bd489
fix naming
riccardofelluga Jan 29, 2025
d94a420
add pad prims
riccardofelluga Jan 29, 2025
35c4606
add test for the new executor
riccardofelluga Jan 29, 2025
e0e32e0
Merge branch 'main' into torchcompilecat-add-xentropy
riccardofelluga Jan 29, 2025
3edc0a3
add missing import
riccardofelluga Jan 29, 2025
3cebfe3
Merge branch 'main' into torchcompilecat-add-xentropy
riccardofelluga Jan 29, 2025
52294ec
Merge branch 'main' into torchcompilecat-add-xentropy
riccardofelluga Feb 4, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion thunder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,16 +191,18 @@
cudnn_executor: None | extend.Executor = extend.get_executor("cudnn")
sdpa_executor: None | extend.Executor = extend.get_executor("sdpa")
torchcompile_cat_executor: None | extend.Executor = extend.get_executor("torchcompile_cat")
torchcompile_xentropy_executor: None | extend.Executor = extend.get_executor("torchcompile_xentropy")
apex_executor: None | extend.Executor = extend.get_executor("apex")
nvfuser_executor: None | extend.Executor = extend.get_executor("nvfuser")
pytorch_executor: None | extend.Executor = extend.get_executor("torch")

# Default executor list is [cudnn -> sdpa -> torchcompile_cat -> nvfuser -> torch -> python]
# Default executor list is [cudnn -> sdpa -> torchcompile_cat -> torchcompile_xentropy -> nvfuser -> torch -> python]
# Note that add_default_executor inserts executor at start of list, hence the reverse order below.
if nvfuser_executor:
add_default_executor(nvfuser_executor)

if torchcompile_cat_executor and pytorch._dynamo.is_inductor_supported():
add_default_executor(torchcompile_xentropy_executor)
add_default_executor(torchcompile_cat_executor)

if sdpa_executor:
Expand Down
35 changes: 35 additions & 0 deletions thunder/executors/torch_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,41 @@ def cuda_device_checker(*args, **kwargs):
op: ImplInfo(checker=cuda_device_checker) for op in pytorch_ex.implmap if op in supported_ops
}

# Similar to torchcomile_cat, this executor is meant to be used with nvfuser_executor to allow
# inductor to claim cross_entropy computation.
required_ops = {
"nll_loss_backward",
"log_softmax_backward",
"torch.log_softmax",
"torch.nn.functional.nll_loss",
"torch.nn.functional.cross_entropy",
}
torch_compile_xentropy_ex = TorchCompileExecutor(name="torchcompile_xentropy", required_ops=required_ops)
register_executor(torch_compile_xentropy_ex)

supported_ops = {
prims.broadcast_in_dim.id,
prims.convert_element_type.id,
prims.div.id,
prims.ne.id,
prims.neg.id,
prims.pad.id,
prims.reshape.id,
prims.slice_prim.id,
prims.where.id,
"nll_loss_backward",
"log_softmax_backward",
"torch.log_softmax",
"torch.nn.functional.nll_loss",
"torch.sum",
"torch.take_along_dim",
"torch.Tensor.contiguous",
}

torch_compile_xentropy_ex._implmap = {
op: ImplInfo(checker=cuda_device_checker) for op in pytorch_ex.implmap if op in supported_ops
}


torch_compile_ex = TorchCompileExecutor(name="torchcompile")
register_executor(torch_compile_ex)
Expand Down
20 changes: 19 additions & 1 deletion thunder/tests/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,23 @@ def version(self):
return torch.__version__


class TorchCompileXentropyTestExecutor(TestExecutor):
name = "torchcompile_xentropy"
supported_devicetypes = (devices.DeviceType.CPU, devices.DeviceType.CUDA)
supported_dtypes = (datatypes.dtype,)

def is_available(self) -> bool:
return not IS_WINDOWS

def executors_list(self) -> list[extend.Executor]:
from thunder.executors.torch_compile import torch_compile_cat_ex

return [torch_compile_cat_ex]

def version(self):
return torch.__version__


class TorchCompileCatTestExecutor(TestExecutor):
name = "torchcompile_cat"
supported_devicetypes = (devices.DeviceType.CPU, devices.DeviceType.CUDA)
Expand Down Expand Up @@ -261,6 +278,7 @@ def make_callable(self, fn, **kwargs):
# TODO Refactor these executors into the actual executor (sub)modules
TorchExecutor: TorchTestExecutor = TorchTestExecutor()
TorchCompileCatExecutor: TorchCompileCatTestExecutor = TorchCompileCatTestExecutor()
TorchCompileXentropyExecutor: TorchCompileXentropyTestExecutor = TorchCompileXentropyTestExecutor()
TorchCompileExecutor: TorchCompileTestExecutor = TorchCompileTestExecutor()
DynamoThunderExecutor: DynamoThunderTestExecutor = DynamoThunderTestExecutor()
nvFuserExecutor: None | nvFuserTestExecutor = None
Expand Down Expand Up @@ -368,7 +386,7 @@ def __init__(
self.supported_executors = (
set(supported_executors)
if supported_executors is not None
else set(_all_test_executors() + [TorchCompileCatExecutor])
else set(_all_test_executors() + [TorchCompileCatExecutor, TorchCompileXentropyExecutor])
)
for ex in self.supported_executors:
assert isinstance(ex, TestExecutor)
Expand Down
1 change: 1 addition & 0 deletions thunder/tests/test_extend.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def test_get_all_executors_includes_all_native_executors():
"sdpa",
"torchcompile",
"torchcompile_cat",
"torchcompile_xentropy",
"python",
"transformer_engine",
}
Expand Down
24 changes: 23 additions & 1 deletion thunder/tests/test_torch_compile_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,12 @@
from torch._dynamo import is_inductor_supported

import thunder
from thunder.executors.torch_compile import supported_ops, torch_compile_ex, torch_compile_cat_ex
from thunder.executors.torch_compile import (
supported_ops,
torch_compile_ex,
torch_compile_cat_ex,
torch_compile_xentropy_ex,
)
from thunder.executors.torchex import ex as pytorch_ex
from thunder.executors.nvfuserex import nvfuserex
from thunder.tests.bf16 import device_supports_bf16
Expand Down Expand Up @@ -122,3 +127,20 @@ def forward_and_loss(model: nn.Module, input_ids: torch.Tensor) -> torch.Tensor:
out_jitted = forward_and_loss_jitted(model, input_ids)

assert_close(out, out_jitted)


@requiresCUDA
def test_torch_compile_xentropy_loss():
from transformers.loss.loss_utils import ForCausalLMLoss

logits = torch.randn(1, 2, 6, device="cuda", requires_grad=True)
labels = torch.randint(0, 6, (1, 2), device="cuda")
vocab_size = 6

closs_fn = thunder.jit(ForCausalLMLoss, executors=[torch_compile_xentropy_ex])
_ = closs_fn(logits, labels, vocab_size, ignore_index=-1)
forward_trace = thunder.last_traces(closs_fn)[-1].python()

# make a single torch.compile region
assert "TorchCompile0" in forward_trace
assert "TorchCompile1" not in forward_trace
Loading