|
| 1 | +// Copyright 2024 Advanced Micro Devices, Inc |
| 2 | +// |
| 3 | +// Licensed under the Apache License v2.0 with LLVM Exceptions. |
| 4 | +// See https://llvm.org/LICENSE.txt for license information. |
| 5 | +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | + |
| 7 | +#include "./lib_ext.h" |
| 8 | +#include "./utils.h" |
| 9 | +#include "shortfin/array/api.h" |
| 10 | + |
| 11 | +using namespace shortfin::array; |
| 12 | + |
| 13 | +namespace shortfin::python { |
| 14 | + |
| 15 | +void BindArray(py::module_ &m) { |
| 16 | + py::class_<DType>(m, "DType") |
| 17 | + .def_prop_ro("is_boolean", &DType::is_boolean) |
| 18 | + .def_prop_ro("is_integer", &DType::is_integer) |
| 19 | + .def_prop_ro("is_float", &DType::is_float) |
| 20 | + .def_prop_ro("is_complex", &DType::is_complex) |
| 21 | + .def_prop_ro("bit_count", &DType::bit_count) |
| 22 | + .def_prop_ro("is_byte_aligned", &DType::is_byte_aligned) |
| 23 | + .def_prop_ro("dense_byte_count", &DType::dense_byte_count) |
| 24 | + .def("is_integer_bitwidth", &DType::is_integer_bitwidth) |
| 25 | + .def(py::self == py::self) |
| 26 | + .def("__repr__", &DType::name); |
| 27 | + |
| 28 | + m.attr("opaque8") = DType::opaque8(); |
| 29 | + m.attr("opaque16") = DType::opaque16(); |
| 30 | + m.attr("opaque32") = DType::opaque32(); |
| 31 | + m.attr("opaque64") = DType::opaque64(); |
| 32 | + m.attr("bool8") = DType::bool8(); |
| 33 | + m.attr("int4") = DType::int4(); |
| 34 | + m.attr("sint4") = DType::sint4(); |
| 35 | + m.attr("uint4") = DType::uint4(); |
| 36 | + m.attr("int8") = DType::int8(); |
| 37 | + m.attr("sint8") = DType::sint8(); |
| 38 | + m.attr("uint8") = DType::uint8(); |
| 39 | + m.attr("int16") = DType::int16(); |
| 40 | + m.attr("sint16") = DType::sint16(); |
| 41 | + m.attr("uint16") = DType::uint16(); |
| 42 | + m.attr("int32") = DType::int32(); |
| 43 | + m.attr("sint32") = DType::sint32(); |
| 44 | + m.attr("uint32") = DType::uint32(); |
| 45 | + m.attr("int64") = DType::int64(); |
| 46 | + m.attr("sint64") = DType::sint64(); |
| 47 | + m.attr("uint64") = DType::uint64(); |
| 48 | + m.attr("float16") = DType::float16(); |
| 49 | + m.attr("float32") = DType::float32(); |
| 50 | + m.attr("float64") = DType::float64(); |
| 51 | + m.attr("bfloat16") = DType::bfloat16(); |
| 52 | + m.attr("complex64") = DType::complex64(); |
| 53 | + m.attr("complex128") = DType::complex128(); |
| 54 | + |
| 55 | + py::class_<storage>(m, "storage") |
| 56 | + .def_static( |
| 57 | + "allocate_host", |
| 58 | + [](local::ScopedDevice &device, iree_device_size_t allocation_size) { |
| 59 | + return storage::AllocateHost(device, allocation_size); |
| 60 | + }, |
| 61 | + py::arg("device"), py::arg("allocation_size"), py::keep_alive<0, 1>()) |
| 62 | + .def_static( |
| 63 | + "allocate_device", |
| 64 | + [](local::ScopedDevice &device, iree_device_size_t allocation_size) { |
| 65 | + return storage::AllocateDevice(device, allocation_size); |
| 66 | + }, |
| 67 | + py::arg("device"), py::arg("allocation_size"), py::keep_alive<0, 1>()) |
| 68 | + .def("fill", |
| 69 | + [](storage &self, py::handle buffer) { |
| 70 | + Py_buffer py_view; |
| 71 | + int flags = PyBUF_FORMAT | PyBUF_ND; // C-Contiguous ND. |
| 72 | + if (PyObject_GetBuffer(buffer.ptr(), &py_view, flags) != 0) { |
| 73 | + throw py::python_error(); |
| 74 | + } |
| 75 | + PyBufferReleaser py_view_releaser(py_view); |
| 76 | + self.Fill(py_view.buf, py_view.len); |
| 77 | + }) |
| 78 | + .def("__repr__", &storage::to_s); |
| 79 | + |
| 80 | + py::class_<base_array>(m, "base_array") |
| 81 | + .def_prop_ro("dtype", &base_array::dtype) |
| 82 | + .def_prop_ro("shape", &base_array::shape); |
| 83 | + py::class_<device_array, base_array>(m, "device_array") |
| 84 | + .def("__init__", [](py::args, py::kwargs) {}) |
| 85 | + .def_static("__new__", |
| 86 | + [](py::handle py_type, class storage storage, |
| 87 | + std::span<const size_t> shape, DType dtype) { |
| 88 | + return custom_new_keep_alive<device_array>( |
| 89 | + py_type, /*keep_alive=*/storage.scope(), storage, shape, |
| 90 | + dtype); |
| 91 | + }) |
| 92 | + .def_static("__new__", |
| 93 | + [](py::handle py_type, local::ScopedDevice &device, |
| 94 | + std::span<const size_t> shape, DType dtype) { |
| 95 | + return custom_new_keep_alive<device_array>( |
| 96 | + py_type, /*keep_alive=*/device.scope(), |
| 97 | + device_array::allocate(device, shape, dtype)); |
| 98 | + }) |
| 99 | + .def_prop_ro("device", &device_array::device, |
| 100 | + py::rv_policy::reference_internal) |
| 101 | + .def_prop_ro("storage", &device_array::storage, |
| 102 | + py::rv_policy::reference_internal) |
| 103 | + .def("__repr__", &device_array::to_s); |
| 104 | + py::class_<host_array, base_array>(m, "host_array") |
| 105 | + .def("__init__", [](py::args, py::kwargs) {}) |
| 106 | + .def_static("__new__", |
| 107 | + [](py::handle py_type, class storage storage, |
| 108 | + std::span<const size_t> shape, DType dtype) { |
| 109 | + return custom_new_keep_alive<host_array>( |
| 110 | + py_type, /*keep_alive=*/storage.scope(), storage, shape, |
| 111 | + dtype); |
| 112 | + }) |
| 113 | + .def_static("__new__", |
| 114 | + [](py::handle py_type, local::ScopedDevice &device, |
| 115 | + std::span<const size_t> shape, DType dtype) { |
| 116 | + return custom_new_keep_alive<host_array>( |
| 117 | + py_type, /*keep_alive=*/device.scope(), |
| 118 | + host_array::allocate(device, shape, dtype)); |
| 119 | + }) |
| 120 | + .def_static("__new__", |
| 121 | + [](py::handle py_type, device_array &device_array) { |
| 122 | + return custom_new_keep_alive<host_array>( |
| 123 | + py_type, /*keep_alive=*/device_array.device().scope(), |
| 124 | + host_array::for_transfer(device_array)); |
| 125 | + }) |
| 126 | + .def_prop_ro("device", &host_array::device, |
| 127 | + py::rv_policy::reference_internal) |
| 128 | + .def_prop_ro("storage", &host_array::storage, |
| 129 | + py::rv_policy::reference_internal) |
| 130 | + .def("__repr__", &host_array::to_s); |
| 131 | +} |
| 132 | + |
| 133 | +} // namespace shortfin::python |
0 commit comments