@@ -116,18 +116,32 @@ std::vector<py::object> fused_bulk_alloc_outputs(at::Tensor input_view, std::vec
116
116
rowwise_full_tensor.data_ptr <uint8_t >() + total_size_rowwise_data;
117
117
// use from_blob to construct rowwise_data_list and rowwise_scale_list
118
118
for (int i = 0 ; i < num_splits; i++) {
119
- rowwise_data_list.emplace_back (at::from_blob (
120
- rowwise_data_ptr,
121
- {static_cast <int64_t >(rowwise_data_shapes[i].first ),
122
- static_cast <int64_t >(rowwise_data_shapes[i].second )},
123
- [rowwise_full_tensor_holder](void *) {}, at::device (at::kCUDA ).dtype (torch::kUInt8 )));
124
- rowwise_scale_list.emplace_back (at::from_blob (
125
- rowwise_scale_ptr,
126
- {static_cast <int64_t >(rowwise_scale_shapes[i].first ),
127
- static_cast <int64_t >(rowwise_scale_shapes[i].second )},
128
- [rowwise_full_tensor_holder](void *) {}, at::device (at::kCUDA ).dtype (torch::kFloat32 )));
129
- rowwise_data_ptr += rowwise_data_sizes[i];
130
- rowwise_scale_ptr += rowwise_scale_sizes[i];
119
+ if (rowwise_data_sizes[i] == 0 ) {
120
+ NVTE_CHECK (rowwise_scale_sizes[i] == 0 ,
121
+ " Rowwise scale size is not 0 when rowwise data size is 0" );
122
+ rowwise_data_list.emplace_back (
123
+ at::empty ({static_cast <int64_t >(rowwise_data_shapes[i].first ),
124
+ static_cast <int64_t >(rowwise_data_shapes[i].second )},
125
+ at::device (at::kCUDA ).dtype (torch::kUInt8 )));
126
+ rowwise_scale_list.emplace_back (
127
+ at::empty ({static_cast <int64_t >(rowwise_scale_shapes[i].first ),
128
+ static_cast <int64_t >(rowwise_scale_shapes[i].second )},
129
+ at::device (at::kCUDA ).dtype (torch::kFloat32 )));
130
+ } else {
131
+ rowwise_data_list.emplace_back (at::from_blob (
132
+ rowwise_data_ptr,
133
+ {static_cast <int64_t >(rowwise_data_shapes[i].first ),
134
+ static_cast <int64_t >(rowwise_data_shapes[i].second )},
135
+ [rowwise_full_tensor_holder](void *) {}, at::device (at::kCUDA ).dtype (torch::kUInt8 )));
136
+ rowwise_scale_list.emplace_back (at::from_blob (
137
+ rowwise_scale_ptr,
138
+ {static_cast <int64_t >(rowwise_scale_shapes[i].first ),
139
+ static_cast <int64_t >(rowwise_scale_shapes[i].second )},
140
+ [rowwise_full_tensor_holder](void *) {},
141
+ at::device (at::kCUDA ).dtype (torch::kFloat32 )));
142
+ rowwise_data_ptr += rowwise_data_sizes[i];
143
+ rowwise_scale_ptr += rowwise_scale_sizes[i];
144
+ }
131
145
}
132
146
}
133
147
@@ -139,19 +153,33 @@ std::vector<py::object> fused_bulk_alloc_outputs(at::Tensor input_view, std::vec
139
153
uint8_t * columnwise_scale_ptr =
140
154
columnwise_full_tensor.data_ptr <uint8_t >() + total_size_columnwise_data;
141
155
for (int i = 0 ; i < num_splits; i++) {
142
- columnwise_data_list.emplace_back (at::from_blob (
143
- columnwise_data_ptr,
144
- {static_cast <int64_t >(columnwise_data_shapes[i].first ),
145
- static_cast <int64_t >(columnwise_data_shapes[i].second )},
146
- [columnwise_full_tensor_holder](void *) {}, at::device (at::kCUDA ).dtype (torch::kUInt8 )));
147
- columnwise_scale_list.emplace_back (at::from_blob (
148
- columnwise_scale_ptr,
149
- {static_cast <int64_t >(columnwise_scale_shapes[i].first ),
150
- static_cast <int64_t >(columnwise_scale_shapes[i].second )},
151
- [columnwise_full_tensor_holder](void *) {},
152
- at::device (at::kCUDA ).dtype (torch::kFloat32 )));
153
- columnwise_data_ptr += columnwise_data_sizes[i];
154
- columnwise_scale_ptr += columnwise_scale_sizes[i];
156
+ if (columnwise_data_sizes[i] == 0 ) {
157
+ NVTE_CHECK (columnwise_scale_sizes[i] == 0 ,
158
+ " Columnwise scale size is not 0 when columnwise data size is 0" );
159
+ columnwise_data_list.emplace_back (
160
+ at::empty ({static_cast <int64_t >(columnwise_data_shapes[i].first ),
161
+ static_cast <int64_t >(columnwise_data_shapes[i].second )},
162
+ at::device (at::kCUDA ).dtype (torch::kUInt8 )));
163
+ columnwise_scale_list.emplace_back (
164
+ at::empty ({static_cast <int64_t >(columnwise_scale_shapes[i].first ),
165
+ static_cast <int64_t >(columnwise_scale_shapes[i].second )},
166
+ at::device (at::kCUDA ).dtype (torch::kFloat32 )));
167
+ } else {
168
+ columnwise_data_list.emplace_back (at::from_blob (
169
+ columnwise_data_ptr,
170
+ {static_cast <int64_t >(columnwise_data_shapes[i].first ),
171
+ static_cast <int64_t >(columnwise_data_shapes[i].second )},
172
+ [columnwise_full_tensor_holder](void *) {},
173
+ at::device (at::kCUDA ).dtype (torch::kUInt8 )));
174
+ columnwise_scale_list.emplace_back (at::from_blob (
175
+ columnwise_scale_ptr,
176
+ {static_cast <int64_t >(columnwise_scale_shapes[i].first ),
177
+ static_cast <int64_t >(columnwise_scale_shapes[i].second )},
178
+ [columnwise_full_tensor_holder](void *) {},
179
+ at::device (at::kCUDA ).dtype (torch::kFloat32 )));
180
+ columnwise_data_ptr += columnwise_data_sizes[i];
181
+ columnwise_scale_ptr += columnwise_scale_sizes[i];
182
+ }
155
183
}
156
184
}
157
185
0 commit comments