Skip to content

Commit 78fa757

Browse files
authored
Re-introduce "XLA_USE_32BIT_LONG" flag (#8571)
1 parent 18efb6d commit 78fa757

File tree

7 files changed

+75
-57
lines changed

7 files changed

+75
-57
lines changed

test/neuron/run_tests.sh

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -160,9 +160,7 @@ function run_xla_op_tests1 {
160160
run_test "$CDIR/dynamo/test_graph_input_matcher.py"
161161
run_test "$CDIR/dynamo/test_dynamo_config.py"
162162
run_save_tensor_ir run_test "$CDIR/dynamo/test_dynamo_graph_dump.py"
163-
#run_test "$CDIR/test_data_type.py"
164-
run_use_bf16 "$CDIR/test_data_type.py"
165-
run_downcast_bf16 "$CDIR/test_data_type.py"
163+
run_test "$CDIR/test_data_type.py"
166164
#run_test "$CDIR/test_fp8.py"
167165
run_xla_ir_debug "$CDIR/test_env_var_mapper.py"
168166
run_xla_hlo_debug "$CDIR/test_env_var_mapper.py"

test/neuron/test_neuron_data_types.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,10 @@ def test_datatypes(self):
2727
(torch.double, "f32", torch.floor_divide),
2828
(torch.int16, "s32", torch.add),
2929
(torch.int32, "s32", torch.add),
30-
(torch.int64, "s32", torch.add),
30+
(torch.int64, "s64", torch.add),
3131
(torch.uint16, "u32", torch.add),
3232
(torch.uint32, "u32", torch.add),
33-
(torch.uint64, "u32", torch.add)]
33+
(torch.uint64, "u64", torch.add)]
3434

3535
for dtype, op_xla_dtype, op in test_cases:
3636
with self.subTest(dtype=dtype, op_xla_dtype=op_xla_dtype, op=op):

