Skip to content

Commit 718c9ca

Browse files
committed
attempt to handle case where experts get zero token
Signed-off-by: zhongboz <[email protected]>
1 parent da41b4e commit 718c9ca

File tree

2 files changed

+53
-25
lines changed

2 files changed

+53
-25
lines changed

qa/format.sh

100644100755
File mode changed.

transformer_engine/pytorch/csrc/extensions/transpose.cpp

Lines changed: 53 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -116,18 +116,32 @@ std::vector<py::object> fused_bulk_alloc_outputs(at::Tensor input_view, std::vec
116116
rowwise_full_tensor.data_ptr<uint8_t>() + total_size_rowwise_data;
117117
// use from_blob to construct rowwise_data_list and rowwise_scale_list
118118
for (int i = 0; i < num_splits; i++) {
119-
rowwise_data_list.emplace_back(at::from_blob(
120-
rowwise_data_ptr,
121-
{static_cast<int64_t>(rowwise_data_shapes[i].first),
122-
static_cast<int64_t>(rowwise_data_shapes[i].second)},
123-
[rowwise_full_tensor_holder](void*) {}, at::device(at::kCUDA).dtype(torch::kUInt8)));
124-
rowwise_scale_list.emplace_back(at::from_blob(
125-
rowwise_scale_ptr,
126-
{static_cast<int64_t>(rowwise_scale_shapes[i].first),
127-
static_cast<int64_t>(rowwise_scale_shapes[i].second)},
128-
[rowwise_full_tensor_holder](void*) {}, at::device(at::kCUDA).dtype(torch::kFloat32)));
129-
rowwise_data_ptr += rowwise_data_sizes[i];
130-
rowwise_scale_ptr += rowwise_scale_sizes[i];
119+
if (rowwise_data_sizes[i] == 0) {
120+
NVTE_CHECK(rowwise_scale_sizes[i] == 0,
121+
"Rowwise scale size is not 0 when rowwise data size is 0");
122+
rowwise_data_list.emplace_back(
123+
at::empty({static_cast<int64_t>(rowwise_data_shapes[i].first),
124+
static_cast<int64_t>(rowwise_data_shapes[i].second)},
125+
at::device(at::kCUDA).dtype(torch::kUInt8)));
126+
rowwise_scale_list.emplace_back(
127+
at::empty({static_cast<int64_t>(rowwise_scale_shapes[i].first),
128+
static_cast<int64_t>(rowwise_scale_shapes[i].second)},
129+
at::device(at::kCUDA).dtype(torch::kFloat32)));
130+
} else {
131+
rowwise_data_list.emplace_back(at::from_blob(
132+
rowwise_data_ptr,
133+
{static_cast<int64_t>(rowwise_data_shapes[i].first),
134+
static_cast<int64_t>(rowwise_data_shapes[i].second)},
135+
[rowwise_full_tensor_holder](void*) {}, at::device(at::kCUDA).dtype(torch::kUInt8)));
136+
rowwise_scale_list.emplace_back(at::from_blob(
137+
rowwise_scale_ptr,
138+
{static_cast<int64_t>(rowwise_scale_shapes[i].first),
139+
static_cast<int64_t>(rowwise_scale_shapes[i].second)},
140+
[rowwise_full_tensor_holder](void*) {},
141+
at::device(at::kCUDA).dtype(torch::kFloat32)));
142+
rowwise_data_ptr += rowwise_data_sizes[i];
143+
rowwise_scale_ptr += rowwise_scale_sizes[i];
144+
}
131145
}
132146
}
133147

@@ -139,19 +153,33 @@ std::vector<py::object> fused_bulk_alloc_outputs(at::Tensor input_view, std::vec
139153
uint8_t* columnwise_scale_ptr =
140154
columnwise_full_tensor.data_ptr<uint8_t>() + total_size_columnwise_data;
141155
for (int i = 0; i < num_splits; i++) {
142-
columnwise_data_list.emplace_back(at::from_blob(
143-
columnwise_data_ptr,
144-
{static_cast<int64_t>(columnwise_data_shapes[i].first),
145-
static_cast<int64_t>(columnwise_data_shapes[i].second)},
146-
[columnwise_full_tensor_holder](void*) {}, at::device(at::kCUDA).dtype(torch::kUInt8)));
147-
columnwise_scale_list.emplace_back(at::from_blob(
148-
columnwise_scale_ptr,
149-
{static_cast<int64_t>(columnwise_scale_shapes[i].first),
150-
static_cast<int64_t>(columnwise_scale_shapes[i].second)},
151-
[columnwise_full_tensor_holder](void*) {},
152-
at::device(at::kCUDA).dtype(torch::kFloat32)));
153-
columnwise_data_ptr += columnwise_data_sizes[i];
154-
columnwise_scale_ptr += columnwise_scale_sizes[i];
156+
if (columnwise_data_sizes[i] == 0) {
157+
NVTE_CHECK(columnwise_scale_sizes[i] == 0,
158+
"Columnwise scale size is not 0 when columnwise data size is 0");
159+
columnwise_data_list.emplace_back(
160+
at::empty({static_cast<int64_t>(columnwise_data_shapes[i].first),
161+
static_cast<int64_t>(columnwise_data_shapes[i].second)},
162+
at::device(at::kCUDA).dtype(torch::kUInt8)));
163+
columnwise_scale_list.emplace_back(
164+
at::empty({static_cast<int64_t>(columnwise_scale_shapes[i].first),
165+
static_cast<int64_t>(columnwise_scale_shapes[i].second)},
166+
at::device(at::kCUDA).dtype(torch::kFloat32)));
167+
} else {
168+
columnwise_data_list.emplace_back(at::from_blob(
169+
columnwise_data_ptr,
170+
{static_cast<int64_t>(columnwise_data_shapes[i].first),
171+
static_cast<int64_t>(columnwise_data_shapes[i].second)},
172+
[columnwise_full_tensor_holder](void*) {},
173+
at::device(at::kCUDA).dtype(torch::kUInt8)));
174+
columnwise_scale_list.emplace_back(at::from_blob(
175+
columnwise_scale_ptr,
176+
{static_cast<int64_t>(columnwise_scale_shapes[i].first),
177+
static_cast<int64_t>(columnwise_scale_shapes[i].second)},
178+
[columnwise_full_tensor_holder](void*) {},
179+
at::device(at::kCUDA).dtype(torch::kFloat32)));
180+
columnwise_data_ptr += columnwise_data_sizes[i];
181+
columnwise_scale_ptr += columnwise_scale_sizes[i];
182+
}
155183
}
156184
}
157185

0 commit comments

Comments
 (0)