Skip to content

Commit af5aec9

Browse files
authored
Merge branch 'NVIDIA:main' into d_256
2 parents 3c2251e + 1d903f5 commit af5aec9

File tree

4 files changed

+52
-6
lines changed

4 files changed

+52
-6
lines changed

build_tools/VERSION.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
2.4.0.dev0
1+
2.5.0.dev0

tests/pytorch/test_parallel_cross_entropy.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,12 @@ def generate_infra(self, reduce_loss: bool, label_smoothing: float):
1919
label_smoothing=label_smoothing, reduction="mean" if reduce_loss else "none"
2020
)
2121

22-
def generate_input(self, dtype: torch.dtype, swap_dim: bool):
22+
def generate_input(self, dtype: torch.dtype, swap_dim: bool, ignore_idx: bool):
2323

2424
SQ = random.choice([64, 128])
2525
batch = random.choice([1, 2])
2626
vocab = random.choice([64000, 128000])
27+
ignore = random.sample(range(0, SQ - 1), 5)
2728

2829
if swap_dim:
2930
self.input_test = torch.rand((SQ, batch, vocab), dtype=dtype).cuda()
@@ -32,14 +33,27 @@ def generate_input(self, dtype: torch.dtype, swap_dim: bool):
3233
self.input_test = torch.rand((batch, SQ, vocab), dtype=dtype).cuda()
3334
self.tar_test = torch.randint(0, vocab, (batch, SQ)).cuda()
3435

36+
if ignore_idx:
37+
for i in ignore:
38+
# Ignore 5 indices
39+
if swap_dim:
40+
self.tar_test[i][0] = -100
41+
else:
42+
self.tar_test[0][i] = -100
43+
3544
self.input_ref = torch.reshape(self.input_test.clone().detach(), (batch * SQ, vocab))
3645
self.tar_ref = torch.reshape(self.tar_test.clone().detach(), (batch * SQ,))
3746

3847
def one_iteration_test(
39-
self, dtype: torch.dtype, swap_dim: bool, label_smoothing: float, reduce_loss: bool
48+
self,
49+
dtype: torch.dtype,
50+
swap_dim: bool,
51+
label_smoothing: float,
52+
reduce_loss: bool,
53+
ignore_idx: bool = False,
4054
):
4155

42-
self.generate_input(dtype, swap_dim)
56+
self.generate_input(dtype, swap_dim, ignore_idx)
4357

4458
self.input_test.requires_grad_(True)
4559
self.input_ref.requires_grad_(True)
@@ -57,6 +71,8 @@ def one_iteration_test(
5771
test_loss = torch.flatten(test_loss) if not reduce_loss else test_loss
5872

5973
torch.testing.assert_close(test_loss, ref_loss, check_dtype=False)
74+
if ignore_idx:
75+
print(test_loss, ref_loss)
6076
if reduce_loss:
6177
torch.testing.assert_close(
6278
torch.flatten(self.input_test.grad, start_dim=0, end_dim=1), self.input_ref.grad
@@ -106,3 +122,15 @@ def test_non_reduced_loss(self):
106122
self.one_iteration_test(
107123
dtype=torch.float32, swap_dim=False, label_smoothing=0, reduce_loss=False
108124
)
125+
126+
def test_ignore_idx(self):
127+
self.generate_iters(5)
128+
self.generate_infra(False, 0)
129+
for i in range(self.iters):
130+
self.one_iteration_test(
131+
dtype=torch.float32,
132+
swap_dim=random.choice([True, False]),
133+
label_smoothing=0,
134+
reduce_loss=False,
135+
ignore_idx=True,
136+
)

transformer_engine/pytorch/cross_entropy.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,13 @@ class CrossEntropyFunction(torch.autograd.Function):
2222

2323
@staticmethod
2424
def forward(
25-
ctx, _input, target, label_smoothing=0.0, reduce_loss=False, dist_process_group=None
25+
ctx,
26+
_input,
27+
target,
28+
label_smoothing=0.0,
29+
reduce_loss=False,
30+
dist_process_group=None,
31+
ignore_idx=-100,
2632
):
2733
"""
2834
The forward pass of the Cross Entropy loss. If dist_process_group is passed for distributed loss calculation, the input to each
@@ -35,12 +41,13 @@ def forward(
3541
label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
3642
reduce_loss (bool): If true, returns the averaged loss across the B*SQ dimension.
3743
dist_process_group (torch.dist.ProcessGroup): The distributed process group the loss computation is split across, None if on 1 device.
44+
ignore_idx (int): The index for which loss and gradients are made to zero
3845
3946
Returns:
4047
tensor: The computed loss.
4148
"""
4249
loss, _input = triton_cross_entropy.cross_entropy_forward(
43-
_input, target, label_smoothing, reduce_loss, dist_process_group
50+
_input, target, label_smoothing, reduce_loss, dist_process_group, ignore_idx
4451
)
4552

4653
ctx.save_for_backward(_input.detach())

transformer_engine/pytorch/triton/cross_entropy.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ def cross_entropy_kernel(
9494
m_d_X_y_stride,
9595
rank,
9696
world_size,
97+
ignore_idx,
9798
n_cols,
9899
n_non_ignore,
99100
label_smoothing: tl.constexpr,
@@ -113,6 +114,7 @@ def cross_entropy_kernel(
113114
m_d_X_y_stride: The stride of m/d/X_y tensor.
114115
rank (int): The rank of this device in the TP group.
115116
world_size (int): The size of world involved in this distributed loss calculation.
117+
ignore_idx (int): Tokens to be ignored for loss and gradient calculation.
116118
n_cols (int): The number of columns in the input tensor.
117119
n_non_ignore (int): The number of non-ignored elements in the batch.
118120
label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
@@ -128,6 +130,13 @@ def cross_entropy_kernel(
128130
Y_ptr += program_id * Y_stride
129131
y = tl.load(Y_ptr)
130132

133+
if y == ignore_idx:
134+
# set all X_ptr as 0
135+
for i in range(0, n_cols, BLOCK_SIZE):
136+
X_offsets = i + tl.arange(0, BLOCK_SIZE)
137+
tl.store(X_ptr + X_offsets, 0.0, mask=X_offsets < n_cols)
138+
return
139+
131140
loss_ptr += program_id * loss_stride
132141
m_d_X_y_ptr += program_id * 3 * m_d_X_y_stride
133142

@@ -247,6 +256,7 @@ def cross_entropy_forward(
247256
label_smoothing: float,
248257
reduce_loss: bool,
249258
dist_process_group: Union[dist.ProcessGroup, None],
259+
ignore_idx: int,
250260
):
251261
"""Forward implementation of Cross Entropy kernel"""
252262

@@ -305,6 +315,7 @@ def cross_entropy_forward(
305315
m_d_X_y_stride=m_d_X_y_gathered.stride(-1),
306316
rank=rank,
307317
world_size=world_size,
318+
ignore_idx=ignore_idx,
308319
n_cols=V,
309320
n_non_ignore=n_rows,
310321
label_smoothing=label_smoothing,

0 commit comments

Comments
 (0)