Skip to content

Commit 6f4634f

Browse files
committed
Simplify results unflatten when calling js func
1 parent e8ffafe commit 6f4634f

File tree

2 files changed

+26
-45
lines changed

2 files changed

+26
-45
lines changed

src/transforms.cc

Lines changed: 24 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,17 @@
33

44
namespace {
55

6+
// Unflatten the function call result.
7+
inline napi_value UnflattenResults(napi_env env,
8+
const std::vector<mx::array>& results) {
9+
if (results.size() > 1)
10+
return ki::ToNodeValue(env, results);
11+
else
12+
return ki::ToNodeValue(env, results[0]);
13+
}
14+
615
// Execute JS function with primals.
7-
std::optional<std::vector<mx::array>> ExecuteWithPrimals(
16+
std::vector<mx::array> ExecuteWithPrimals(
817
napi_env env,
918
napi_value js_func,
1019
const std::vector<mx::array>& primals) {
@@ -17,15 +26,15 @@ std::optional<std::vector<mx::array>> ExecuteWithPrimals(
1726
if (napi_make_callback(env, nullptr, js_func, js_func,
1827
args.size(), args.empty() ? nullptr : &args.front(),
1928
&result) != napi_ok) {
20-
return std::nullopt;
29+
return {};
2130
}
2231
// Convert result to vector.
2332
if (auto a = ki::FromNodeTo<mx::array*>(env, result); a)
2433
return std::vector<mx::array>{*a.value()};
2534
if (auto v = ki::FromNodeTo<std::vector<mx::array>>(env, result); v)
2635
return std::move(*v);
2736
ki::ThrowError(env, "function does not return mx.array or Array of mx.array");
28-
return std::nullopt;
37+
return {};
2938
}
3039

3140
// A template converter for ops that accept infinite |array|s.
@@ -58,8 +67,7 @@ JVPOpWrapper(
5867
std::vector<mx::array> primals,
5968
std::vector<mx::array> tangents) {
6069
auto vfunc = [env, js_func](const std::vector<mx::array>& primals) {
61-
return ExecuteWithPrimals(env, js_func, primals)
62-
.value_or(std::vector<mx::array>());
70+
return ExecuteWithPrimals(env, js_func, primals);
6371
};
6472
return func(vfunc, primals, tangents);
6573
};
@@ -75,34 +83,22 @@ ValueAndGrad(napi_env env,
7583
std::optional<std::variant<int, std::vector<int>>> argnums) {
7684
// Reference the JS function as napi_value only lives at current tick.
7785
ki::Persistent js_func(env, value);
78-
// Get the indices of gradients.
79-
std::vector<int> indices = ToIntVector(std::move(argnums.value_or(0)));
80-
bool multi_gradients = indices.size() > 1;
8186
// Call value_and_grad with the JS function.
8287
auto func = mx::value_and_grad(
8388
[js_func = std::move(js_func)](const std::vector<mx::array>& primals) {
84-
return ExecuteWithPrimals(js_func.Env(), js_func.Value(), primals)
85-
.value_or(std::vector<mx::array>());
86-
}, std::move(indices));
89+
return ExecuteWithPrimals(js_func.Env(), js_func.Value(), primals);
90+
}, ToIntVector(std::move(argnums.value_or(0))));
8791
// Return a JS function that converts JS args into primals.
88-
return [env, func = std::move(func), multi_gradients](ki::Arguments* args) {
89-
std::pair<napi_value, napi_value> ret;
92+
return [env, func = std::move(func)](ki::Arguments* args)
93+
-> std::pair<napi_value, napi_value> {
9094
std::vector<mx::array> arrays;
9195
if (!ReadArgs(args, &arrays))
92-
return ret;
96+
return {nullptr, nullptr};
9397
auto results = func(std::move(arrays));
9498
if (ki::IsExceptionPending(env))
95-
return ret;
96-
// Unflatten the results.
97-
if (results.first.size() > 1)
98-
ret.first = ki::ToNodeValue(env, results.first);
99-
else
100-
ret.first = ki::ToNodeValue(env, results.first[0]);
101-
if (multi_gradients)
102-
ret.second = ki::ToNodeValue(env, results.second);
103-
else
104-
ret.second = ki::ToNodeValue(env, results.second[0]);
105-
return ret;
99+
return {nullptr, nullptr};
100+
return {UnflattenResults(env, results.first),
101+
UnflattenResults(env, results.second)};
106102
};
107103
}
108104

@@ -123,34 +119,22 @@ VMap(napi_env env,
123119
std::optional<std::variant<int, std::vector<int>>> out_axes) {
124120
// Reference the JS function as napi_value only lives at current tick.
125121
ki::Persistent js_func(env, value);
126-
// Whether the function has multiple outputs.
127-
bool multi_outs = false;
128-
if (out_axes) {
129-
auto v = std::get_if<std::vector<int>>(&out_axes.value());
130-
multi_outs = v && v->size() > 1;
131-
}
132122
// Call vmap with the JS function.
133123
auto func = mx::vmap(
134124
[js_func = std::move(js_func)](const std::vector<mx::array>& primals) {
135-
return ExecuteWithPrimals(js_func.Env(), js_func.Value(), primals)
136-
.value_or(std::vector<mx::array>());
125+
return ExecuteWithPrimals(js_func.Env(), js_func.Value(), primals);
137126
},
138127
ToIntVector(std::move(in_axes.value_or(std::vector<int>()))),
139128
ToIntVector(std::move(out_axes.value_or(std::vector<int>()))));
140129
// Return a JS function that converts JS args into primals.
141-
return [env, func = std::move(func), multi_outs](ki::Arguments* args)
142-
-> napi_value {
130+
return [env, func = std::move(func)](ki::Arguments* args) -> napi_value {
143131
std::vector<mx::array> arrays;
144132
if (!ReadArgs(args, &arrays))
145133
return nullptr;
146134
auto results = func(std::move(arrays));
147135
if (ki::IsExceptionPending(env))
148136
return nullptr;
149-
// Unflatten the results.
150-
if (multi_outs)
151-
return ki::ToNodeValue(env, results);
152-
else
153-
return ki::ToNodeValue(env, results[0]);
137+
return UnflattenResults(env, results);
154138
};
155139
}
156140

tests/vmap.spec.ts

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -245,10 +245,7 @@ describe('vmap', () => {
245245
const a = mx.random.uniform(0, 1, [3, 4, 2]);
246246
const cpuSvd = x => mx.linalg.svd(x, mx.cpu);
247247

248-
// FIXME(zcbenz): Since tree flatten is not supported yet, the results of
249-
// svd is treated as returning 3 arrays instead of one Array. Specify out
250-
// axes explicitly to make test work.
251-
let [Us, Ss, Vts] = mx.vmap(cpuSvd, 0, [0, 0, 0])(a);
248+
let [Us, Ss, Vts] = mx.vmap(cpuSvd, 0)(a);
252249
assert.deepEqual(Us.shape, [a.shape[0], a.shape[1], a.shape[1]]);
253250
assert.deepEqual(Ss.shape, [a.shape[0], a.shape[2]]);
254251
assert.deepEqual(Vts.shape, [a.shape[0], a.shape[2], a.shape[2]]);
@@ -263,7 +260,7 @@ describe('vmap', () => {
263260
M, 1e-5, 1e-7));
264261
}
265262

266-
[Us, Ss, Vts] = mx.vmap(cpuSvd, 1, [0, 0, 0])(a);
263+
[Us, Ss, Vts] = mx.vmap(cpuSvd, 1)(a);
267264
assert.deepEqual(Us.shape, [a.shape[1], a.shape[0], a.shape[0]]);
268265
assert.deepEqual(Ss.shape, [a.shape[1], a.shape[2]]);
269266
assert.deepEqual(Vts.shape, [a.shape[1], a.shape[2], a.shape[2]]);

0 commit comments

Comments
 (0)