Skip to content

Commit d3b03f2

Browse files
committed
Support tree input in eval
1 parent 78c4dbd commit d3b03f2

File tree

5 files changed

+74
-7
lines changed

5 files changed

+74
-7
lines changed

deps/kizunapi

src/transforms.cc

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include "src/array.h"
22
#include "src/ops.h"
3+
#include "src/trees.h"
34

45
// Needed for detail::compile.
56
#include "mlx/transforms_impl.h"
@@ -45,9 +46,7 @@ inline
4546
std::function<void(ki::Arguments* args)>
4647
EvalOpWrapper(void(*func)(std::vector<mx::array>)) {
4748
return [func](ki::Arguments* args) {
48-
std::vector<mx::array> arrays;
49-
if (ReadArgs(args, &arrays))
50-
func(std::move(arrays));
49+
func(TreeFlatten(args));
5150
};
5251
}
5352

src/trees.cc

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
#include "src/array.h"
2+
#include "src/trees.h"
3+
4+
void TreeVisit(napi_env env, napi_value tree,
5+
const std::function<void(napi_env env,
6+
napi_value tree)>& callback) {
7+
std::function<void(napi_env env, napi_value value)> recurse;
8+
recurse = [&callback, &recurse](napi_env env, napi_value value) {
9+
// Iterate arrays.
10+
if (ki::IsArray(env, value)) {
11+
uint32_t length = 0;
12+
napi_get_array_length(env, value, &length);
13+
for (uint32_t i = 0; i < length; ++i) {
14+
napi_value item;
15+
if (napi_get_element(env, value, i, &item) != napi_ok)
16+
break;
17+
recurse(env, item);
18+
}
19+
return;
20+
}
21+
// Only iterate objects when they do not wrap a native instance.
22+
void* ptr;
23+
if (napi_unwrap(env, value, &ptr) != napi_ok) {
24+
auto m = ki::FromNodeTo<std::map<napi_value, napi_value>>(env, value);
25+
if (m) {
26+
for (auto [key, item] : *m)
27+
recurse(env, item);
28+
return;
29+
}
30+
}
31+
callback(env, value);
32+
};
33+
34+
recurse(env, tree);
35+
}
36+
37+
std::vector<mx::array> TreeFlatten(napi_env env, napi_value tree, bool strict) {
38+
std::vector<mx::array> flat;
39+
TreeVisit(env, tree, [strict, &flat](napi_env env, napi_value value) {
40+
if (auto a = ki::FromNodeTo<mx::array*>(env, value); a) {
41+
flat.push_back(*a.value());
42+
} else if (strict) {
43+
throw std::invalid_argument(
44+
"[TreeFlatten] The argument should contain only arrays");
45+
}
46+
});
47+
return flat;
48+
}
49+
50+
std::vector<mx::array> TreeFlatten(ki::Arguments* args, bool strict) {
51+
if (args->Length() == 1)
52+
return TreeFlatten(args->Env(), (*args)[0], strict);
53+
std::vector<mx::array> ret;
54+
for (uint32_t i = 0; i < args->Length(); ++i) {
55+
std::vector<mx::array> flat = TreeFlatten(args->Env(), (*args)[i], strict);
56+
std::move(flat.begin(), flat.end(), std::back_inserter(ret));
57+
}
58+
return ret;
59+
}

src/trees.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
#include "src/bindings.h"
2+
3+
void TreeVisit(napi_env env, napi_value tree,
4+
const std::function<void(napi_env env,
5+
napi_value tree)>& callback);
6+
7+
std::vector<mx::array> TreeFlatten(napi_env env, napi_value tree,
8+
bool strict = false);
9+
std::vector<mx::array> TreeFlatten(ki::Arguments* args, bool strict = false);

tests/eval.spec.ts

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ import {assert} from 'chai';
55
describe('eval', () => {
66
it('eval', () => {
77
const arrs: mx.array[] = [];
8-
for(let i = 0; i < 4; i++) {
8+
for (let i = 0; i < 4; i++) {
99
arrs.push(mx.ones([2, 2]));
1010
}
1111
mx.eval(...arrs);
@@ -29,8 +29,8 @@ describe('eval', () => {
2929
const one = mx.array(1);
3030
let x = mx.add(mx.add(one, 1), 1);
3131
let y = 0;
32-
let z = true;
33-
mx.eval(x, y, z);
32+
let z = 'hello' as unknown as number; // pass typecheck to test native code
33+
mx.eval([x, y, z]);
3434
assert.equal(x.item(), 3);
3535
});
3636

0 commit comments

Comments
 (0)