diff --git a/thunder/__init__.py b/thunder/__init__.py index 1fcc346471..2ba53cdfc9 100644 --- a/thunder/__init__.py +++ b/thunder/__init__.py @@ -191,16 +191,18 @@ cudnn_executor: None | extend.Executor = extend.get_executor("cudnn") sdpa_executor: None | extend.Executor = extend.get_executor("sdpa") torchcompile_cat_executor: None | extend.Executor = extend.get_executor("torchcompile_cat") +torchcompile_xentropy_executor: None | extend.Executor = extend.get_executor("torchcompile_xentropy") apex_executor: None | extend.Executor = extend.get_executor("apex") nvfuser_executor: None | extend.Executor = extend.get_executor("nvfuser") pytorch_executor: None | extend.Executor = extend.get_executor("torch") -# Default executor list is [cudnn -> sdpa -> torchcompile_cat -> nvfuser -> torch -> python] +# Default executor list is [cudnn -> sdpa -> torchcompile_cat -> torchcompile_xentropy -> nvfuser -> torch -> python] # Note that add_default_executor inserts executor at start of list, hence the reverse order below. if nvfuser_executor: add_default_executor(nvfuser_executor) if torchcompile_cat_executor and pytorch._dynamo.is_inductor_supported(): + add_default_executor(torchcompile_xentropy_executor) add_default_executor(torchcompile_cat_executor) if sdpa_executor: diff --git a/thunder/executors/torch_compile.py b/thunder/executors/torch_compile.py index 5aaba1e64f..5f4532dbce 100644 --- a/thunder/executors/torch_compile.py +++ b/thunder/executors/torch_compile.py @@ -243,6 +243,41 @@ def cuda_device_checker(*args, **kwargs): op: ImplInfo(checker=cuda_device_checker) for op in pytorch_ex.implmap if op in supported_ops } +# Similar to torchcomile_cat, this executor is meant to be used with nvfuser_executor to allow +# inductor to claim cross_entropy computation. +required_ops = { + "nll_loss_backward", + "log_softmax_backward", + "torch.log_softmax", + "torch.nn.functional.nll_loss", + "torch.nn.functional.cross_entropy", +} +torch_compile_xentropy_ex = TorchCompileExecutor(name="torchcompile_xentropy", required_ops=required_ops) +register_executor(torch_compile_xentropy_ex) + +supported_ops = { + prims.broadcast_in_dim.id, + prims.convert_element_type.id, + prims.div.id, + prims.ne.id, + prims.neg.id, + prims.pad.id, + prims.reshape.id, + prims.slice_prim.id, + prims.where.id, + "nll_loss_backward", + "log_softmax_backward", + "torch.log_softmax", + "torch.nn.functional.nll_loss", + "torch.sum", + "torch.take_along_dim", + "torch.Tensor.contiguous", +} + +torch_compile_xentropy_ex._implmap = { + op: ImplInfo(checker=cuda_device_checker) for op in pytorch_ex.implmap if op in supported_ops +} + torch_compile_ex = TorchCompileExecutor(name="torchcompile") register_executor(torch_compile_ex) diff --git a/thunder/tests/framework.py b/thunder/tests/framework.py index f02d3c3c3b..8f458576f0 100644 --- a/thunder/tests/framework.py +++ b/thunder/tests/framework.py @@ -208,6 +208,23 @@ def version(self): return torch.__version__ +class TorchCompileXentropyTestExecutor(TestExecutor): + name = "torchcompile_xentropy" + supported_devicetypes = (devices.DeviceType.CPU, devices.DeviceType.CUDA) + supported_dtypes = (datatypes.dtype,) + + def is_available(self) -> bool: + return not IS_WINDOWS + + def executors_list(self) -> list[extend.Executor]: + from thunder.executors.torch_compile import torch_compile_cat_ex + + return [torch_compile_cat_ex] + + def version(self): + return torch.__version__ + + class TorchCompileCatTestExecutor(TestExecutor): name = "torchcompile_cat" supported_devicetypes = (devices.DeviceType.CPU, devices.DeviceType.CUDA) @@ -261,6 +278,7 @@ def make_callable(self, fn, **kwargs): # TODO Refactor these executors into the actual executor (sub)modules TorchExecutor: TorchTestExecutor = TorchTestExecutor() TorchCompileCatExecutor: TorchCompileCatTestExecutor = TorchCompileCatTestExecutor() +TorchCompileXentropyExecutor: TorchCompileXentropyTestExecutor = TorchCompileXentropyTestExecutor() TorchCompileExecutor: TorchCompileTestExecutor = TorchCompileTestExecutor() DynamoThunderExecutor: DynamoThunderTestExecutor = DynamoThunderTestExecutor() nvFuserExecutor: None | nvFuserTestExecutor = None @@ -368,7 +386,7 @@ def __init__( self.supported_executors = ( set(supported_executors) if supported_executors is not None - else set(_all_test_executors() + [TorchCompileCatExecutor]) + else set(_all_test_executors() + [TorchCompileCatExecutor, TorchCompileXentropyExecutor]) ) for ex in self.supported_executors: assert isinstance(ex, TestExecutor) diff --git a/thunder/tests/test_extend.py b/thunder/tests/test_extend.py index c5a45ffd83..7e4a4b04b2 100644 --- a/thunder/tests/test_extend.py +++ b/thunder/tests/test_extend.py @@ -132,6 +132,7 @@ def test_get_all_executors_includes_all_native_executors(): "sdpa", "torchcompile", "torchcompile_cat", + "torchcompile_xentropy", "python", "transformer_engine", } diff --git a/thunder/tests/test_torch_compile_executor.py b/thunder/tests/test_torch_compile_executor.py index c388dd6347..ce66d75d3d 100644 --- a/thunder/tests/test_torch_compile_executor.py +++ b/thunder/tests/test_torch_compile_executor.py @@ -4,7 +4,12 @@ from torch._dynamo import is_inductor_supported import thunder -from thunder.executors.torch_compile import supported_ops, torch_compile_ex, torch_compile_cat_ex +from thunder.executors.torch_compile import ( + supported_ops, + torch_compile_ex, + torch_compile_cat_ex, + torch_compile_xentropy_ex, +) from thunder.executors.torchex import ex as pytorch_ex from thunder.executors.nvfuserex import nvfuserex from thunder.tests.bf16 import device_supports_bf16 @@ -122,3 +127,20 @@ def forward_and_loss(model: nn.Module, input_ids: torch.Tensor) -> torch.Tensor: out_jitted = forward_and_loss_jitted(model, input_ids) assert_close(out, out_jitted) + + +@requiresCUDA +def test_torch_compile_xentropy_loss(): + from transformers.loss.loss_utils import ForCausalLMLoss + + logits = torch.randn(1, 2, 6, device="cuda", requires_grad=True) + labels = torch.randint(0, 6, (1, 2), device="cuda") + vocab_size = 6 + + closs_fn = thunder.jit(ForCausalLMLoss, executors=[torch_compile_xentropy_ex]) + _ = closs_fn(logits, labels, vocab_size, ignore_index=-1) + forward_trace = thunder.last_traces(closs_fn)[-1].python() + + # make a single torch.compile region + assert "TorchCompile0" in forward_trace + assert "TorchCompile1" not in forward_trace