Skip to content

Commit

Permalink
Support tree input in eval
Browse files Browse the repository at this point in the history
  • Loading branch information
zcbenz committed Apr 29, 2024
1 parent 78c4dbd commit d3b03f2
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 7 deletions.
2 changes: 1 addition & 1 deletion deps/kizunapi
5 changes: 2 additions & 3 deletions src/transforms.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "src/array.h"
#include "src/ops.h"
#include "src/trees.h"

// Needed for detail::compile.
#include "mlx/transforms_impl.h"
Expand Down Expand Up @@ -45,9 +46,7 @@ inline
std::function<void(ki::Arguments* args)>
EvalOpWrapper(void(*func)(std::vector<mx::array>)) {
return [func](ki::Arguments* args) {
std::vector<mx::array> arrays;
if (ReadArgs(args, &arrays))
func(std::move(arrays));
func(TreeFlatten(args));
};
}

Expand Down
59 changes: 59 additions & 0 deletions src/trees.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
#include "src/array.h"
#include "src/trees.h"

void TreeVisit(napi_env env, napi_value tree,
const std::function<void(napi_env env,
napi_value tree)>& callback) {
std::function<void(napi_env env, napi_value value)> recurse;
recurse = [&callback, &recurse](napi_env env, napi_value value) {
// Iterate arrays.
if (ki::IsArray(env, value)) {
uint32_t length = 0;
napi_get_array_length(env, value, &length);
for (uint32_t i = 0; i < length; ++i) {
napi_value item;
if (napi_get_element(env, value, i, &item) != napi_ok)
break;
recurse(env, item);
}
return;
}
// Only iterate objects when they do not wrap a native instance.
void* ptr;
if (napi_unwrap(env, value, &ptr) != napi_ok) {
auto m = ki::FromNodeTo<std::map<napi_value, napi_value>>(env, value);
if (m) {
for (auto [key, item] : *m)
recurse(env, item);
return;
}
}
callback(env, value);
};

recurse(env, tree);
}

std::vector<mx::array> TreeFlatten(napi_env env, napi_value tree, bool strict) {
std::vector<mx::array> flat;
TreeVisit(env, tree, [strict, &flat](napi_env env, napi_value value) {
if (auto a = ki::FromNodeTo<mx::array*>(env, value); a) {
flat.push_back(*a.value());
} else if (strict) {
throw std::invalid_argument(
"[TreeFlatten] The argument should contain only arrays");
}
});
return flat;
}

std::vector<mx::array> TreeFlatten(ki::Arguments* args, bool strict) {
if (args->Length() == 1)
return TreeFlatten(args->Env(), (*args)[0], strict);
std::vector<mx::array> ret;
for (uint32_t i = 0; i < args->Length(); ++i) {
std::vector<mx::array> flat = TreeFlatten(args->Env(), (*args)[i], strict);
std::move(flat.begin(), flat.end(), std::back_inserter(ret));
}
return ret;
}
9 changes: 9 additions & 0 deletions src/trees.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
#include "src/bindings.h"

void TreeVisit(napi_env env, napi_value tree,
const std::function<void(napi_env env,
napi_value tree)>& callback);

std::vector<mx::array> TreeFlatten(napi_env env, napi_value tree,
bool strict = false);
std::vector<mx::array> TreeFlatten(ki::Arguments* args, bool strict = false);
6 changes: 3 additions & 3 deletions tests/eval.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import {assert} from 'chai';
describe('eval', () => {
it('eval', () => {
const arrs: mx.array[] = [];
for(let i = 0; i < 4; i++) {
for (let i = 0; i < 4; i++) {
arrs.push(mx.ones([2, 2]));
}
mx.eval(...arrs);
Expand All @@ -29,8 +29,8 @@ describe('eval', () => {
const one = mx.array(1);
let x = mx.add(mx.add(one, 1), 1);
let y = 0;
let z = true;
mx.eval(x, y, z);
let z = 'hello' as unknown as number; // pass typecheck to test native code
mx.eval([x, y, z]);
assert.equal(x.item(), 3);
});

Expand Down

0 comments on commit d3b03f2

Please sign in to comment.