Skip to content

Commit

Permalink
Add device and stream
Browse files Browse the repository at this point in the history
  • Loading branch information
zcbenz committed Apr 12, 2024
1 parent ce5a163 commit 13f2aaa
Show file tree
Hide file tree
Showing 9 changed files with 272 additions and 41 deletions.
2 changes: 1 addition & 1 deletion deps/kizunapi
84 changes: 44 additions & 40 deletions src/array.cc
Original file line number Diff line number Diff line change
@@ -1,10 +1,19 @@
#include "mlx/mlx.h"
#include "src/bindings.h"

namespace mx = mlx::core;
#include "src/stream.h"
#include "src/util.h"

namespace ki {

// Allow passing Dtype to js directly, no memory management involved as they are
// static globals.
template<>
struct TypeBridge<mx::Dtype> {
static inline mx::Dtype* Wrap(mx::Dtype* ptr) {
return ptr;
}
static inline void Finalize(mx::Dtype* ptr) {
}
};

template<>
struct Type<mx::Dtype> {
static constexpr const char* name = "Dtype";
Expand All @@ -14,12 +23,12 @@ struct Type<mx::Dtype> {
DefineProperties(env, prototype,
Property("size", Getter(&mx::Dtype::size)));
}
// Since Dtype is represented as a class, we have to store it as a pointer in
// js, so converting it to js usually would involve a heap allocation. To
// avoid that let's just find the global const.
static inline napi_status ToNode(napi_env env,
const mx::Dtype& value,
napi_value* result) {
// Since Dtype is represented as a class, we have to store it as a pointer
// in js, so converting it to js usually would involve a heap allocation. To
// avoid that let's just find the global const.
if (value == mx::bool_)
return ConvertToNode(env, &mx::bool_, result);
if (value == mx::uint8)
Expand Down Expand Up @@ -48,30 +57,23 @@ struct Type<mx::Dtype> {
return ConvertToNode(env, &mx::complex64, result);
return napi_generic_failure;
}
// Dtype is stored as pointer in js, convert it to value in C++ by copy.
static inline std::optional<mx::Dtype> FromNode(napi_env env,
napi_value value) {
std::optional<mx::Dtype*> ptr = ki::FromNode<mx::Dtype*>(env, value);
if (!ptr)
return std::nullopt;
return *ptr.value();
return NodeObjToCppValue<mx::Dtype>(env, value);
}
};

// Allow passing Dtype to js directly, no memory management involved as they are
// static globals.
template<>
struct TypeBridge<mx::Dtype> {
static inline mx::Dtype* Wrap(mx::Dtype* ptr) {
return ptr;
}
static inline void Finalize(mx::Dtype* ptr) {
struct TypeBridge<mx::array> {
static inline void Finalize(mx::array* ptr) {
delete ptr;
}
};

template<>
struct Type<mx::array> {
static constexpr const char* name = "array";

static mx::array* Constructor(napi_env env,
napi_value value,
std::optional<mx::Dtype> dtype) {
Expand All @@ -89,9 +91,7 @@ struct Type<mx::array> {
return nullptr;
}
}
static inline void Destructor(mx::array* ptr) {
delete ptr;
}

static void Define(napi_env env,
napi_value constructor,
napi_value prototype) {
Expand All @@ -108,47 +108,51 @@ struct Type<mx::array> {
Property("dtype", Getter(&mx::array::dtype)));
// Define array's methods.
Set(env, prototype,
"item", MemberFunction(&Item));
"item", MemberFunction(&Item),
"astype", MemberFunction(&mx::astype));
}

static napi_value Item(mx::array* a, napi_env env) {
a->eval();
switch (a->dtype()) {
case mx::bool_:
return ToNode(env, a->item<bool>());
return ki::ToNode(env, a->item<bool>());
case mx::uint8:
return ToNode(env, a->item<uint8_t>());
return ki::ToNode(env, a->item<uint8_t>());
case mx::uint16:
return ToNode(env, a->item<uint16_t>());
return ki::ToNode(env, a->item<uint16_t>());
case mx::uint32:
return ToNode(env, a->item<uint32_t>());
return ki::ToNode(env, a->item<uint32_t>());
case mx::uint64:
return ToNode(env, a->item<uint64_t>());
return ki::ToNode(env, a->item<uint64_t>());
case mx::int8:
return ToNode(env, a->item<int8_t>());
return ki::ToNode(env, a->item<int8_t>());
case mx::int16:
return ToNode(env, a->item<int16_t>());
return ki::ToNode(env, a->item<int16_t>());
case mx::int32:
return ToNode(env, a->item<int32_t>());
return ki::ToNode(env, a->item<int32_t>());
case mx::int64:
return ToNode(env, a->item<int64_t>());
return ki::ToNode(env, a->item<int64_t>());
case mx::float16:
return ToNode(env, static_cast<float>(a->item<mx::float16_t>()));
return ki::ToNode(env, static_cast<float>(a->item<mx::float16_t>()));
case mx::float32:
return ToNode(env, a->item<float>());
return ki::ToNode(env, a->item<float>());
case mx::bfloat16:
return ToNode(env, static_cast<float>(a->item<mx::bfloat16_t>()));
return ki::ToNode(env, static_cast<float>(a->item<mx::bfloat16_t>()));
case mx::complex64:
// FIXME(zcbenz): Represent complex number in JS.
return Undefined(env);
}
}
// array is stored as pointer in js, convert it to value in C++ by copy.

static inline napi_status ToNode(napi_env env,
mx::array a,
napi_value* result) {
return ManagePointerInJSWrapper(env, new mx::array(std::move(a)), result);
}
static inline std::optional<mx::array> FromNode(napi_env env,
napi_value value) {
std::optional<mx::array*> ptr = ki::FromNode<mx::array*>(env, value);
if (!ptr)
return std::nullopt;
return *ptr.value();
return NodeObjToCppValue<mx::array>(env, value);
}
};

Expand Down
2 changes: 2 additions & 0 deletions src/bindings.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
namespace {

napi_value Init(napi_env env, napi_value exports) {
InitDevice(env, exports);
InitStream(env, exports);
InitArray(env, exports);
return exports;
}
Expand Down
5 changes: 5 additions & 0 deletions src/bindings.h
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
#ifndef SRC_BINDINGS_H_
#define SRC_BINDINGS_H_

#include <mlx/mlx.h>
#include <kizunapi.h>

namespace mx = mlx::core;

void InitDevice(napi_env env, napi_value exports);
void InitStream(napi_env env, napi_value exports);
void InitArray(napi_env env, napi_value exports);

#endif // SRC_BINDINGS_H_
74 changes: 74 additions & 0 deletions src/device.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
#include "src/device.h"
#include "src/util.h"

namespace ki {

template<>
struct TypeBridge<mx::Device> {
static inline void Finalize(mx::Device* ptr) {
delete ptr;
}
};

// static
napi_status Type<mx::Device::DeviceType>::ToNode(
napi_env env, mx::Device::DeviceType type, napi_value* result) {
return ConvertToNode(env, static_cast<int>(type), result);
}

// static
std::optional<mx::Device::DeviceType> Type<mx::Device::DeviceType>::FromNode(
napi_env env, napi_value value) {
std::optional<int> type = ki::FromNode<int>(env, value);
if (!type)
return std::nullopt;
if (*type == static_cast<int>(mx::Device::DeviceType::cpu))
return mx::Device::DeviceType::cpu;
if (*type == static_cast<int>(mx::Device::DeviceType::gpu))
return mx::Device::DeviceType::gpu;
return std::nullopt;
}

// static
mx::Device* Type<mx::Device>::Constructor(mx::Device::DeviceType type,
int index) {
return new mx::Device(type, index);
}

// static
void Type<mx::Device>::Define(napi_env env,
napi_value constructor,
napi_value prototype) {
DefineProperties(env, prototype,
Property("type", Getter(&mx::Device::type)));
}

// static
napi_status Type<mx::Device>::ToNode(napi_env env,
mx::Device device,
napi_value* result) {
return ManagePointerInJSWrapper(
env, new mx::Device(std::move(device)), result);
}

// static
std::optional<mx::Device> Type<mx::Device>::FromNode(napi_env env,
napi_value value) {
// Try creating a Device when value is a DeviceType.
auto type = ki::FromNode<mx::Device::DeviceType>(env, value);
if (type)
return mx::Device(*type);
// Otherwise try converting from Device.
return NodeObjToCppValue<mx::Device>(env, value);
}

} // namespace ki

void InitDevice(napi_env env, napi_value exports) {
ki::Set(env, exports,
"cpu", mx::Device::DeviceType::cpu,
"gpu", mx::Device::DeviceType::gpu,
"Device", ki::Class<mx::Device>(),
"defaultDevice", mx::default_device,
"setDefaultDevice", mx::set_default_device);
}
36 changes: 36 additions & 0 deletions src/device.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
#ifndef SRC_DEVICE_H_
#define SRC_DEVICE_H_

#include "src/bindings.h"

namespace ki {

template<>
struct Type<mx::Device::DeviceType> {
static constexpr const char* name = "DeviceType";
static napi_status ToNode(napi_env env,
mx::Device::DeviceType type,
napi_value* result);
static std::optional<mx::Device::DeviceType> FromNode(napi_env env,
napi_value value);
};

template<>
struct Type<mx::Device> {
static constexpr const char* name = "Device";

static mx::Device* Constructor(mx::Device::DeviceType type, int index);
static void Define(napi_env env,
napi_value constructor,
napi_value prototype);

static napi_status ToNode(napi_env env,
mx::Device device,
napi_value* result);
static std::optional<mx::Device> FromNode(napi_env env,
napi_value value);
};

} // namespace ki

#endif // SRC_DEVICE_H_
61 changes: 61 additions & 0 deletions src/stream.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
#include "src/stream.h"
#include "src/util.h"

namespace ki {

template<>
struct TypeBridge<mx::Stream> {
static inline void Finalize(mx::Stream* ptr) {
delete ptr;
}
};

// static
mx::Stream* Type<mx::Stream>::Constructor(int index, const mx::Device& device) {
return new mx::Stream(index, device);
}

// static
void Type<mx::Stream>::Define(napi_env env,
napi_value constructor,
napi_value prototype) {
DefineProperties(env, prototype,
Property("device", Getter(&mx::Stream::device)));
}

// static
napi_status Type<mx::Stream>::ToNode(napi_env env,
mx::Stream stream,
napi_value* result) {
return ManagePointerInJSWrapper(
env, new mx::Stream(std::move(stream)), result);
}

// static
std::optional<mx::Stream> Type<mx::Stream>::FromNode(napi_env env,
napi_value value) {
return NodeObjToCppValue<mx::Stream>(env, value);
}

// static
std::optional<mx::StreamOrDevice> Type<mx::StreamOrDevice>::FromNode(
napi_env env,
napi_value value) {
std::optional<mx::Stream> stream = Type<mx::Stream>::FromNode(env, value);
if (stream)
return *stream;
std::optional<mx::Device> device = Type<mx::Device>::FromNode(env, value);
if (device)
return *device;
return std::nullopt;
}

} // namespace ki

void InitStream(napi_env env, napi_value exports) {
ki::Set(env, exports,
"Stream", ki::Class<mx::Stream>(),
"defaultStream", mx::default_stream,
"setDefaultStream", mx::set_default_stream,
"newStream", mx::new_stream);
}
33 changes: 33 additions & 0 deletions src/stream.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
#ifndef SRC_STREAM_H_
#define SRC_STREAM_H_

#include "src/device.h"

namespace ki {

template<>
struct Type<mx::Stream> {
static constexpr const char* name = "Stream";

static mx::Stream* Constructor(int index, const mx::Device& device);
static void Define(napi_env env,
napi_value constructor,
napi_value prototype);

static napi_status ToNode(napi_env env,
mx::Stream stream,
napi_value* result);
static std::optional<mx::Stream> FromNode(napi_env env,
napi_value value);
};

template<>
struct Type<mx::StreamOrDevice> {
static constexpr const char* name = "StreamOrDevice";
static std::optional<mx::StreamOrDevice> FromNode(napi_env env,
napi_value value);
};

} // namespace ki

#endif // SRC_STREAM_H_
Loading

0 comments on commit 13f2aaa

Please sign in to comment.