diff --git a/deps/kizunapi b/deps/kizunapi index c4dbe5c3..f04ca7c5 160000 --- a/deps/kizunapi +++ b/deps/kizunapi @@ -1 +1 @@ -Subproject commit c4dbe5c3f1eb0fd850be85e9c5a188238874bbe3 +Subproject commit f04ca7c5b7389b057c88fbc7f9debbe22f5a88ec diff --git a/src/array.cc b/src/array.cc index 5a1bc1fe..c523a42f 100644 --- a/src/array.cc +++ b/src/array.cc @@ -1,10 +1,19 @@ -#include "mlx/mlx.h" -#include "src/bindings.h" - -namespace mx = mlx::core; +#include "src/stream.h" +#include "src/util.h" namespace ki { +// Allow passing Dtype to js directly, no memory management involved as they are +// static globals. +template<> +struct TypeBridge { + static inline mx::Dtype* Wrap(mx::Dtype* ptr) { + return ptr; + } + static inline void Finalize(mx::Dtype* ptr) { + } +}; + template<> struct Type { static constexpr const char* name = "Dtype"; @@ -14,12 +23,12 @@ struct Type { DefineProperties(env, prototype, Property("size", Getter(&mx::Dtype::size))); } - // Since Dtype is represented as a class, we have to store it as a pointer in - // js, so converting it to js usually would involve a heap allocation. To - // avoid that let's just find the global const. static inline napi_status ToNode(napi_env env, const mx::Dtype& value, napi_value* result) { + // Since Dtype is represented as a class, we have to store it as a pointer + // in js, so converting it to js usually would involve a heap allocation. To + // avoid that let's just find the global const. if (value == mx::bool_) return ConvertToNode(env, &mx::bool_, result); if (value == mx::uint8) @@ -48,30 +57,23 @@ struct Type { return ConvertToNode(env, &mx::complex64, result); return napi_generic_failure; } - // Dtype is stored as pointer in js, convert it to value in C++ by copy. static inline std::optional FromNode(napi_env env, napi_value value) { - std::optional ptr = ki::FromNode(env, value); - if (!ptr) - return std::nullopt; - return *ptr.value(); + return NodeObjToCppValue(env, value); } }; -// Allow passing Dtype to js directly, no memory management involved as they are -// static globals. template<> -struct TypeBridge { - static inline mx::Dtype* Wrap(mx::Dtype* ptr) { - return ptr; - } - static inline void Finalize(mx::Dtype* ptr) { +struct TypeBridge { + static inline void Finalize(mx::array* ptr) { + delete ptr; } }; template<> struct Type { static constexpr const char* name = "array"; + static mx::array* Constructor(napi_env env, napi_value value, std::optional dtype) { @@ -89,9 +91,7 @@ struct Type { return nullptr; } } - static inline void Destructor(mx::array* ptr) { - delete ptr; - } + static void Define(napi_env env, napi_value constructor, napi_value prototype) { @@ -108,47 +108,51 @@ struct Type { Property("dtype", Getter(&mx::array::dtype))); // Define array's methods. Set(env, prototype, - "item", MemberFunction(&Item)); + "item", MemberFunction(&Item), + "astype", MemberFunction(&mx::astype)); } + static napi_value Item(mx::array* a, napi_env env) { a->eval(); switch (a->dtype()) { case mx::bool_: - return ToNode(env, a->item()); + return ki::ToNode(env, a->item()); case mx::uint8: - return ToNode(env, a->item()); + return ki::ToNode(env, a->item()); case mx::uint16: - return ToNode(env, a->item()); + return ki::ToNode(env, a->item()); case mx::uint32: - return ToNode(env, a->item()); + return ki::ToNode(env, a->item()); case mx::uint64: - return ToNode(env, a->item()); + return ki::ToNode(env, a->item()); case mx::int8: - return ToNode(env, a->item()); + return ki::ToNode(env, a->item()); case mx::int16: - return ToNode(env, a->item()); + return ki::ToNode(env, a->item()); case mx::int32: - return ToNode(env, a->item()); + return ki::ToNode(env, a->item()); case mx::int64: - return ToNode(env, a->item()); + return ki::ToNode(env, a->item()); case mx::float16: - return ToNode(env, static_cast(a->item())); + return ki::ToNode(env, static_cast(a->item())); case mx::float32: - return ToNode(env, a->item()); + return ki::ToNode(env, a->item()); case mx::bfloat16: - return ToNode(env, static_cast(a->item())); + return ki::ToNode(env, static_cast(a->item())); case mx::complex64: // FIXME(zcbenz): Represent complex number in JS. return Undefined(env); } } - // array is stored as pointer in js, convert it to value in C++ by copy. + + static inline napi_status ToNode(napi_env env, + mx::array a, + napi_value* result) { + return ManagePointerInJSWrapper(env, new mx::array(std::move(a)), result); + } static inline std::optional FromNode(napi_env env, napi_value value) { - std::optional ptr = ki::FromNode(env, value); - if (!ptr) - return std::nullopt; - return *ptr.value(); + return NodeObjToCppValue(env, value); } }; diff --git a/src/bindings.cc b/src/bindings.cc index 6353c36f..7b1bb41a 100644 --- a/src/bindings.cc +++ b/src/bindings.cc @@ -3,6 +3,8 @@ namespace { napi_value Init(napi_env env, napi_value exports) { + InitDevice(env, exports); + InitStream(env, exports); InitArray(env, exports); return exports; } diff --git a/src/bindings.h b/src/bindings.h index 15abbcda..bba9c6e6 100644 --- a/src/bindings.h +++ b/src/bindings.h @@ -1,8 +1,13 @@ #ifndef SRC_BINDINGS_H_ #define SRC_BINDINGS_H_ +#include #include +namespace mx = mlx::core; + +void InitDevice(napi_env env, napi_value exports); +void InitStream(napi_env env, napi_value exports); void InitArray(napi_env env, napi_value exports); #endif // SRC_BINDINGS_H_ diff --git a/src/device.cc b/src/device.cc new file mode 100644 index 00000000..3f94a632 --- /dev/null +++ b/src/device.cc @@ -0,0 +1,74 @@ +#include "src/device.h" +#include "src/util.h" + +namespace ki { + +template<> +struct TypeBridge { + static inline void Finalize(mx::Device* ptr) { + delete ptr; + } +}; + +// static +napi_status Type::ToNode( + napi_env env, mx::Device::DeviceType type, napi_value* result) { + return ConvertToNode(env, static_cast(type), result); +} + +// static +std::optional Type::FromNode( + napi_env env, napi_value value) { + std::optional type = ki::FromNode(env, value); + if (!type) + return std::nullopt; + if (*type == static_cast(mx::Device::DeviceType::cpu)) + return mx::Device::DeviceType::cpu; + if (*type == static_cast(mx::Device::DeviceType::gpu)) + return mx::Device::DeviceType::gpu; + return std::nullopt; +} + +// static +mx::Device* Type::Constructor(mx::Device::DeviceType type, + int index) { + return new mx::Device(type, index); +} + +// static +void Type::Define(napi_env env, + napi_value constructor, + napi_value prototype) { + DefineProperties(env, prototype, + Property("type", Getter(&mx::Device::type))); +} + +// static +napi_status Type::ToNode(napi_env env, + mx::Device device, + napi_value* result) { + return ManagePointerInJSWrapper( + env, new mx::Device(std::move(device)), result); +} + +// static +std::optional Type::FromNode(napi_env env, + napi_value value) { + // Try creating a Device when value is a DeviceType. + auto type = ki::FromNode(env, value); + if (type) + return mx::Device(*type); + // Otherwise try converting from Device. + return NodeObjToCppValue(env, value); +} + +} // namespace ki + +void InitDevice(napi_env env, napi_value exports) { + ki::Set(env, exports, + "cpu", mx::Device::DeviceType::cpu, + "gpu", mx::Device::DeviceType::gpu, + "Device", ki::Class(), + "defaultDevice", mx::default_device, + "setDefaultDevice", mx::set_default_device); +} diff --git a/src/device.h b/src/device.h new file mode 100644 index 00000000..19081bb0 --- /dev/null +++ b/src/device.h @@ -0,0 +1,36 @@ +#ifndef SRC_DEVICE_H_ +#define SRC_DEVICE_H_ + +#include "src/bindings.h" + +namespace ki { + +template<> +struct Type { + static constexpr const char* name = "DeviceType"; + static napi_status ToNode(napi_env env, + mx::Device::DeviceType type, + napi_value* result); + static std::optional FromNode(napi_env env, + napi_value value); +}; + +template<> +struct Type { + static constexpr const char* name = "Device"; + + static mx::Device* Constructor(mx::Device::DeviceType type, int index); + static void Define(napi_env env, + napi_value constructor, + napi_value prototype); + + static napi_status ToNode(napi_env env, + mx::Device device, + napi_value* result); + static std::optional FromNode(napi_env env, + napi_value value); +}; + +} // namespace ki + +#endif // SRC_DEVICE_H_ diff --git a/src/stream.cc b/src/stream.cc new file mode 100644 index 00000000..e98ecfc7 --- /dev/null +++ b/src/stream.cc @@ -0,0 +1,61 @@ +#include "src/stream.h" +#include "src/util.h" + +namespace ki { + +template<> +struct TypeBridge { + static inline void Finalize(mx::Stream* ptr) { + delete ptr; + } +}; + +// static +mx::Stream* Type::Constructor(int index, const mx::Device& device) { + return new mx::Stream(index, device); +} + +// static +void Type::Define(napi_env env, + napi_value constructor, + napi_value prototype) { + DefineProperties(env, prototype, + Property("device", Getter(&mx::Stream::device))); +} + +// static +napi_status Type::ToNode(napi_env env, + mx::Stream stream, + napi_value* result) { + return ManagePointerInJSWrapper( + env, new mx::Stream(std::move(stream)), result); +} + +// static +std::optional Type::FromNode(napi_env env, + napi_value value) { + return NodeObjToCppValue(env, value); +} + +// static +std::optional Type::FromNode( + napi_env env, + napi_value value) { + std::optional stream = Type::FromNode(env, value); + if (stream) + return *stream; + std::optional device = Type::FromNode(env, value); + if (device) + return *device; + return std::nullopt; +} + +} // namespace ki + +void InitStream(napi_env env, napi_value exports) { + ki::Set(env, exports, + "Stream", ki::Class(), + "defaultStream", mx::default_stream, + "setDefaultStream", mx::set_default_stream, + "newStream", mx::new_stream); +} diff --git a/src/stream.h b/src/stream.h new file mode 100644 index 00000000..353bf09f --- /dev/null +++ b/src/stream.h @@ -0,0 +1,33 @@ +#ifndef SRC_STREAM_H_ +#define SRC_STREAM_H_ + +#include "src/device.h" + +namespace ki { + +template<> +struct Type { + static constexpr const char* name = "Stream"; + + static mx::Stream* Constructor(int index, const mx::Device& device); + static void Define(napi_env env, + napi_value constructor, + napi_value prototype); + + static napi_status ToNode(napi_env env, + mx::Stream stream, + napi_value* result); + static std::optional FromNode(napi_env env, + napi_value value); +}; + +template<> +struct Type { + static constexpr const char* name = "StreamOrDevice"; + static std::optional FromNode(napi_env env, + napi_value value); +}; + +} // namespace ki + +#endif // SRC_STREAM_H_ diff --git a/src/util.h b/src/util.h new file mode 100644 index 00000000..fc90a9c6 --- /dev/null +++ b/src/util.h @@ -0,0 +1,16 @@ +#ifndef SRC_UTIL_H_ +#define SRC_UTIL_H_ + +#include + +// In js land the objects are always stored as pointers, when a value is needed +// from C++ land, we do a copy. +template +inline std::optional NodeObjToCppValue(napi_env env, napi_value value) { + std::optional ptr = ki::FromNode(env, value); + if (!ptr) + return std::nullopt; + return *ptr.value(); +} + +#endif // SRC_UTIL_H_