4
4
* See LICENSE for license information.
5
5
************************************************************************/
6
6
7
- #include < optional>
8
7
#include < pybind.h>
9
8
9
+ #include < iostream>
10
+ #include < optional>
11
+
10
12
#include " extensions.h"
11
13
#include " pybind.h"
12
14
13
- #include < iostream>
14
-
15
15
namespace transformer_engine ::pytorch {
16
16
17
- std::vector<py::object> fused_bulk_alloc_outputs (at::Tensor input_view, std::vector<int > m_splits,
18
- std::vector<py::handle> quantizer_list) {
17
+ std::vector<py::object> fused_bulk_alloc_outputs (at::Tensor input_view, std::vector<int > m_splits,
18
+ std::vector<py::handle> quantizer_list) {
19
19
init_extension ();
20
20
using namespace pybind11 ::literals; // For operator""_a
21
21
@@ -58,25 +58,38 @@ std::vector<py::object> fused_bulk_alloc_outputs(at::Tensor input_view, std::vec
58
58
std::vector<size_t > columnwise_data_sizes;
59
59
std::vector<size_t > columnwise_scale_sizes;
60
60
for (int i = 0 ; i < num_splits; i++) {
61
- std::pair<size_t , size_t > input_view_i_shape = std::make_pair ((size_t )m_splits[i], (size_t )hidden_dim);
61
+ std::pair<size_t , size_t > input_view_i_shape =
62
+ std::make_pair ((size_t )m_splits[i], (size_t )hidden_dim);
62
63
if (rowwise_usage) {
63
64
rowwise_data_shapes.emplace_back (input_view_i_shape);
64
- rowwise_scale_shapes.emplace_back (blockwise_quantizers[i]->get_scale_shape ({input_view_i_shape.first , input_view_i_shape.second }, false ));
65
- rowwise_data_sizes.emplace_back (input_view_i_shape.first * input_view_i_shape.second * fp8_elem_size);
66
- rowwise_scale_sizes.emplace_back (rowwise_scale_shapes.back ().first * rowwise_scale_shapes.back ().second * scale_elem_size);
65
+ rowwise_scale_shapes.emplace_back (blockwise_quantizers[i]->get_scale_shape (
66
+ {input_view_i_shape.first , input_view_i_shape.second }, false ));
67
+ rowwise_data_sizes.emplace_back (input_view_i_shape.first * input_view_i_shape.second *
68
+ fp8_elem_size);
69
+ rowwise_scale_sizes.emplace_back (rowwise_scale_shapes.back ().first *
70
+ rowwise_scale_shapes.back ().second * scale_elem_size);
67
71
}
68
72
if (columnwise_usage) {
69
- columnwise_data_shapes.emplace_back (std::make_pair (input_view_i_shape.second , input_view_i_shape.first ));
70
- columnwise_scale_shapes.emplace_back (blockwise_quantizers[i]->get_scale_shape ({input_view_i_shape.first , input_view_i_shape.second }, true ));
71
- columnwise_data_sizes.emplace_back (input_view_i_shape.first * input_view_i_shape.second * fp8_elem_size);
72
- columnwise_scale_sizes.emplace_back (columnwise_scale_shapes.back ().first * columnwise_scale_shapes.back ().second * scale_elem_size);
73
+ columnwise_data_shapes.emplace_back (
74
+ std::make_pair (input_view_i_shape.second , input_view_i_shape.first ));
75
+ columnwise_scale_shapes.emplace_back (blockwise_quantizers[i]->get_scale_shape (
76
+ {input_view_i_shape.first , input_view_i_shape.second }, true ));
77
+ columnwise_data_sizes.emplace_back (input_view_i_shape.first * input_view_i_shape.second *
78
+ fp8_elem_size);
79
+ columnwise_scale_sizes.emplace_back (columnwise_scale_shapes.back ().first *
80
+ columnwise_scale_shapes.back ().second *
81
+ scale_elem_size);
73
82
}
74
83
}
75
84
76
- size_t total_size_rowwise_data = std::accumulate (rowwise_data_sizes.begin (), rowwise_data_sizes.end (), 0 );
77
- size_t total_size_rowwise_scale = std::accumulate (rowwise_scale_sizes.begin (), rowwise_scale_sizes.end (), 0 );
78
- size_t total_size_columnwise_data = std::accumulate (columnwise_data_sizes.begin (), columnwise_data_sizes.end (), 0 );
79
- size_t total_size_columnwise_scale = std::accumulate (columnwise_scale_sizes.begin (), columnwise_scale_sizes.end (), 0 );
85
+ size_t total_size_rowwise_data =
86
+ std::accumulate (rowwise_data_sizes.begin (), rowwise_data_sizes.end (), 0 );
87
+ size_t total_size_rowwise_scale =
88
+ std::accumulate (rowwise_scale_sizes.begin (), rowwise_scale_sizes.end (), 0 );
89
+ size_t total_size_columnwise_data =
90
+ std::accumulate (columnwise_data_sizes.begin (), columnwise_data_sizes.end (), 0 );
91
+ size_t total_size_columnwise_scale =
92
+ std::accumulate (columnwise_scale_sizes.begin (), columnwise_scale_sizes.end (), 0 );
80
93
81
94
size_t total_size_rowwise = total_size_rowwise_data + total_size_rowwise_scale;
82
95
size_t total_size_columnwise = total_size_columnwise_data + total_size_columnwise_scale;
@@ -90,49 +103,67 @@ std::vector<py::object> fused_bulk_alloc_outputs(at::Tensor input_view, std::vec
90
103
at::Tensor columnwise_full_tensor;
91
104
92
105
if (rowwise_usage) {
93
- rowwise_full_tensor = at::empty ({(int64_t )total_size_rowwise}, at::device (input_view.device ()).dtype (torch::kUInt8 ));
106
+ rowwise_full_tensor = at::empty ({(int64_t )total_size_rowwise},
107
+ at::device (input_view.device ()).dtype (torch::kUInt8 ));
94
108
// use raw pointer math + from blob, avoid torch slice to reduce cpu overhead
95
109
uint8_t * rowwise_data_ptr = rowwise_full_tensor.data_ptr <uint8_t >();
96
- uint8_t * rowwise_scale_ptr = rowwise_full_tensor.data_ptr <uint8_t >() + total_size_rowwise_data;
110
+ uint8_t * rowwise_scale_ptr =
111
+ rowwise_full_tensor.data_ptr <uint8_t >() + total_size_rowwise_data;
97
112
// use from_blob to construct rowwise_data_list and rowwise_scale_list
98
113
for (int i = 0 ; i < num_splits; i++) {
99
- rowwise_data_list.emplace_back (at::from_blob (rowwise_data_ptr, {static_cast <int64_t >(rowwise_data_shapes[i].first ), static_cast <int64_t >(rowwise_data_shapes[i].second )}, at::device (input_view.device ()).dtype (torch::kUInt8 )));
100
- rowwise_scale_list.emplace_back (at::from_blob (rowwise_scale_ptr, {static_cast <int64_t >(rowwise_scale_shapes[i].first ), static_cast <int64_t >(rowwise_scale_shapes[i].second )}, at::device (input_view.device ()).dtype (torch::kFloat32 )));
114
+ rowwise_data_list.emplace_back (
115
+ at::from_blob (rowwise_data_ptr,
116
+ {static_cast <int64_t >(rowwise_data_shapes[i].first ),
117
+ static_cast <int64_t >(rowwise_data_shapes[i].second )},
118
+ at::device (input_view.device ()).dtype (torch::kUInt8 )));
119
+ rowwise_scale_list.emplace_back (
120
+ at::from_blob (rowwise_scale_ptr,
121
+ {static_cast <int64_t >(rowwise_scale_shapes[i].first ),
122
+ static_cast <int64_t >(rowwise_scale_shapes[i].second )},
123
+ at::device (input_view.device ()).dtype (torch::kFloat32 )));
101
124
rowwise_data_ptr += rowwise_data_sizes[i];
102
125
rowwise_scale_ptr += rowwise_scale_sizes[i];
103
126
}
104
127
}
105
128
106
129
if (columnwise_usage) {
107
- columnwise_full_tensor = at::empty ({(int64_t )total_size_columnwise}, at::device (input_view.device ()).dtype (torch::kUInt8 ));
130
+ columnwise_full_tensor = at::empty ({(int64_t )total_size_columnwise},
131
+ at::device (input_view.device ()).dtype (torch::kUInt8 ));
108
132
uint8_t * columnwise_data_ptr = columnwise_full_tensor.data_ptr <uint8_t >();
109
- uint8_t * columnwise_scale_ptr = columnwise_full_tensor.data_ptr <uint8_t >() + total_size_columnwise_data;
133
+ uint8_t * columnwise_scale_ptr =
134
+ columnwise_full_tensor.data_ptr <uint8_t >() + total_size_columnwise_data;
110
135
for (int i = 0 ; i < num_splits; i++) {
111
- columnwise_data_list.emplace_back (at::from_blob (columnwise_data_ptr, {static_cast <int64_t >(columnwise_data_shapes[i].first ), static_cast <int64_t >(columnwise_data_shapes[i].second )}, at::device (input_view.device ()).dtype (torch::kUInt8 )));
112
- columnwise_scale_list.emplace_back (at::from_blob (columnwise_scale_ptr, {static_cast <int64_t >(columnwise_scale_shapes[i].first ), static_cast <int64_t >(columnwise_scale_shapes[i].second )}, at::device (input_view.device ()).dtype (torch::kFloat32 )));
136
+ columnwise_data_list.emplace_back (
137
+ at::from_blob (columnwise_data_ptr,
138
+ {static_cast <int64_t >(columnwise_data_shapes[i].first ),
139
+ static_cast <int64_t >(columnwise_data_shapes[i].second )},
140
+ at::device (input_view.device ()).dtype (torch::kUInt8 )));
141
+ columnwise_scale_list.emplace_back (
142
+ at::from_blob (columnwise_scale_ptr,
143
+ {static_cast <int64_t >(columnwise_scale_shapes[i].first ),
144
+ static_cast <int64_t >(columnwise_scale_shapes[i].second )},
145
+ at::device (input_view.device ()).dtype (torch::kFloat32 )));
113
146
columnwise_data_ptr += columnwise_data_sizes[i];
114
147
columnwise_scale_ptr += columnwise_scale_sizes[i];
115
148
}
116
149
}
117
-
118
- for (int i = 0 ; i < num_splits; i++) {
119
150
151
+ for (int i = 0 ; i < num_splits; i++) {
120
152
py::handle Float8BlockwiseQTensorClass (
121
- reinterpret_cast <PyObject*>(Float8BlockwiseQTensorBasePythonClass));
153
+ reinterpret_cast <PyObject*>(Float8BlockwiseQTensorBasePythonClass));
122
154
123
155
// Create the tensor object with proper reference counting
124
156
py::object rowwise_data = rowwise_usage ? py::cast (rowwise_data_list[i]) : py::none ();
125
- py::object columnwise_data = columnwise_usage ? py::cast (columnwise_data_list[i]) : py::none ();
157
+ py::object columnwise_data =
158
+ columnwise_usage ? py::cast (columnwise_data_list[i]) : py::none ();
126
159
py::object rowwise_scale = rowwise_usage ? py::cast (rowwise_scale_list[i]) : py::none ();
127
- py::object columnwise_scale = columnwise_usage ? py::cast (columnwise_scale_list[i]) : py::none ();
160
+ py::object columnwise_scale =
161
+ columnwise_usage ? py::cast (columnwise_scale_list[i]) : py::none ();
128
162
129
163
py::object ret = Float8BlockwiseQTensorClass (
130
- " rowwise_data" _a = rowwise_data,
131
- " columnwise_data" _a = columnwise_data,
132
- " rowwise_scale_inv" _a = rowwise_scale,
133
- " columnwise_scale_inv" _a = columnwise_scale,
134
- " fp8_dtype" _a = fp8_dtype,
135
- " quantizer" _a = quantizer_list[i],
164
+ " rowwise_data" _a = rowwise_data, " columnwise_data" _a = columnwise_data,
165
+ " rowwise_scale_inv" _a = rowwise_scale, " columnwise_scale_inv" _a = columnwise_scale,
166
+ " fp8_dtype" _a = fp8_dtype, " quantizer" _a = quantizer_list[i],
136
167
" is_2D_scaled" _a = is_2D_scaled);
137
168
138
169
output_list.emplace_back (std::move (ret));
@@ -143,41 +174,20 @@ std::vector<py::object> fused_bulk_alloc_outputs(at::Tensor input_view, std::vec
143
174
144
175
// put the two full tensor into a python class to maintain their life cycle
145
176
py::object ret = Float8BlockwiseQTensorClass (
146
- " rowwise_data" _a = rowwise_full_tensor,
147
- " columnwise_data" _a = columnwise_full_tensor,
148
- " rowwise_scale_inv" _a = py::none (),
149
- " columnwise_scale_inv" _a = py::none (),
150
- " fp8_dtype" _a = transformer_engine::DType::kFloat8E4M3 , " quantizer" _a = py::none (), " is_2D_scaled" _a = true );
151
-
177
+ " rowwise_data" _a = rowwise_full_tensor, " columnwise_data" _a = columnwise_full_tensor,
178
+ " rowwise_scale_inv" _a = py::none (), " columnwise_scale_inv" _a = py::none (),
179
+ " fp8_dtype" _a = transformer_engine::DType::kFloat8E4M3 , " quantizer" _a = py::none (),
180
+ " is_2D_scaled" _a = true );
181
+
152
182
output_list.emplace_back (std::move (ret));
153
183
154
- }else {
184
+ } else {
155
185
NVTE_ERROR (" Fused bulk alloc is not supported for this quantizer type" );
156
186
}
157
187
158
188
return output_list;
159
189
}
160
190
161
- py::object simple_sanity_check (at::Tensor input, py::handle quantizer){
162
- init_extension ();
163
- using namespace pybind11 ::literals; // For operator""_a
164
- py::handle Float8BlockwiseQTensorClass (
165
- reinterpret_cast <PyObject*>(Float8BlockwiseQTensorBasePythonClass));
166
-
167
- py::object ret = Float8BlockwiseQTensorClass (
168
- " rowwise_data" _a = input,
169
- " columnwise_data" _a = input,
170
- " rowwise_scale_inv" _a = input,
171
- " columnwise_scale_inv" _a = input,
172
- " fp8_dtype" _a = transformer_engine::DType::kFloat8E4M3 , " quantizer" _a = quantizer, " is_2D_scaled" _a = true );
173
-
174
- // py::handle Float8TensorClass(reinterpret_cast<PyObject*>(Float8TensorBasePythonClass));
175
- // py::object ret = Float8TensorClass("data"_a = py::none(), "fp8_scale_inv"_a = py::none(),
176
- // "fp8_dtype"_a = transformer_engine::DType::kFloat8E4M3, "data_transpose"_a = py::none(),
177
- // "quantizer"_a = py::none());
178
- return ret;
179
- }
180
-
181
191
std::vector<py::object> fused_multi_quantize (std::vector<at::Tensor> input_list,
182
192
std::optional<std::vector<py::object>> output_list,
183
193
std::vector<py::handle> quantizer_list,
0 commit comments