|
1 | 1 | import os
|
| 2 | +import sys |
| 3 | +import unittest |
2 | 4 |
|
3 | 5 | import torch
|
4 | 6 | import torch_xla
|
5 | 7 | import torch_xla.core.xla_model as xm
|
6 | 8 | import torch_xla.utils.utils as xu
|
7 |
| -import unittest |
8 | 9 |
|
9 | 10 |
|
10 |
| -def check_env_flag(name, default=''): |
11 |
| - return os.getenv(name, default).upper() in ['ON', '1', 'YES', 'TRUE', 'Y'] |
| 11 | +class XlaDataTypeTest(unittest.TestCase): |
12 | 12 |
|
| 13 | + def setUp(cls): |
| 14 | + cls.original_env = { |
| 15 | + 'XLA_USE_BF16': os.environ.get('XLA_USE_BF16'), |
| 16 | + 'XLA_DOWNCAST_BF16': os.environ.get('XLA_DOWNCAST_BF16'), |
| 17 | + 'XLA_USE_32BIT_LONG': os.environ.get('XLA_USE_32BIT_LONG') |
| 18 | + } |
13 | 19 |
|
14 |
| -class XlaDataTypeTest(unittest.TestCase): |
| 20 | + def tearDown(self): |
| 21 | + for key, value in self.original_env.items(): |
| 22 | + if value is None: |
| 23 | + os.environ.pop(key, None) |
| 24 | + else: |
| 25 | + os.environ[key] = value |
15 | 26 |
|
16 |
| - def test_datatype_f32(self): |
17 |
| - t1 = torch.tensor([2.0, 3.0], dtype=torch.float, device=xm.xla_device()) |
18 |
| - t2 = torch.tensor([2.0, 3.0], dtype=torch.float, device=xm.xla_device()) |
19 |
| - t3 = torch.div(t1, t2, rounding_mode='floor') |
20 |
| - assert t3.dtype == torch.float |
| 27 | + def _set_env(self, **kwargs): |
| 28 | + for key, value in kwargs.items(): |
| 29 | + os.environ[key] = value |
21 | 30 |
|
22 |
| - hlo_text = torch_xla._XLAC._get_xla_tensors_text([t3]) |
23 |
| - device_data_hlo = hlo_text.split('\n')[1] |
24 |
| - assert 'xla::device_data' in device_data_hlo, device_data_hlo |
25 |
| - if check_env_flag('XLA_USE_BF16') or check_env_flag('XLA_DOWNCAST_BF16'): |
26 |
| - assert 'bf16' in device_data_hlo, device_data_hlo |
27 |
| - elif check_env_flag('XLA_USE_FP16') or check_env_flag('XLA_DOWNCAST_FP16'): |
28 |
| - assert 'f16' in device_data_hlo, device_data_hlo |
29 |
| - else: |
30 |
| - assert 'f32' in device_data_hlo, device_data_hlo |
31 |
| - |
32 |
| - def test_datatype_f64(self): |
33 |
| - t1 = torch.tensor([2.0, 3.0], dtype=torch.double, device=xm.xla_device()) |
34 |
| - t2 = torch.tensor([2.0, 3.0], dtype=torch.double, device=xm.xla_device()) |
35 |
| - t3 = torch.div(t1, t2, rounding_mode='floor') |
36 |
| - assert t3.dtype == torch.double |
| 31 | + def _test_datatype(self, dtype, expected_type, op): |
| 32 | + t1 = torch.tensor([2, 3], dtype=dtype, device=xm.xla_device()) |
| 33 | + t2 = torch.tensor([2, 3], dtype=dtype, device=xm.xla_device()) |
| 34 | + t3 = op(t1, t2) |
| 35 | + self.assertEqual(t3.dtype, dtype) |
37 | 36 |
|
38 | 37 | hlo_text = torch_xla._XLAC._get_xla_tensors_text([t3])
|
39 |
| - device_data_hlo = hlo_text.split('\n')[1] |
40 |
| - assert 'xla::device_data' in device_data_hlo, device_data_hlo |
41 |
| - if check_env_flag('XLA_USE_BF16'): |
42 |
| - assert 'bf16' in device_data_hlo, device_data_hlo |
43 |
| - elif check_env_flag('XLA_USE_FP16'): |
44 |
| - assert 'f16' in device_data_hlo, device_data_hlo |
45 |
| - elif check_env_flag('XLA_DOWNCAST_BF16') or check_env_flag( |
46 |
| - 'XLA_DOWNCAST_FP16'): |
47 |
| - assert 'f32' in device_data_hlo, device_data_hlo |
48 |
| - else: |
49 |
| - assert 'f64' in device_data_hlo, device_data_hlo |
| 38 | + device_data_hlo = hlo_text.split('\n')[2] |
| 39 | + self.assertIn('xla::device_data', device_data_hlo) |
| 40 | + self.assertIn(expected_type, device_data_hlo) |
| 41 | + |
| 42 | + def test_datatype_use_bf16(self): |
| 43 | + self._set_env(XLA_USE_BF16='1') |
| 44 | + self._test_datatype(torch.double, 'bf16', torch.floor_divide) |
| 45 | + self._test_datatype(torch.float, 'bf16', torch.floor_divide) |
| 46 | + |
| 47 | + def test_datatype_downcast_bf16(self): |
| 48 | + self._set_env(XLA_DOWNCAST_BF16='1') |
| 49 | + self._test_datatype(torch.double, 'bf16', torch.floor_divide) |
| 50 | + self._test_datatype(torch.float, 'bf16', torch.floor_divide) |
| 51 | + |
| 52 | + def test_datatype_use_32bit_long(self): |
| 53 | + self._set_env(XLA_USE_32BIT_LONG='1') |
| 54 | + self._test_datatype(torch.int64, 's32', torch.add) |
| 55 | + self._test_datatype(torch.uint64, 'u32', torch.add) |
50 | 56 |
|
51 | 57 | def test_module_to_dtype(self):
|
52 | 58 | device = torch_xla.device()
|
53 | 59 | linear = torch.nn.Linear(
|
54 | 60 | 5, 10, dtype=torch.float32).to(device).to(torch.bfloat16)
|
55 |
| - input = torch.randn( |
56 |
| - 10, |
57 |
| - 5, |
58 |
| - ).to(device).to(torch.bfloat16) |
| 61 | + input = torch.randn(10, 5).to(device).to(torch.bfloat16) |
59 | 62 | xm.mark_step()
|
60 | 63 | res = linear(input)
|
61 | 64 |
|
62 | 65 | hlo_text = torch_xla._XLAC._get_xla_tensors_text([res])
|
63 | 66 | res_hlo = hlo_text.split('\n')[-3]
|
64 |
| - assert 'bf16' in res_hlo, res_hlo |
| 67 | + self.assertIn('bf16', res_hlo) |
65 | 68 |
|
66 | 69 | linear_weight_hlo = torch_xla._XLAC._get_xla_tensors_text([linear.weight
|
67 | 70 | ]).split('\n')[-3]
|
68 |
| - assert 'bf16' in linear_weight_hlo, linear_weight_hlo |
| 71 | + self.assertIn('bf16', linear_weight_hlo) |
69 | 72 |
|
70 | 73 |
|
71 | 74 | if __name__ == '__main__':
|
72 |
| - test = unittest.main() |
73 |
| - sys.exit(0 if test.result.wasSuccessful() else 1) |
| 75 | + suite = unittest.TestSuite() |
| 76 | + suite.addTest(XlaDataTypeTest("test_datatype_use_bf16")) |
| 77 | + suite.addTest(XlaDataTypeTest("test_datatype_downcast_bf16")) |
| 78 | + suite.addTest(XlaDataTypeTest("test_datatype_use_32bit_long")) |
| 79 | + suite.addTest(XlaDataTypeTest("test_module_to_dtype")) |
| 80 | + runner = unittest.TextTestRunner(failfast=True) |
| 81 | + result = runner.run(suite) |
| 82 | + sys.exit(0 if result.wasSuccessful() else 1) |
0 commit comments