Skip to content

Commit a9d25dc

Browse files
authored
Add missing tensor data types (unsigned int 16, 32, 64) to PopulateTensorBuffer (#9090)
1 parent 01b1170 commit a9d25dc

File tree

2 files changed

+34
-0
lines changed

2 files changed

+34
-0
lines changed

test/cpp/test_tensor.cpp

+22
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,28 @@ TEST_F(TensorTest, TestConversions) {
6262
at::TensorOptions(at::kInt));
6363
EXPECT_TRUE(CheckBidirectionalConversion(a, at::ScalarType::Long));
6464
}
65+
{
66+
at::Tensor a = at::randint(std::numeric_limits<uint16_t>::min(),
67+
std::numeric_limits<uint16_t>::max(), {2, 2},
68+
at::TensorOptions(at::kUInt16));
69+
EXPECT_TRUE(CheckBidirectionalConversion(a, at::ScalarType::UInt16));
70+
EXPECT_TRUE(CheckBidirectionalConversion(a, at::ScalarType::UInt32));
71+
EXPECT_TRUE(CheckBidirectionalConversion(a, at::ScalarType::UInt64));
72+
}
73+
{
74+
at::Tensor a = at::randint(std::numeric_limits<uint32_t>::min(),
75+
std::numeric_limits<uint32_t>::max(), {2, 2},
76+
at::TensorOptions(at::kUInt32));
77+
EXPECT_TRUE(CheckBidirectionalConversion(a, at::ScalarType::UInt32));
78+
EXPECT_TRUE(CheckBidirectionalConversion(a, at::ScalarType::UInt64));
79+
}
80+
{
81+
// The range of uint64_t is too large for randint to generate.
82+
at::Tensor a = at::randint(std::numeric_limits<uint32_t>::min(),
83+
std::numeric_limits<uint32_t>::max(), {2, 2},
84+
at::TensorOptions(at::kUInt64));
85+
EXPECT_TRUE(CheckBidirectionalConversion(a, at::ScalarType::UInt64));
86+
}
6587
{
6688
at::Tensor a = at::randint(0, 1, {2, 2}, at::TensorOptions(at::kByte));
6789
EXPECT_TRUE(CheckBidirectionalConversion(a, at::ScalarType::Byte,

torch_xla/csrc/tensor_util.cpp

+12
Original file line numberDiff line numberDiff line change
@@ -701,6 +701,18 @@ void PopulateTensorBuffer(const at::Tensor& tensor,
701701
TensorToBufferSType<c10::complex<double>>(tensor, dest_shape, dest_buffer,
702702
dest_buffer_size, device);
703703
break;
704+
case at::ScalarType::UInt16:
705+
TensorToBufferSType<uint16_t>(tensor, dest_shape, dest_buffer,
706+
dest_buffer_size, device);
707+
break;
708+
case at::ScalarType::UInt32:
709+
TensorToBufferSType<uint32_t>(tensor, dest_shape, dest_buffer,
710+
dest_buffer_size, device);
711+
break;
712+
case at::ScalarType::UInt64:
713+
TensorToBufferSType<uint64_t>(tensor, dest_shape, dest_buffer,
714+
dest_buffer_size, device);
715+
break;
704716
default:
705717
XLA_ERROR() << "Tensor type not supported: " << tensor.type();
706718
}

0 commit comments

Comments
 (0)