Skip to content

Commit 7116a07

Browse files
committed
c++ bulk alloc worked, still draft version
Signed-off-by: zhongboz <[email protected]>
1 parent ef12e27 commit 7116a07

File tree

7 files changed

+353
-100
lines changed

7 files changed

+353
-100
lines changed

transformer_engine/pytorch/csrc/common.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,8 @@ class Float8BlockQuantizer : public Quantizer {
194194
std::pair<TensorWrapper, py::object> create_tensor(
195195
const std::vector<size_t>& shape, DType dtype,
196196
std::optional<at::Tensor> rowwise_data = std::nullopt) const override;
197+
198+
std::pair<size_t, size_t> get_scale_shape(const std::vector<size_t>& shape, bool columnwise) const;
197199
};
198200

199201
class MXFP8Quantizer : public Quantizer {

transformer_engine/pytorch/csrc/extensions.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,11 @@ namespace transformer_engine::pytorch {
107107
* Transpose
108108
**************************************************************************************************/
109109

110+
std::vector<py::object> fused_bulk_alloc_outputs(at::Tensor inpput_view, std::vector<int> m_splits,
111+
std::vector<py::handle> quantizer_list);
112+
113+
py::object simple_sanity_check(at::Tensor input, py::handle quantizer);
114+
110115
std::vector<py::object> fused_multi_quantize(std::vector<at::Tensor> input_list,
111116
std::optional<std::vector<py::object>> output_list,
112117
std::vector<py::handle> quantizer_list,

transformer_engine/pytorch/csrc/extensions/pybind.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
#include "../extensions.h"
1919
#include "common.h"
2020

21+
#include <iostream>
22+
2123
namespace transformer_engine::pytorch {
2224

2325
PyTypeObject *Float8TensorPythonClass = nullptr; /// TODO Remove
@@ -199,6 +201,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
199201
py::arg("ln_out"), py::arg("quantizer"), py::arg("otype"), py::arg("sm_margin"),
200202
py::arg("zero_centered_gamma"));
201203
m.def("rmsnorm_bwd", &rmsnorm_bwd, "Backward of RMSNorm");
204+
m.def("fused_bulk_alloc_outputs", &transformer_engine::pytorch::fused_bulk_alloc_outputs, "Fused Bulk Alloc Outputs",
205+
py::arg("input_view"), py::arg("m_splits"), py::arg("quantizer_list"));
206+
m.def("simple_sanity_check", &transformer_engine::pytorch::simple_sanity_check, "foo",
207+
py::arg("input"), py::arg("quantizer"));
202208
m.def("fused_multi_quantize", &transformer_engine::pytorch::fused_multi_quantize,
203209
"Fused Multi-tensor Cast + Transpose", py::arg("input_list"), py::arg("output_list"),
204210
py::arg("quantizer_list"), py::arg("otype"));

transformer_engine/pytorch/csrc/extensions/quantizer.cpp

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -387,6 +387,57 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor(
387387
return {std::move(tensor), std::move(ret)};
388388
}
389389

390+
std::pair<size_t, size_t> Float8BlockQuantizer::get_scale_shape(const std::vector<size_t>& shape, bool columnwise) const {
391+
using namespace pybind11::literals;
392+
std::vector<int64_t> torch_shape;
393+
size_t numel = 1;
394+
for (auto s : shape) {
395+
torch_shape.emplace_back(static_cast<int64_t>(s));
396+
numel *= s;
397+
}
398+
399+
size_t k_dim = torch_shape.size() == 0 ? 1u : torch_shape.back();
400+
size_t m_dim = numel / k_dim;
401+
constexpr size_t kBlockLen = 128;
402+
403+
std::pair<size_t, size_t> scale_shape;
404+
405+
if (columnwise) {
406+
size_t sinv0 = 0;
407+
size_t sinv1 = 0;
408+
if (block_scaling_dim == 2) {
409+
sinv0 = (k_dim + kBlockLen - 1) / kBlockLen;
410+
sinv1 = roundup((m_dim + kBlockLen - 1) / kBlockLen, 4);
411+
} else if (block_scaling_dim == 1) {
412+
sinv0 = (m_dim + kBlockLen - 1) / kBlockLen;
413+
sinv1 = roundup(k_dim, 4);
414+
} else {
415+
NVTE_CHECK(false,
416+
"Unsupported block_scaling_dim in create_tensor columnwise."
417+
"Expected 1 or 2. Got ",
418+
block_scaling_dim);
419+
}
420+
scale_shape = {sinv0, sinv1};
421+
}else {
422+
size_t sinv0 = 0;
423+
size_t sinv1 = 0;
424+
if (block_scaling_dim == 2) {
425+
sinv0 = (m_dim + kBlockLen - 1) / kBlockLen;
426+
sinv1 = roundup((k_dim + kBlockLen - 1) / kBlockLen, 4);
427+
} else if (block_scaling_dim == 1) {
428+
sinv0 = (k_dim + kBlockLen - 1) / kBlockLen;
429+
sinv1 = roundup(m_dim, 4);
430+
} else {
431+
NVTE_CHECK(false,
432+
"Unsupported block_scaling_dim in create_tensor rowwise."
433+
"Expected 1 or 2. Got ",
434+
block_scaling_dim);
435+
}
436+
scale_shape = {sinv0, sinv1};
437+
}
438+
return scale_shape;
439+
}
440+
390441
MXFP8Quantizer::MXFP8Quantizer(const py::handle& quantizer) : Quantizer(quantizer) {
391442
this->dtype = quantizer.attr("dtype").cast<DType>();
392443
}

transformer_engine/pytorch/csrc/extensions/transpose.cpp

Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,179 @@
55
************************************************************************/
66

77
#include <optional>
8+
#include <pybind.h>
89

910
#include "extensions.h"
1011
#include "pybind.h"
1112

13+
#include <iostream>
14+
1215
namespace transformer_engine::pytorch {
1316

17+
std::vector<py::object> fused_bulk_alloc_outputs(at::Tensor input_view, std::vector<int> m_splits,
18+
std::vector<py::handle> quantizer_list) {
19+
init_extension();
20+
using namespace pybind11::literals; // For operator""_a
21+
22+
int num_splits = m_splits.size();
23+
24+
// convert all the quantizers
25+
std::vector<std::unique_ptr<Quantizer>> quantizers;
26+
for (int i = 0; i < num_splits; i++) {
27+
quantizers.push_back(convert_quantizer(quantizer_list[i]));
28+
}
29+
30+
bool rowwise_usage = quantizers[0]->rowwise_usage;
31+
bool columnwise_usage = quantizers[0]->columnwise_usage;
32+
size_t hidden_dim = input_view.size(1);
33+
34+
std::vector<py::object> output_list;
35+
36+
if (detail::IsFloat8BlockwiseQuantizers(quantizer_list[0].ptr())) {
37+
// implement the fuse bulk alloc for blockwise quantizer
38+
// downcast quantizers, resorces are owned by the unique_ptr, so use raw ptr here just to get the attributes
39+
std::vector<Float8BlockQuantizer*> blockwise_quantizers;
40+
for (size_t i = 0; i < quantizers.size(); i++) {
41+
Quantizer* raw_ptr = quantizers[i].get();
42+
Float8BlockQuantizer* blockwise_quantizer = static_cast<Float8BlockQuantizer*>(raw_ptr);
43+
blockwise_quantizers.push_back(blockwise_quantizer);
44+
}
45+
46+
bool is_2D_scaled = blockwise_quantizers[0]->get_scaling_mode() == NVTE_BLOCK_SCALING_2D;
47+
transformer_engine::DType fp8_dtype = blockwise_quantizers[0]->dtype;
48+
49+
size_t fp8_elem_size = 1;
50+
size_t scale_elem_size = 4;
51+
52+
std::vector<std::pair<size_t, size_t>> rowwise_data_shapes;
53+
std::vector<std::pair<size_t, size_t>> rowwise_scale_shapes;
54+
std::vector<size_t> rowwise_data_sizes;
55+
std::vector<size_t> rowwise_scale_sizes;
56+
std::vector<std::pair<size_t, size_t>> columnwise_data_shapes;
57+
std::vector<std::pair<size_t, size_t>> columnwise_scale_shapes;
58+
std::vector<size_t> columnwise_data_sizes;
59+
std::vector<size_t> columnwise_scale_sizes;
60+
for (int i = 0; i < num_splits; i++) {
61+
std::pair<size_t, size_t> input_view_i_shape = std::make_pair((size_t)m_splits[i], (size_t)hidden_dim);
62+
if (rowwise_usage) {
63+
rowwise_data_shapes.emplace_back(input_view_i_shape);
64+
rowwise_scale_shapes.emplace_back(blockwise_quantizers[i]->get_scale_shape({input_view_i_shape.first, input_view_i_shape.second}, false));
65+
rowwise_data_sizes.emplace_back(input_view_i_shape.first * input_view_i_shape.second * fp8_elem_size);
66+
rowwise_scale_sizes.emplace_back(rowwise_scale_shapes.back().first * rowwise_scale_shapes.back().second * scale_elem_size);
67+
}
68+
if (columnwise_usage) {
69+
columnwise_data_shapes.emplace_back(std::make_pair(input_view_i_shape.second, input_view_i_shape.first));
70+
columnwise_scale_shapes.emplace_back(blockwise_quantizers[i]->get_scale_shape({input_view_i_shape.first, input_view_i_shape.second}, true));
71+
columnwise_data_sizes.emplace_back(input_view_i_shape.first * input_view_i_shape.second * fp8_elem_size);
72+
columnwise_scale_sizes.emplace_back(columnwise_scale_shapes.back().first * columnwise_scale_shapes.back().second * scale_elem_size);
73+
}
74+
}
75+
76+
size_t total_size_rowwise_data = std::accumulate(rowwise_data_sizes.begin(), rowwise_data_sizes.end(), 0);
77+
size_t total_size_rowwise_scale = std::accumulate(rowwise_scale_sizes.begin(), rowwise_scale_sizes.end(), 0);
78+
size_t total_size_columnwise_data = std::accumulate(columnwise_data_sizes.begin(), columnwise_data_sizes.end(), 0);
79+
size_t total_size_columnwise_scale = std::accumulate(columnwise_scale_sizes.begin(), columnwise_scale_sizes.end(), 0);
80+
81+
size_t total_size_rowwise = total_size_rowwise_data + total_size_rowwise_scale;
82+
size_t total_size_columnwise = total_size_columnwise_data + total_size_columnwise_scale;
83+
84+
std::vector<at::Tensor> rowwise_data_list;
85+
std::vector<at::Tensor> rowwise_scale_list;
86+
std::vector<at::Tensor> columnwise_data_list;
87+
std::vector<at::Tensor> columnwise_scale_list;
88+
89+
at::Tensor rowwise_full_tensor;
90+
at::Tensor columnwise_full_tensor;
91+
92+
if (rowwise_usage) {
93+
rowwise_full_tensor = at::empty({(int64_t)total_size_rowwise}, at::device(input_view.device()).dtype(torch::kUInt8));
94+
// use raw pointer math + from blob, avoid torch slice to reduce cpu overhead
95+
uint8_t* rowwise_data_ptr = rowwise_full_tensor.data_ptr<uint8_t>();
96+
uint8_t* rowwise_scale_ptr = rowwise_full_tensor.data_ptr<uint8_t>() + total_size_rowwise_data;
97+
// use from_blob to construct rowwise_data_list and rowwise_scale_list
98+
for (int i = 0; i < num_splits; i++) {
99+
rowwise_data_list.emplace_back(at::from_blob(rowwise_data_ptr, {static_cast<int64_t>(rowwise_data_shapes[i].first), static_cast<int64_t>(rowwise_data_shapes[i].second)}, at::device(input_view.device()).dtype(torch::kUInt8)));
100+
rowwise_scale_list.emplace_back(at::from_blob(rowwise_scale_ptr, {static_cast<int64_t>(rowwise_scale_shapes[i].first), static_cast<int64_t>(rowwise_scale_shapes[i].second)}, at::device(input_view.device()).dtype(torch::kFloat32)));
101+
rowwise_data_ptr += rowwise_data_sizes[i];
102+
rowwise_scale_ptr += rowwise_scale_sizes[i];
103+
}
104+
}
105+
106+
if (columnwise_usage) {
107+
columnwise_full_tensor = at::empty({(int64_t)total_size_columnwise}, at::device(input_view.device()).dtype(torch::kUInt8));
108+
uint8_t* columnwise_data_ptr = columnwise_full_tensor.data_ptr<uint8_t>();
109+
uint8_t* columnwise_scale_ptr = columnwise_full_tensor.data_ptr<uint8_t>() + total_size_columnwise_data;
110+
for (int i = 0; i < num_splits; i++) {
111+
columnwise_data_list.emplace_back(at::from_blob(columnwise_data_ptr, {static_cast<int64_t>(columnwise_data_shapes[i].first), static_cast<int64_t>(columnwise_data_shapes[i].second)}, at::device(input_view.device()).dtype(torch::kUInt8)));
112+
columnwise_scale_list.emplace_back(at::from_blob(columnwise_scale_ptr, {static_cast<int64_t>(columnwise_scale_shapes[i].first), static_cast<int64_t>(columnwise_scale_shapes[i].second)}, at::device(input_view.device()).dtype(torch::kFloat32)));
113+
columnwise_data_ptr += columnwise_data_sizes[i];
114+
columnwise_scale_ptr += columnwise_scale_sizes[i];
115+
}
116+
}
117+
118+
for (int i = 0; i < num_splits; i++) {
119+
120+
py::handle Float8BlockwiseQTensorClass(
121+
reinterpret_cast<PyObject*>(Float8BlockwiseQTensorBasePythonClass));
122+
123+
// Create the tensor object with proper reference counting
124+
py::object rowwise_data = rowwise_usage ? py::cast(rowwise_data_list[i]) : py::none();
125+
py::object columnwise_data = columnwise_usage ? py::cast(columnwise_data_list[i]) : py::none();
126+
py::object rowwise_scale = rowwise_usage ? py::cast(rowwise_scale_list[i]) : py::none();
127+
py::object columnwise_scale = columnwise_usage ? py::cast(columnwise_scale_list[i]) : py::none();
128+
129+
py::object ret = Float8BlockwiseQTensorClass(
130+
"rowwise_data"_a = rowwise_data,
131+
"columnwise_data"_a = columnwise_data,
132+
"rowwise_scale_inv"_a = rowwise_scale,
133+
"columnwise_scale_inv"_a = columnwise_scale,
134+
"fp8_dtype"_a = fp8_dtype,
135+
"quantizer"_a = quantizer_list[i],
136+
"is_2D_scaled"_a = is_2D_scaled);
137+
138+
output_list.emplace_back(std::move(ret));
139+
}
140+
141+
py::handle Float8BlockwiseQTensorClass(
142+
reinterpret_cast<PyObject*>(Float8BlockwiseQTensorBasePythonClass));
143+
144+
// put the two full tensor into a python class to maintain their life cycle
145+
py::object ret = Float8BlockwiseQTensorClass(
146+
"rowwise_data"_a = rowwise_full_tensor,
147+
"columnwise_data"_a = columnwise_full_tensor,
148+
"rowwise_scale_inv"_a = py::none(),
149+
"columnwise_scale_inv"_a = py::none(),
150+
"fp8_dtype"_a = transformer_engine::DType::kFloat8E4M3, "quantizer"_a = py::none(), "is_2D_scaled"_a = true);
151+
152+
output_list.emplace_back(std::move(ret));
153+
154+
}else{
155+
NVTE_ERROR("Fused bulk alloc is not supported for this quantizer type");
156+
}
157+
158+
return output_list;
159+
}
160+
161+
py::object simple_sanity_check(at::Tensor input, py::handle quantizer){
162+
init_extension();
163+
using namespace pybind11::literals; // For operator""_a
164+
py::handle Float8BlockwiseQTensorClass(
165+
reinterpret_cast<PyObject*>(Float8BlockwiseQTensorBasePythonClass));
166+
167+
py::object ret = Float8BlockwiseQTensorClass(
168+
"rowwise_data"_a = input,
169+
"columnwise_data"_a = input,
170+
"rowwise_scale_inv"_a = input,
171+
"columnwise_scale_inv"_a = input,
172+
"fp8_dtype"_a = transformer_engine::DType::kFloat8E4M3, "quantizer"_a = quantizer, "is_2D_scaled"_a = true);
173+
174+
// py::handle Float8TensorClass(reinterpret_cast<PyObject*>(Float8TensorBasePythonClass));
175+
// py::object ret = Float8TensorClass("data"_a = py::none(), "fp8_scale_inv"_a = py::none(),
176+
// "fp8_dtype"_a = transformer_engine::DType::kFloat8E4M3, "data_transpose"_a = py::none(),
177+
// "quantizer"_a = py::none());
178+
return ret;
179+
}
180+
14181
std::vector<py::object> fused_multi_quantize(std::vector<at::Tensor> input_list,
15182
std::optional<std::vector<py::object>> output_list,
16183
std::vector<py::handle> quantizer_list,

transformer_engine/pytorch/module/grouped_linear.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
restore_from_saved,
5151
)
5252

53-
from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer, bulk_alloc_float8_blockwise_tensor
53+
from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer
5454

5555
__all__ = ["GroupedLinear"]
5656

@@ -124,12 +124,22 @@ def forward(
124124
output_quantizer.set_usage(rowwise=True, columnwise=False)
125125

126126
fprop_gemm_use_split_accumulator = _2X_ACC_FPROP
127+
full_buffer_rowwise = None
128+
full_buffer_columnwise = None
127129
if fp8:
128130
recipe = FP8GlobalStateManager.get_fp8_recipe()
129131
if hasattr(recipe, "fp8_gemm_fprop"):
130132
fprop_gemm_use_split_accumulator = recipe.fp8_gemm_fprop.use_split_accumulator
131-
# TODO(zhongbo): make bulk alloc available for all quantizers
132-
output_list = bulk_alloc_float8_blockwise_tensor(inp_view, m_splits, input_quantizers) if isinstance(input_quantizers[0], Float8BlockQuantizer) else None
133+
134+
alloc_output = tex.fused_bulk_alloc_outputs(inp_view, m_splits, input_quantizers) if isinstance(input_quantizers[0], Float8BlockQuantizer) else None
135+
# alloc_output = tex.simple_sanity_check(inp_view, input_quantizers[0]) if isinstance(input_quantizers[0], Float8BlockQuantizer) else None
136+
output_list = None
137+
if alloc_output is not None:
138+
# last element if the full buffer, all the previous tensor are view of the full buffer
139+
output_list = alloc_output[:-1]
140+
full_buffer_rowwise = alloc_output[-1]._rowwise_data
141+
full_buffer_columnwise = alloc_output[-1]._columnwise_data
142+
133143
inputmats = tex.fused_multi_quantize(
134144
inputmats_no_fp8, output_list, input_quantizers, TE_DType[activation_dtype]
135145
)
@@ -204,6 +214,7 @@ def forward(
204214
for inputmat in inputmats:
205215
if isinstance(inputmat, QuantizedTensorBase):
206216
inputmat.update_usage(rowwise_usage=False, columnwise_usage=True)
217+
full_tensor_rowwise = None
207218
if inp.requires_grad:
208219
for weight in weights_fp8:
209220
if isinstance(weight, QuantizedTensorBase):
@@ -216,6 +227,7 @@ def forward(
216227
*biases,
217228
)
218229
ctx.save_for_backward(*tensors_to_save)
230+
ctx.full_buffer_columnwise = full_buffer_columnwise
219231
ctx.tensor_objects = tensor_objects
220232

221233
ctx.weights_requires_grad = weights[0].requires_grad
@@ -260,6 +272,8 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],
260272
biases = saved_tensors[3 * N : 4 * N]
261273
main_grads = ctx.main_grads
262274

275+
full_buffer_columnwise = ctx.full_buffer_columnwise
276+
263277
if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation: # TOSO
264278
for i in ctx.num_gemms:
265279
w = torch.nn.Parameter(weights[i], weights[i].requires_grad)
@@ -275,6 +289,9 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],
275289
)
276290
grad_output = [None] * ctx.num_gemms
277291
grad_biases = [None] * ctx.num_gemms
292+
293+
full_buffer_rowwise_dy = None
294+
full_buffer_columnwise_dy = None
278295
if ctx.fp8:
279296
if ctx.use_bias:
280297
# unfuse bgrad for now until cast_transpose + dgrad calculation is ready
@@ -289,7 +306,12 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],
289306
grad_output_mats[i], ctx.grad_output_quantizers[i]
290307
)
291308
else:
292-
output_list = bulk_alloc_float8_blockwise_tensor(grad_output_view, ctx.m_splits, ctx.grad_output_quantizers) if isinstance(ctx.grad_output_quantizers[0], Float8BlockQuantizer) else None
309+
alloc_output = tex.fused_bulk_alloc_outputs(grad_output_view, ctx.m_splits, ctx.grad_output_quantizers) if isinstance(ctx.grad_output_quantizers[0], Float8BlockQuantizer) else None
310+
output_list = None
311+
if alloc_output is not None:
312+
output_list = alloc_output[:-1]
313+
full_buffer_rowwise_dy = alloc_output[-1]._rowwise_data
314+
full_buffer_columnwise_dy = alloc_output[-1]._columnwise_data
293315
grad_output = tex.fused_multi_quantize(
294316
grad_output_mats,
295317
output_list,

0 commit comments

Comments
 (0)