test/run_tests.sh

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -178,8 +178,6 @@ function run_xla_op_tests1 {
178178
run_test "$CDIR/dynamo/test_dynamo_config.py"
179179
run_save_tensor_ir run_test "$CDIR/dynamo/test_dynamo_graph_dump.py"
180180
run_test "$CDIR/test_data_type.py"
181-
run_use_bf16 "$CDIR/test_data_type.py"
182-
run_downcast_bf16 "$CDIR/test_data_type.py"
183181
run_test "$CDIR/test_fp8.py"
184182
run_xla_ir_debug run_test "$CDIR/test_env_var_mapper.py"
185183
run_xla_hlo_debug run_test "$CDIR/test_env_var_mapper.py"

test/test_data_type.py

Lines changed: 52 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,73 +1,82 @@
11
import os
2+
import sys
3+
import unittest
24

35
import torch
46
import torch_xla
57
import torch_xla.core.xla_model as xm
68
import torch_xla.utils.utils as xu
7-
import unittest
89

910

10-
def check_env_flag(name, default=''):
11-
return os.getenv(name, default).upper() in ['ON', '1', 'YES', 'TRUE', 'Y']
11+
class XlaDataTypeTest(unittest.TestCase):
1212

13+
def setUp(cls):
14+
cls.original_env = {
15+
'XLA_USE_BF16': os.environ.get('XLA_USE_BF16'),
16+
'XLA_DOWNCAST_BF16': os.environ.get('XLA_DOWNCAST_BF16'),
17+
'XLA_USE_32BIT_LONG': os.environ.get('XLA_USE_32BIT_LONG')
18+
}
1319

14-
class XlaDataTypeTest(unittest.TestCase):
20+
def tearDown(self):
21+
for key, value in self.original_env.items():
22+
if value is None:
23+
os.environ.pop(key, None)
24+
else:
25+
os.environ[key] = value
1526

16-
def test_datatype_f32(self):
17-
t1 = torch.tensor([2.0, 3.0], dtype=torch.float, device=xm.xla_device())
18-
t2 = torch.tensor([2.0, 3.0], dtype=torch.float, device=xm.xla_device())
19-
t3 = torch.div(t1, t2, rounding_mode='floor')
20-
assert t3.dtype == torch.float
27+
def _set_env(self, **kwargs):
28+
for key, value in kwargs.items():
29+
os.environ[key] = value
2130

22-
hlo_text = torch_xla._XLAC._get_xla_tensors_text([t3])
23-
device_data_hlo = hlo_text.split('\n')[1]
24-
assert 'xla::device_data' in device_data_hlo, device_data_hlo
25-
if check_env_flag('XLA_USE_BF16') or check_env_flag('XLA_DOWNCAST_BF16'):
26-
assert 'bf16' in device_data_hlo, device_data_hlo
27-
elif check_env_flag('XLA_USE_FP16') or check_env_flag('XLA_DOWNCAST_FP16'):
28-
assert 'f16' in device_data_hlo, device_data_hlo
29-
else:
30-
assert 'f32' in device_data_hlo, device_data_hlo
31-
32-
def test_datatype_f64(self):
33-
t1 = torch.tensor([2.0, 3.0], dtype=torch.double, device=xm.xla_device())
34-
t2 = torch.tensor([2.0, 3.0], dtype=torch.double, device=xm.xla_device())
35-
t3 = torch.div(t1, t2, rounding_mode='floor')
36-
assert t3.dtype == torch.double
31+
def _test_datatype(self, dtype, expected_type, op):
32+
t1 = torch.tensor([2, 3], dtype=dtype, device=xm.xla_device())
33+
t2 = torch.tensor([2, 3], dtype=dtype, device=xm.xla_device())
34+
t3 = op(t1, t2)
35+
self.assertEqual(t3.dtype, dtype)
3736

3837
hlo_text = torch_xla._XLAC._get_xla_tensors_text([t3])
39-
device_data_hlo = hlo_text.split('\n')[1]
40-
assert 'xla::device_data' in device_data_hlo, device_data_hlo
41-
if check_env_flag('XLA_USE_BF16'):
42-
assert 'bf16' in device_data_hlo, device_data_hlo
43-
elif check_env_flag('XLA_USE_FP16'):
44-
assert 'f16' in device_data_hlo, device_data_hlo
45-
elif check_env_flag('XLA_DOWNCAST_BF16') or check_env_flag(
46-
'XLA_DOWNCAST_FP16'):
47-
assert 'f32' in device_data_hlo, device_data_hlo
48-
else:
49-
assert 'f64' in device_data_hlo, device_data_hlo
38+
device_data_hlo = hlo_text.split('\n')[2]
39+
self.assertIn('xla::device_data', device_data_hlo)
40+
self.assertIn(expected_type, device_data_hlo)
41+
42+
def test_datatype_use_bf16(self):
43+
self._set_env(XLA_USE_BF16='1')
44+
self._test_datatype(torch.double, 'bf16', torch.floor_divide)
45+
self._test_datatype(torch.float, 'bf16', torch.floor_divide)
46+
47+
def test_datatype_downcast_bf16(self):
48+
self._set_env(XLA_DOWNCAST_BF16='1')
49+
self._test_datatype(torch.double, 'bf16', torch.floor_divide)
50+
self._test_datatype(torch.float, 'bf16', torch.floor_divide)
51+
52+
def test_datatype_use_32bit_long(self):
53+
self._set_env(XLA_USE_32BIT_LONG='1')
54+
self._test_datatype(torch.int64, 's32', torch.add)
55+
self._test_datatype(torch.uint64, 'u32', torch.add)
5056

5157
def test_module_to_dtype(self):
5258
device = torch_xla.device()
5359
linear = torch.nn.Linear(
5460
5, 10, dtype=torch.float32).to(device).to(torch.bfloat16)
55-
input = torch.randn(
56-
10,
57-
5,
58-
).to(device).to(torch.bfloat16)
61+
input = torch.randn(10, 5).to(device).to(torch.bfloat16)
5962
xm.mark_step()
6063
res = linear(input)
6164

6265
hlo_text = torch_xla._XLAC._get_xla_tensors_text([res])
6366
res_hlo = hlo_text.split('\n')[-3]
64-
assert 'bf16' in res_hlo, res_hlo
67+
self.assertIn('bf16', res_hlo)
6568

6669
linear_weight_hlo = torch_xla._XLAC._get_xla_tensors_text([linear.weight
6770
]).split('\n')[-3]
68-
assert 'bf16' in linear_weight_hlo, linear_weight_hlo
71+
self.assertIn('bf16', linear_weight_hlo)
6972

7073

7174
if __name__ == '__main__':
72-
test = unittest.main()
73-
sys.exit(0 if test.result.wasSuccessful() else 1)
75+
suite = unittest.TestSuite()
76+
suite.addTest(XlaDataTypeTest("test_datatype_use_bf16"))
77+
suite.addTest(XlaDataTypeTest("test_datatype_downcast_bf16"))
78+
suite.addTest(XlaDataTypeTest("test_datatype_use_32bit_long"))
79+
suite.addTest(XlaDataTypeTest("test_module_to_dtype"))
80+
runner = unittest.TextTestRunner(failfast=True)
81+
result = runner.run(suite)
82+
sys.exit(0 if result.wasSuccessful() else 1)

test/tpu/run_tests.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ python3 "$TEST_CDIR/torch_distributed/test_torch_distributed_multi_all_reduce_xl
4747
python3 "$TEST_CDIR/torch_distributed/test_torch_distributed_reduce_scatter_xla_backend.py"
4848
python3 "$TEST_CDIR/quantized_ops/test_dot_general.py"
4949
run_xla_ir_hlo_debug python3 "$TEST_CDIR/test_user_computation_debug_cache.py"
50+
python3 "$TEST_CDIR/test_data_type.py"
5051

5152
# run examples, each test should takes <2 minutes
5253
python3 "$TEST_CDIR/../examples/data_parallel/train_resnet_spmd_data_parallel.py"

torch_xla/__init__.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -178,9 +178,7 @@ def _setup_tpu_vm_library_path() -> bool:
178178

179179

180180
def _check_deprecated_env_var():
181-
deprecated_env_vars = [
182-
'XLA_USE_FP16', 'XLA_DOWNCAST_FP16', 'XLA_USE_32BIT_LONG'
183-
]
181+
deprecated_env_vars = ['XLA_USE_FP16', 'XLA_DOWNCAST_FP16']
184182
for env_var in deprecated_env_vars:
185183
if os.environ.get(env_var):
186184
warnings.warn(f"The environment variable '{env_var}' is deprecated "

torch_xla/csrc/dtype.cpp

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,17 @@ bool ShouldDowncastToBF16() {
3030
return downcast_bf16;
3131
}
3232

33+
bool ShouldUse32BitLong() {
34+
bool use_32bit_long =
35+
runtime::sys_util::GetEnvBool("XLA_USE_32BIT_LONG", false);
36+
if (use_32bit_long) {
37+
std::cout
38+
<< "XLA_USE_32BIT_LONG will be deprecated after the 2.6 release\n";
39+
TF_LOG(INFO) << "Using 32bit integers for kLong values";
40+
}
41+
return use_32bit_long;
42+
}
43+
3344
bool UseBF16() {
3445
static bool use_bf16 = ShouldUseBF16();
3546
return use_bf16;
@@ -40,6 +51,11 @@ bool DowncastBF16() {
4051
return downcast_bf16;
4152
}
4253

54+
bool Use32BitLong() {
55+
static bool use_32bit_long = ShouldUse32BitLong();
56+
return use_32bit_long;
57+
}
58+
4359
} // namespace
4460

4561
at::ScalarType TorchTypeFromXlaType(xla::PrimitiveType xla_type) {
@@ -143,11 +159,9 @@ xla::PrimitiveType MaybeDowncastToXlaDeviceType(
143159
return CheckNeuronDevice(hw_type) ? xla::PrimitiveType::S32
144160
: xla::PrimitiveType::S16;
145161
case xla::PrimitiveType::S64:
146-
return CheckNeuronDevice(hw_type) ? xla::PrimitiveType::S32
147-
: xla::PrimitiveType::S64;
162+
return Use32BitLong() ? xla::PrimitiveType::S32 : xla::PrimitiveType::S64;
148163
case xla::PrimitiveType::U64:
149-
return CheckNeuronDevice(hw_type) ? xla::PrimitiveType::U32
150-
: xla::PrimitiveType::U64;
164+
return Use32BitLong() ? xla::PrimitiveType::U32 : xla::PrimitiveType::U64;
151165
case xla::PrimitiveType::C128:
152166
return xla::PrimitiveType::C128;
153167
default:

0 commit comments

Comments
 (0)