Skip to content

Commit 940bee4

Browse files
authored
Enable cross entropy loss for xla autocast with FP32 precision (#7992) (#8094)
1 parent 0f645b2 commit 940bee4

File tree

2 files changed

+30
-1
lines changed

2 files changed

+30
-1
lines changed

test/test_bf16_autocast.py

+29
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import os
2+
import re
3+
import torch
4+
import torch_xla
5+
import torch_xla.core.xla_model as xm
6+
import unittest
7+
8+
device = xm.xla_device()
9+
10+
11+
class TestAutocastXla(unittest.TestCase):
12+
13+
def test_cross_entropy_loss(self):
14+
data = torch.randn(16, 10).to(torch.bfloat16).to(device)
15+
target = torch.randn(16, 10).to(torch.bfloat16).to(device)
16+
with torch.autocast("xla"):
17+
loss = torch.nn.CrossEntropyLoss()(data, target)
18+
hlo = torch_xla._XLAC._get_xla_tensors_hlo([loss])
19+
self.assertTrue(
20+
re.search(rf".*convert.*f32.*convert.*bf16", hlo) is not None)
21+
22+
self.assertTrue(
23+
re.search(rf".*exponential.*f32.*exponential.*f32", hlo) is not None)
24+
25+
self.assertTrue(re.search(rf".*log.*f32.*log.*f32", hlo) is not None)
26+
27+
28+
if __name__ == "__main__":
29+
unittest.main()

torch_xla/csrc/autocast_mode.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ TORCH_LIBRARY_IMPL(aten, AutocastXLA, m) {
9595
KERNEL_XLA(hinge_embedding_loss, fp32)
9696
// KERNEL_XLA(poisson_nll_loss, fp32)
9797
KERNEL_XLA(smooth_l1_loss, fp32)
98-
// KERNEL_XLA(cross_entropy_loss, fp32)
98+
KERNEL_XLA(cross_entropy_loss, fp32)
9999
KERNEL_XLA(l1_loss, fp32)
100100
// KERNEL_XLA(huber_loss, fp32)
101101
KERNEL_XLA(margin_ranking_loss, fp32)

0 commit comments

Comments
 (0)