|
5 | 5 | ************************************************************************/
|
6 | 6 |
|
7 | 7 | #include <optional>
|
| 8 | +#include <pybind.h> |
8 | 9 |
|
9 | 10 | #include "extensions.h"
|
10 | 11 | #include "pybind.h"
|
11 | 12 |
|
| 13 | +#include <iostream> |
| 14 | + |
12 | 15 | namespace transformer_engine::pytorch {
|
13 | 16 |
|
| 17 | +std::vector<py::object> fused_bulk_alloc_outputs(at::Tensor input_view, std::vector<int> m_splits, |
| 18 | + std::vector<py::handle> quantizer_list) { |
| 19 | + init_extension(); |
| 20 | + using namespace pybind11::literals; // For operator""_a |
| 21 | + |
| 22 | + int num_splits = m_splits.size(); |
| 23 | + |
| 24 | + // convert all the quantizers |
| 25 | + std::vector<std::unique_ptr<Quantizer>> quantizers; |
| 26 | + for (int i = 0; i < num_splits; i++) { |
| 27 | + quantizers.push_back(convert_quantizer(quantizer_list[i])); |
| 28 | + } |
| 29 | + |
| 30 | + bool rowwise_usage = quantizers[0]->rowwise_usage; |
| 31 | + bool columnwise_usage = quantizers[0]->columnwise_usage; |
| 32 | + size_t hidden_dim = input_view.size(1); |
| 33 | + |
| 34 | + std::vector<py::object> output_list; |
| 35 | + |
| 36 | + if (detail::IsFloat8BlockwiseQuantizers(quantizer_list[0].ptr())) { |
| 37 | + // implement the fuse bulk alloc for blockwise quantizer |
| 38 | + // downcast quantizers, resorces are owned by the unique_ptr, so use raw ptr here just to get the attributes |
| 39 | + std::vector<Float8BlockQuantizer*> blockwise_quantizers; |
| 40 | + for (size_t i = 0; i < quantizers.size(); i++) { |
| 41 | + Quantizer* raw_ptr = quantizers[i].get(); |
| 42 | + Float8BlockQuantizer* blockwise_quantizer = static_cast<Float8BlockQuantizer*>(raw_ptr); |
| 43 | + blockwise_quantizers.push_back(blockwise_quantizer); |
| 44 | + } |
| 45 | + |
| 46 | + bool is_2D_scaled = blockwise_quantizers[0]->get_scaling_mode() == NVTE_BLOCK_SCALING_2D; |
| 47 | + transformer_engine::DType fp8_dtype = blockwise_quantizers[0]->dtype; |
| 48 | + |
| 49 | + size_t fp8_elem_size = 1; |
| 50 | + size_t scale_elem_size = 4; |
| 51 | + |
| 52 | + std::vector<std::pair<size_t, size_t>> rowwise_data_shapes; |
| 53 | + std::vector<std::pair<size_t, size_t>> rowwise_scale_shapes; |
| 54 | + std::vector<size_t> rowwise_data_sizes; |
| 55 | + std::vector<size_t> rowwise_scale_sizes; |
| 56 | + std::vector<std::pair<size_t, size_t>> columnwise_data_shapes; |
| 57 | + std::vector<std::pair<size_t, size_t>> columnwise_scale_shapes; |
| 58 | + std::vector<size_t> columnwise_data_sizes; |
| 59 | + std::vector<size_t> columnwise_scale_sizes; |
| 60 | + for (int i = 0; i < num_splits; i++) { |
| 61 | + std::pair<size_t, size_t> input_view_i_shape = std::make_pair((size_t)m_splits[i], (size_t)hidden_dim); |
| 62 | + if (rowwise_usage) { |
| 63 | + rowwise_data_shapes.emplace_back(input_view_i_shape); |
| 64 | + rowwise_scale_shapes.emplace_back(blockwise_quantizers[i]->get_scale_shape({input_view_i_shape.first, input_view_i_shape.second}, false)); |
| 65 | + rowwise_data_sizes.emplace_back(input_view_i_shape.first * input_view_i_shape.second * fp8_elem_size); |
| 66 | + rowwise_scale_sizes.emplace_back(rowwise_scale_shapes.back().first * rowwise_scale_shapes.back().second * scale_elem_size); |
| 67 | + } |
| 68 | + if (columnwise_usage) { |
| 69 | + columnwise_data_shapes.emplace_back(std::make_pair(input_view_i_shape.second, input_view_i_shape.first)); |
| 70 | + columnwise_scale_shapes.emplace_back(blockwise_quantizers[i]->get_scale_shape({input_view_i_shape.first, input_view_i_shape.second}, true)); |
| 71 | + columnwise_data_sizes.emplace_back(input_view_i_shape.first * input_view_i_shape.second * fp8_elem_size); |
| 72 | + columnwise_scale_sizes.emplace_back(columnwise_scale_shapes.back().first * columnwise_scale_shapes.back().second * scale_elem_size); |
| 73 | + } |
| 74 | + } |
| 75 | + |
| 76 | + size_t total_size_rowwise_data = std::accumulate(rowwise_data_sizes.begin(), rowwise_data_sizes.end(), 0); |
| 77 | + size_t total_size_rowwise_scale = std::accumulate(rowwise_scale_sizes.begin(), rowwise_scale_sizes.end(), 0); |
| 78 | + size_t total_size_columnwise_data = std::accumulate(columnwise_data_sizes.begin(), columnwise_data_sizes.end(), 0); |
| 79 | + size_t total_size_columnwise_scale = std::accumulate(columnwise_scale_sizes.begin(), columnwise_scale_sizes.end(), 0); |
| 80 | + |
| 81 | + size_t total_size_rowwise = total_size_rowwise_data + total_size_rowwise_scale; |
| 82 | + size_t total_size_columnwise = total_size_columnwise_data + total_size_columnwise_scale; |
| 83 | + |
| 84 | + std::vector<at::Tensor> rowwise_data_list; |
| 85 | + std::vector<at::Tensor> rowwise_scale_list; |
| 86 | + std::vector<at::Tensor> columnwise_data_list; |
| 87 | + std::vector<at::Tensor> columnwise_scale_list; |
| 88 | + |
| 89 | + at::Tensor rowwise_full_tensor; |
| 90 | + at::Tensor columnwise_full_tensor; |
| 91 | + |
| 92 | + if (rowwise_usage) { |
| 93 | + rowwise_full_tensor = at::empty({(int64_t)total_size_rowwise}, at::device(input_view.device()).dtype(torch::kUInt8)); |
| 94 | + // use raw pointer math + from blob, avoid torch slice to reduce cpu overhead |
| 95 | + uint8_t* rowwise_data_ptr = rowwise_full_tensor.data_ptr<uint8_t>(); |
| 96 | + uint8_t* rowwise_scale_ptr = rowwise_full_tensor.data_ptr<uint8_t>() + total_size_rowwise_data; |
| 97 | + // use from_blob to construct rowwise_data_list and rowwise_scale_list |
| 98 | + for (int i = 0; i < num_splits; i++) { |
| 99 | + rowwise_data_list.emplace_back(at::from_blob(rowwise_data_ptr, {static_cast<int64_t>(rowwise_data_shapes[i].first), static_cast<int64_t>(rowwise_data_shapes[i].second)}, at::device(input_view.device()).dtype(torch::kUInt8))); |
| 100 | + rowwise_scale_list.emplace_back(at::from_blob(rowwise_scale_ptr, {static_cast<int64_t>(rowwise_scale_shapes[i].first), static_cast<int64_t>(rowwise_scale_shapes[i].second)}, at::device(input_view.device()).dtype(torch::kFloat32))); |
| 101 | + rowwise_data_ptr += rowwise_data_sizes[i]; |
| 102 | + rowwise_scale_ptr += rowwise_scale_sizes[i]; |
| 103 | + } |
| 104 | + } |
| 105 | + |
| 106 | + if (columnwise_usage) { |
| 107 | + columnwise_full_tensor = at::empty({(int64_t)total_size_columnwise}, at::device(input_view.device()).dtype(torch::kUInt8)); |
| 108 | + uint8_t* columnwise_data_ptr = columnwise_full_tensor.data_ptr<uint8_t>(); |
| 109 | + uint8_t* columnwise_scale_ptr = columnwise_full_tensor.data_ptr<uint8_t>() + total_size_columnwise_data; |
| 110 | + for (int i = 0; i < num_splits; i++) { |
| 111 | + columnwise_data_list.emplace_back(at::from_blob(columnwise_data_ptr, {static_cast<int64_t>(columnwise_data_shapes[i].first), static_cast<int64_t>(columnwise_data_shapes[i].second)}, at::device(input_view.device()).dtype(torch::kUInt8))); |
| 112 | + columnwise_scale_list.emplace_back(at::from_blob(columnwise_scale_ptr, {static_cast<int64_t>(columnwise_scale_shapes[i].first), static_cast<int64_t>(columnwise_scale_shapes[i].second)}, at::device(input_view.device()).dtype(torch::kFloat32))); |
| 113 | + columnwise_data_ptr += columnwise_data_sizes[i]; |
| 114 | + columnwise_scale_ptr += columnwise_scale_sizes[i]; |
| 115 | + } |
| 116 | + } |
| 117 | + |
| 118 | + for (int i = 0; i < num_splits; i++) { |
| 119 | + |
| 120 | + py::handle Float8BlockwiseQTensorClass( |
| 121 | + reinterpret_cast<PyObject*>(Float8BlockwiseQTensorBasePythonClass)); |
| 122 | + |
| 123 | + // Create the tensor object with proper reference counting |
| 124 | + py::object rowwise_data = rowwise_usage ? py::cast(rowwise_data_list[i]) : py::none(); |
| 125 | + py::object columnwise_data = columnwise_usage ? py::cast(columnwise_data_list[i]) : py::none(); |
| 126 | + py::object rowwise_scale = rowwise_usage ? py::cast(rowwise_scale_list[i]) : py::none(); |
| 127 | + py::object columnwise_scale = columnwise_usage ? py::cast(columnwise_scale_list[i]) : py::none(); |
| 128 | + |
| 129 | + py::object ret = Float8BlockwiseQTensorClass( |
| 130 | + "rowwise_data"_a = rowwise_data, |
| 131 | + "columnwise_data"_a = columnwise_data, |
| 132 | + "rowwise_scale_inv"_a = rowwise_scale, |
| 133 | + "columnwise_scale_inv"_a = columnwise_scale, |
| 134 | + "fp8_dtype"_a = fp8_dtype, |
| 135 | + "quantizer"_a = quantizer_list[i], |
| 136 | + "is_2D_scaled"_a = is_2D_scaled); |
| 137 | + |
| 138 | + output_list.emplace_back(std::move(ret)); |
| 139 | + } |
| 140 | + |
| 141 | + py::handle Float8BlockwiseQTensorClass( |
| 142 | + reinterpret_cast<PyObject*>(Float8BlockwiseQTensorBasePythonClass)); |
| 143 | + |
| 144 | + // put the two full tensor into a python class to maintain their life cycle |
| 145 | + py::object ret = Float8BlockwiseQTensorClass( |
| 146 | + "rowwise_data"_a = rowwise_full_tensor, |
| 147 | + "columnwise_data"_a = columnwise_full_tensor, |
| 148 | + "rowwise_scale_inv"_a = py::none(), |
| 149 | + "columnwise_scale_inv"_a = py::none(), |
| 150 | + "fp8_dtype"_a = transformer_engine::DType::kFloat8E4M3, "quantizer"_a = py::none(), "is_2D_scaled"_a = true); |
| 151 | + |
| 152 | + output_list.emplace_back(std::move(ret)); |
| 153 | + |
| 154 | + }else{ |
| 155 | + NVTE_ERROR("Fused bulk alloc is not supported for this quantizer type"); |
| 156 | + } |
| 157 | + |
| 158 | + return output_list; |
| 159 | +} |
| 160 | + |
| 161 | +py::object simple_sanity_check(at::Tensor input, py::handle quantizer){ |
| 162 | + init_extension(); |
| 163 | + using namespace pybind11::literals; // For operator""_a |
| 164 | + py::handle Float8BlockwiseQTensorClass( |
| 165 | + reinterpret_cast<PyObject*>(Float8BlockwiseQTensorBasePythonClass)); |
| 166 | + |
| 167 | + py::object ret = Float8BlockwiseQTensorClass( |
| 168 | + "rowwise_data"_a = input, |
| 169 | + "columnwise_data"_a = input, |
| 170 | + "rowwise_scale_inv"_a = input, |
| 171 | + "columnwise_scale_inv"_a = input, |
| 172 | + "fp8_dtype"_a = transformer_engine::DType::kFloat8E4M3, "quantizer"_a = quantizer, "is_2D_scaled"_a = true); |
| 173 | + |
| 174 | + // py::handle Float8TensorClass(reinterpret_cast<PyObject*>(Float8TensorBasePythonClass)); |
| 175 | + // py::object ret = Float8TensorClass("data"_a = py::none(), "fp8_scale_inv"_a = py::none(), |
| 176 | + // "fp8_dtype"_a = transformer_engine::DType::kFloat8E4M3, "data_transpose"_a = py::none(), |
| 177 | + // "quantizer"_a = py::none()); |
| 178 | + return ret; |
| 179 | +} |
| 180 | + |
14 | 181 | std::vector<py::object> fused_multi_quantize(std::vector<at::Tensor> input_list,
|
15 | 182 | std::optional<std::vector<py::object>> output_list,
|
16 | 183 | std::vector<py::handle> quantizer_list,
|
|
0 commit comments