Skip to content

Commit 63b10ba

Browse files
committed
clean up
Signed-off-by: zhongboz <[email protected]>
1 parent 7116a07 commit 63b10ba

File tree

8 files changed

+112
-201
lines changed

8 files changed

+112
-201
lines changed

benchmarks/linear/benchmark_grouped_linear.py

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,18 @@
88
from transformer_engine.common.recipe import Float8BlockScaling
99
from transformer_engine.pytorch.fp8 import fp8_autocast
1010
from contextlib import nullcontext
11+
1112
RECIPES = {
1213
"bf16": None,
1314
"fp8_sub_channel": Float8BlockScaling(),
1415
}
1516

1617

17-
def run_linear_multiple_steps(
18-
layer, x, m_splits, mode, gradient, run_num_steps=1, recipe=None
19-
):
18+
def run_linear_multiple_steps(layer, x, m_splits, mode, gradient, run_num_steps=1, recipe=None):
2019
assert mode in ["fwd_only", "fwd_bwd"]
21-
fp8_context = fp8_autocast(enabled=True, fp8_recipe=recipe) if recipe is not None else nullcontext()
20+
fp8_context = (
21+
fp8_autocast(enabled=True, fp8_recipe=recipe) if recipe is not None else nullcontext()
22+
)
2223
# print(f"fp8_context: {fp8_context} and is it nullcontext? {isinstance(fp8_context, nullcontext)}")
2324

2425
if mode == "fwd_only":
@@ -67,13 +68,11 @@ def benchmark_linear(
6768
num_gemms=4,
6869
):
6970
params_dtype = torch.bfloat16
70-
recipe =RECIPES[recipe_name]
71+
recipe = RECIPES[recipe_name]
7172

7273
in_features = x.shape[1]
7374
out_features = ws[0].shape[0]
74-
gradient = torch.ones(
75-
(x.shape[0], out_features), dtype=torch.bfloat16, device=x.device
76-
)
75+
gradient = torch.ones((x.shape[0], out_features), dtype=torch.bfloat16, device=x.device)
7776

7877
layer = GroupedLinear(
7978
num_gemms,
@@ -97,7 +96,10 @@ def benchmark_linear(
9796
label = f"{recipe_name}_{'grouped'}"
9897
torch.cuda.nvtx.range_push(label)
9998
timing = benchmark.Timer(
100-
stmt="run_linear_multiple_steps(layer, x, m_splits, mode, gradient, num_microbatches, recipe)",
99+
stmt=(
100+
"run_linear_multiple_steps(layer, x, m_splits, mode, gradient, num_microbatches,"
101+
" recipe)"
102+
),
101103
globals={
102104
"run_linear_multiple_steps": run_linear_multiple_steps,
103105
"layer": layer,
@@ -116,20 +118,15 @@ def benchmark_linear(
116118
return timing_ms
117119

118120

119-
def run_benchmark_linear(
120-
mkns, recipe_name, use_bias, num_gemms=4
121-
):
121+
def run_benchmark_linear(mkns, recipe_name, use_bias, num_gemms=4):
122122
data = []
123123
assert not use_bias, "Bias is not supported for GroupedLinear benchmark"
124124

125125
print(f"========== Benchmarking {recipe_name} ==========")
126126
for m, k, n in mkns:
127127
device = "cuda"
128128
x = torch.randn((m, k), dtype=torch.bfloat16, device=device, requires_grad=True)
129-
ws = [
130-
torch.randn((n, k), dtype=torch.bfloat16, device=device)
131-
for _ in range(num_gemms)
132-
]
129+
ws = [torch.randn((n, k), dtype=torch.bfloat16, device=device) for _ in range(num_gemms)]
133130
assert m % num_gemms == 0
134131
m_splits = [m // num_gemms] * num_gemms
135132
# Bias is not supported for GroupedLinear benchmark
@@ -192,7 +189,7 @@ def run_benchmark_linear(
192189
# Set the MKN values to benchmark
193190
mkns = []
194191
for m in [1024]:
195-
# for m in [4096, 8192, 16384]:
192+
# for m in [4096, 8192, 16384]:
196193
# for n in [1024, 2048, 4096, 8192, 16384]:
197194
for n in [3072]:
198195
for k in [4096]:
@@ -236,6 +233,5 @@ def run_benchmark_linear(
236233

237234
print(df_linears)
238235

239-
240236
if args.profile:
241237
torch.autograd.profiler.emit_nvtx().__exit__(None, None, None)

transformer_engine/pytorch/csrc/common.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,8 @@ class Float8BlockQuantizer : public Quantizer {
195195
const std::vector<size_t>& shape, DType dtype,
196196
std::optional<at::Tensor> rowwise_data = std::nullopt) const override;
197197

198-
std::pair<size_t, size_t> get_scale_shape(const std::vector<size_t>& shape, bool columnwise) const;
198+
std::pair<size_t, size_t> get_scale_shape(const std::vector<size_t>& shape,
199+
bool columnwise) const;
199200
};
200201

201202
class MXFP8Quantizer : public Quantizer {

transformer_engine/pytorch/csrc/extensions.h

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -107,10 +107,8 @@ namespace transformer_engine::pytorch {
107107
* Transpose
108108
**************************************************************************************************/
109109

110-
std::vector<py::object> fused_bulk_alloc_outputs(at::Tensor inpput_view, std::vector<int> m_splits,
111-
std::vector<py::handle> quantizer_list);
112-
113-
py::object simple_sanity_check(at::Tensor input, py::handle quantizer);
110+
std::vector<py::object> fused_bulk_alloc_outputs(at::Tensor inpput_view, std::vector<int> m_splits,
111+
std::vector<py::handle> quantizer_list);
114112

115113
std::vector<py::object> fused_multi_quantize(std::vector<at::Tensor> input_list,
116114
std::optional<std::vector<py::object>> output_list,

transformer_engine/pytorch/csrc/extensions/pybind.cpp

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,13 @@
1212
#include <pybind11/pybind11.h>
1313
#include <pybind11/stl.h>
1414

15+
#include <iostream>
1516
#include <stdexcept>
1617

1718
#include "../common.h"
1819
#include "../extensions.h"
1920
#include "common.h"
2021

21-
#include <iostream>
22-
2322
namespace transformer_engine::pytorch {
2423

2524
PyTypeObject *Float8TensorPythonClass = nullptr; /// TODO Remove
@@ -201,10 +200,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
201200
py::arg("ln_out"), py::arg("quantizer"), py::arg("otype"), py::arg("sm_margin"),
202201
py::arg("zero_centered_gamma"));
203202
m.def("rmsnorm_bwd", &rmsnorm_bwd, "Backward of RMSNorm");
204-
m.def("fused_bulk_alloc_outputs", &transformer_engine::pytorch::fused_bulk_alloc_outputs, "Fused Bulk Alloc Outputs",
205-
py::arg("input_view"), py::arg("m_splits"), py::arg("quantizer_list"));
206-
m.def("simple_sanity_check", &transformer_engine::pytorch::simple_sanity_check, "foo",
207-
py::arg("input"), py::arg("quantizer"));
203+
m.def("fused_bulk_alloc_outputs", &transformer_engine::pytorch::fused_bulk_alloc_outputs,
204+
"Fused Bulk Alloc Outputs", py::arg("input_view"), py::arg("m_splits"),
205+
py::arg("quantizer_list"));
208206
m.def("fused_multi_quantize", &transformer_engine::pytorch::fused_multi_quantize,
209207
"Fused Multi-tensor Cast + Transpose", py::arg("input_list"), py::arg("output_list"),
210208
py::arg("quantizer_list"), py::arg("otype"));

transformer_engine/pytorch/csrc/extensions/quantizer.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -387,7 +387,8 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor(
387387
return {std::move(tensor), std::move(ret)};
388388
}
389389

390-
std::pair<size_t, size_t> Float8BlockQuantizer::get_scale_shape(const std::vector<size_t>& shape, bool columnwise) const {
390+
std::pair<size_t, size_t> Float8BlockQuantizer::get_scale_shape(const std::vector<size_t>& shape,
391+
bool columnwise) const {
391392
using namespace pybind11::literals;
392393
std::vector<int64_t> torch_shape;
393394
size_t numel = 1;
@@ -418,7 +419,7 @@ std::pair<size_t, size_t> Float8BlockQuantizer::get_scale_shape(const std::vecto
418419
block_scaling_dim);
419420
}
420421
scale_shape = {sinv0, sinv1};
421-
}else {
422+
} else {
422423
size_t sinv0 = 0;
423424
size_t sinv1 = 0;
424425
if (block_scaling_dim == 2) {

transformer_engine/pytorch/csrc/extensions/transpose.cpp

Lines changed: 73 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,18 @@
44
* See LICENSE for license information.
55
************************************************************************/
66

7-
#include <optional>
87
#include <pybind.h>
98

9+
#include <iostream>
10+
#include <optional>
11+
1012
#include "extensions.h"
1113
#include "pybind.h"
1214

13-
#include <iostream>
14-
1515
namespace transformer_engine::pytorch {
1616

17-
std::vector<py::object> fused_bulk_alloc_outputs(at::Tensor input_view, std::vector<int> m_splits,
18-
std::vector<py::handle> quantizer_list) {
17+
std::vector<py::object> fused_bulk_alloc_outputs(at::Tensor input_view, std::vector<int> m_splits,
18+
std::vector<py::handle> quantizer_list) {
1919
init_extension();
2020
using namespace pybind11::literals; // For operator""_a
2121

@@ -58,25 +58,38 @@ std::vector<py::object> fused_bulk_alloc_outputs(at::Tensor input_view, std::vec
5858
std::vector<size_t> columnwise_data_sizes;
5959
std::vector<size_t> columnwise_scale_sizes;
6060
for (int i = 0; i < num_splits; i++) {
61-
std::pair<size_t, size_t> input_view_i_shape = std::make_pair((size_t)m_splits[i], (size_t)hidden_dim);
61+
std::pair<size_t, size_t> input_view_i_shape =
62+
std::make_pair((size_t)m_splits[i], (size_t)hidden_dim);
6263
if (rowwise_usage) {
6364
rowwise_data_shapes.emplace_back(input_view_i_shape);
64-
rowwise_scale_shapes.emplace_back(blockwise_quantizers[i]->get_scale_shape({input_view_i_shape.first, input_view_i_shape.second}, false));
65-
rowwise_data_sizes.emplace_back(input_view_i_shape.first * input_view_i_shape.second * fp8_elem_size);
66-
rowwise_scale_sizes.emplace_back(rowwise_scale_shapes.back().first * rowwise_scale_shapes.back().second * scale_elem_size);
65+
rowwise_scale_shapes.emplace_back(blockwise_quantizers[i]->get_scale_shape(
66+
{input_view_i_shape.first, input_view_i_shape.second}, false));
67+
rowwise_data_sizes.emplace_back(input_view_i_shape.first * input_view_i_shape.second *
68+
fp8_elem_size);
69+
rowwise_scale_sizes.emplace_back(rowwise_scale_shapes.back().first *
70+
rowwise_scale_shapes.back().second * scale_elem_size);
6771
}
6872
if (columnwise_usage) {
69-
columnwise_data_shapes.emplace_back(std::make_pair(input_view_i_shape.second, input_view_i_shape.first));
70-
columnwise_scale_shapes.emplace_back(blockwise_quantizers[i]->get_scale_shape({input_view_i_shape.first, input_view_i_shape.second}, true));
71-
columnwise_data_sizes.emplace_back(input_view_i_shape.first * input_view_i_shape.second * fp8_elem_size);
72-
columnwise_scale_sizes.emplace_back(columnwise_scale_shapes.back().first * columnwise_scale_shapes.back().second * scale_elem_size);
73+
columnwise_data_shapes.emplace_back(
74+
std::make_pair(input_view_i_shape.second, input_view_i_shape.first));
75+
columnwise_scale_shapes.emplace_back(blockwise_quantizers[i]->get_scale_shape(
76+
{input_view_i_shape.first, input_view_i_shape.second}, true));
77+
columnwise_data_sizes.emplace_back(input_view_i_shape.first * input_view_i_shape.second *
78+
fp8_elem_size);
79+
columnwise_scale_sizes.emplace_back(columnwise_scale_shapes.back().first *
80+
columnwise_scale_shapes.back().second *
81+
scale_elem_size);
7382
}
7483
}
7584

76-
size_t total_size_rowwise_data = std::accumulate(rowwise_data_sizes.begin(), rowwise_data_sizes.end(), 0);
77-
size_t total_size_rowwise_scale = std::accumulate(rowwise_scale_sizes.begin(), rowwise_scale_sizes.end(), 0);
78-
size_t total_size_columnwise_data = std::accumulate(columnwise_data_sizes.begin(), columnwise_data_sizes.end(), 0);
79-
size_t total_size_columnwise_scale = std::accumulate(columnwise_scale_sizes.begin(), columnwise_scale_sizes.end(), 0);
85+
size_t total_size_rowwise_data =
86+
std::accumulate(rowwise_data_sizes.begin(), rowwise_data_sizes.end(), 0);
87+
size_t total_size_rowwise_scale =
88+
std::accumulate(rowwise_scale_sizes.begin(), rowwise_scale_sizes.end(), 0);
89+
size_t total_size_columnwise_data =
90+
std::accumulate(columnwise_data_sizes.begin(), columnwise_data_sizes.end(), 0);
91+
size_t total_size_columnwise_scale =
92+
std::accumulate(columnwise_scale_sizes.begin(), columnwise_scale_sizes.end(), 0);
8093

8194
size_t total_size_rowwise = total_size_rowwise_data + total_size_rowwise_scale;
8295
size_t total_size_columnwise = total_size_columnwise_data + total_size_columnwise_scale;
@@ -90,49 +103,67 @@ std::vector<py::object> fused_bulk_alloc_outputs(at::Tensor input_view, std::vec
90103
at::Tensor columnwise_full_tensor;
91104

92105
if (rowwise_usage) {
93-
rowwise_full_tensor = at::empty({(int64_t)total_size_rowwise}, at::device(input_view.device()).dtype(torch::kUInt8));
106+
rowwise_full_tensor = at::empty({(int64_t)total_size_rowwise},
107+
at::device(input_view.device()).dtype(torch::kUInt8));
94108
// use raw pointer math + from blob, avoid torch slice to reduce cpu overhead
95109
uint8_t* rowwise_data_ptr = rowwise_full_tensor.data_ptr<uint8_t>();
96-
uint8_t* rowwise_scale_ptr = rowwise_full_tensor.data_ptr<uint8_t>() + total_size_rowwise_data;
110+
uint8_t* rowwise_scale_ptr =
111+
rowwise_full_tensor.data_ptr<uint8_t>() + total_size_rowwise_data;
97112
// use from_blob to construct rowwise_data_list and rowwise_scale_list
98113
for (int i = 0; i < num_splits; i++) {
99-
rowwise_data_list.emplace_back(at::from_blob(rowwise_data_ptr, {static_cast<int64_t>(rowwise_data_shapes[i].first), static_cast<int64_t>(rowwise_data_shapes[i].second)}, at::device(input_view.device()).dtype(torch::kUInt8)));
100-
rowwise_scale_list.emplace_back(at::from_blob(rowwise_scale_ptr, {static_cast<int64_t>(rowwise_scale_shapes[i].first), static_cast<int64_t>(rowwise_scale_shapes[i].second)}, at::device(input_view.device()).dtype(torch::kFloat32)));
114+
rowwise_data_list.emplace_back(
115+
at::from_blob(rowwise_data_ptr,
116+
{static_cast<int64_t>(rowwise_data_shapes[i].first),
117+
static_cast<int64_t>(rowwise_data_shapes[i].second)},
118+
at::device(input_view.device()).dtype(torch::kUInt8)));
119+
rowwise_scale_list.emplace_back(
120+
at::from_blob(rowwise_scale_ptr,
121+
{static_cast<int64_t>(rowwise_scale_shapes[i].first),
122+
static_cast<int64_t>(rowwise_scale_shapes[i].second)},
123+
at::device(input_view.device()).dtype(torch::kFloat32)));
101124
rowwise_data_ptr += rowwise_data_sizes[i];
102125
rowwise_scale_ptr += rowwise_scale_sizes[i];
103126
}
104127
}
105128

106129
if (columnwise_usage) {
107-
columnwise_full_tensor = at::empty({(int64_t)total_size_columnwise}, at::device(input_view.device()).dtype(torch::kUInt8));
130+
columnwise_full_tensor = at::empty({(int64_t)total_size_columnwise},
131+
at::device(input_view.device()).dtype(torch::kUInt8));
108132
uint8_t* columnwise_data_ptr = columnwise_full_tensor.data_ptr<uint8_t>();
109-
uint8_t* columnwise_scale_ptr = columnwise_full_tensor.data_ptr<uint8_t>() + total_size_columnwise_data;
133+
uint8_t* columnwise_scale_ptr =
134+
columnwise_full_tensor.data_ptr<uint8_t>() + total_size_columnwise_data;
110135
for (int i = 0; i < num_splits; i++) {
111-
columnwise_data_list.emplace_back(at::from_blob(columnwise_data_ptr, {static_cast<int64_t>(columnwise_data_shapes[i].first), static_cast<int64_t>(columnwise_data_shapes[i].second)}, at::device(input_view.device()).dtype(torch::kUInt8)));
112-
columnwise_scale_list.emplace_back(at::from_blob(columnwise_scale_ptr, {static_cast<int64_t>(columnwise_scale_shapes[i].first), static_cast<int64_t>(columnwise_scale_shapes[i].second)}, at::device(input_view.device()).dtype(torch::kFloat32)));
136+
columnwise_data_list.emplace_back(
137+
at::from_blob(columnwise_data_ptr,
138+
{static_cast<int64_t>(columnwise_data_shapes[i].first),
139+
static_cast<int64_t>(columnwise_data_shapes[i].second)},
140+
at::device(input_view.device()).dtype(torch::kUInt8)));
141+
columnwise_scale_list.emplace_back(
142+
at::from_blob(columnwise_scale_ptr,
143+
{static_cast<int64_t>(columnwise_scale_shapes[i].first),
144+
static_cast<int64_t>(columnwise_scale_shapes[i].second)},
145+
at::device(input_view.device()).dtype(torch::kFloat32)));
113146
columnwise_data_ptr += columnwise_data_sizes[i];
114147
columnwise_scale_ptr += columnwise_scale_sizes[i];
115148
}
116149
}
117-
118-
for (int i = 0; i < num_splits; i++) {
119150

151+
for (int i = 0; i < num_splits; i++) {
120152
py::handle Float8BlockwiseQTensorClass(
121-
reinterpret_cast<PyObject*>(Float8BlockwiseQTensorBasePythonClass));
153+
reinterpret_cast<PyObject*>(Float8BlockwiseQTensorBasePythonClass));
122154

123155
// Create the tensor object with proper reference counting
124156
py::object rowwise_data = rowwise_usage ? py::cast(rowwise_data_list[i]) : py::none();
125-
py::object columnwise_data = columnwise_usage ? py::cast(columnwise_data_list[i]) : py::none();
157+
py::object columnwise_data =
158+
columnwise_usage ? py::cast(columnwise_data_list[i]) : py::none();
126159
py::object rowwise_scale = rowwise_usage ? py::cast(rowwise_scale_list[i]) : py::none();
127-
py::object columnwise_scale = columnwise_usage ? py::cast(columnwise_scale_list[i]) : py::none();
160+
py::object columnwise_scale =
161+
columnwise_usage ? py::cast(columnwise_scale_list[i]) : py::none();
128162

129163
py::object ret = Float8BlockwiseQTensorClass(
130-
"rowwise_data"_a = rowwise_data,
131-
"columnwise_data"_a = columnwise_data,
132-
"rowwise_scale_inv"_a = rowwise_scale,
133-
"columnwise_scale_inv"_a = columnwise_scale,
134-
"fp8_dtype"_a = fp8_dtype,
135-
"quantizer"_a = quantizer_list[i],
164+
"rowwise_data"_a = rowwise_data, "columnwise_data"_a = columnwise_data,
165+
"rowwise_scale_inv"_a = rowwise_scale, "columnwise_scale_inv"_a = columnwise_scale,
166+
"fp8_dtype"_a = fp8_dtype, "quantizer"_a = quantizer_list[i],
136167
"is_2D_scaled"_a = is_2D_scaled);
137168

138169
output_list.emplace_back(std::move(ret));
@@ -143,41 +174,20 @@ std::vector<py::object> fused_bulk_alloc_outputs(at::Tensor input_view, std::vec
143174

144175
// put the two full tensor into a python class to maintain their life cycle
145176
py::object ret = Float8BlockwiseQTensorClass(
146-
"rowwise_data"_a = rowwise_full_tensor,
147-
"columnwise_data"_a = columnwise_full_tensor,
148-
"rowwise_scale_inv"_a = py::none(),
149-
"columnwise_scale_inv"_a = py::none(),
150-
"fp8_dtype"_a = transformer_engine::DType::kFloat8E4M3, "quantizer"_a = py::none(), "is_2D_scaled"_a = true);
151-
177+
"rowwise_data"_a = rowwise_full_tensor, "columnwise_data"_a = columnwise_full_tensor,
178+
"rowwise_scale_inv"_a = py::none(), "columnwise_scale_inv"_a = py::none(),
179+
"fp8_dtype"_a = transformer_engine::DType::kFloat8E4M3, "quantizer"_a = py::none(),
180+
"is_2D_scaled"_a = true);
181+
152182
output_list.emplace_back(std::move(ret));
153183

154-
}else{
184+
} else {
155185
NVTE_ERROR("Fused bulk alloc is not supported for this quantizer type");
156186
}
157187

158188
return output_list;
159189
}
160190

161-
py::object simple_sanity_check(at::Tensor input, py::handle quantizer){
162-
init_extension();
163-
using namespace pybind11::literals; // For operator""_a
164-
py::handle Float8BlockwiseQTensorClass(
165-
reinterpret_cast<PyObject*>(Float8BlockwiseQTensorBasePythonClass));
166-
167-
py::object ret = Float8BlockwiseQTensorClass(
168-
"rowwise_data"_a = input,
169-
"columnwise_data"_a = input,
170-
"rowwise_scale_inv"_a = input,
171-
"columnwise_scale_inv"_a = input,
172-
"fp8_dtype"_a = transformer_engine::DType::kFloat8E4M3, "quantizer"_a = quantizer, "is_2D_scaled"_a = true);
173-
174-
// py::handle Float8TensorClass(reinterpret_cast<PyObject*>(Float8TensorBasePythonClass));
175-
// py::object ret = Float8TensorClass("data"_a = py::none(), "fp8_scale_inv"_a = py::none(),
176-
// "fp8_dtype"_a = transformer_engine::DType::kFloat8E4M3, "data_transpose"_a = py::none(),
177-
// "quantizer"_a = py::none());
178-
return ret;
179-
}
180-
181191
std::vector<py::object> fused_multi_quantize(std::vector<at::Tensor> input_list,
182192
std::optional<std::vector<py::object>> output_list,
183193
std::vector<py::handle> quantizer_list,

0 commit comments

Comments
 (0)