Skip to content

Commit

Permalink
Simplify results unflatten when calling js func
Browse files Browse the repository at this point in the history
  • Loading branch information
zcbenz committed Apr 28, 2024
1 parent e8ffafe commit 6f4634f
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 45 deletions.
64 changes: 24 additions & 40 deletions src/transforms.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,17 @@

namespace {

// Unflatten the function call result.
inline napi_value UnflattenResults(napi_env env,
const std::vector<mx::array>& results) {
if (results.size() > 1)
return ki::ToNodeValue(env, results);
else
return ki::ToNodeValue(env, results[0]);
}

// Execute JS function with primals.
std::optional<std::vector<mx::array>> ExecuteWithPrimals(
std::vector<mx::array> ExecuteWithPrimals(
napi_env env,
napi_value js_func,
const std::vector<mx::array>& primals) {
Expand All @@ -17,15 +26,15 @@ std::optional<std::vector<mx::array>> ExecuteWithPrimals(
if (napi_make_callback(env, nullptr, js_func, js_func,
args.size(), args.empty() ? nullptr : &args.front(),
&result) != napi_ok) {
return std::nullopt;
return {};
}
// Convert result to vector.
if (auto a = ki::FromNodeTo<mx::array*>(env, result); a)
return std::vector<mx::array>{*a.value()};
if (auto v = ki::FromNodeTo<std::vector<mx::array>>(env, result); v)
return std::move(*v);
ki::ThrowError(env, "function does not return mx.array or Array of mx.array");
return std::nullopt;
return {};
}

// A template converter for ops that accept infinite |array|s.
Expand Down Expand Up @@ -58,8 +67,7 @@ JVPOpWrapper(
std::vector<mx::array> primals,
std::vector<mx::array> tangents) {
auto vfunc = [env, js_func](const std::vector<mx::array>& primals) {
return ExecuteWithPrimals(env, js_func, primals)
.value_or(std::vector<mx::array>());
return ExecuteWithPrimals(env, js_func, primals);
};
return func(vfunc, primals, tangents);
};
Expand All @@ -75,34 +83,22 @@ ValueAndGrad(napi_env env,
std::optional<std::variant<int, std::vector<int>>> argnums) {
// Reference the JS function as napi_value only lives at current tick.
ki::Persistent js_func(env, value);
// Get the indices of gradients.
std::vector<int> indices = ToIntVector(std::move(argnums.value_or(0)));
bool multi_gradients = indices.size() > 1;
// Call value_and_grad with the JS function.
auto func = mx::value_and_grad(
[js_func = std::move(js_func)](const std::vector<mx::array>& primals) {
return ExecuteWithPrimals(js_func.Env(), js_func.Value(), primals)
.value_or(std::vector<mx::array>());
}, std::move(indices));
return ExecuteWithPrimals(js_func.Env(), js_func.Value(), primals);
}, ToIntVector(std::move(argnums.value_or(0))));
// Return a JS function that converts JS args into primals.
return [env, func = std::move(func), multi_gradients](ki::Arguments* args) {
std::pair<napi_value, napi_value> ret;
return [env, func = std::move(func)](ki::Arguments* args)
-> std::pair<napi_value, napi_value> {
std::vector<mx::array> arrays;
if (!ReadArgs(args, &arrays))
return ret;
return {nullptr, nullptr};
auto results = func(std::move(arrays));
if (ki::IsExceptionPending(env))
return ret;
// Unflatten the results.
if (results.first.size() > 1)
ret.first = ki::ToNodeValue(env, results.first);
else
ret.first = ki::ToNodeValue(env, results.first[0]);
if (multi_gradients)
ret.second = ki::ToNodeValue(env, results.second);
else
ret.second = ki::ToNodeValue(env, results.second[0]);
return ret;
return {nullptr, nullptr};
return {UnflattenResults(env, results.first),
UnflattenResults(env, results.second)};
};
}

Expand All @@ -123,34 +119,22 @@ VMap(napi_env env,
std::optional<std::variant<int, std::vector<int>>> out_axes) {
// Reference the JS function as napi_value only lives at current tick.
ki::Persistent js_func(env, value);
// Whether the function has multiple outputs.
bool multi_outs = false;
if (out_axes) {
auto v = std::get_if<std::vector<int>>(&out_axes.value());
multi_outs = v && v->size() > 1;
}
// Call vmap with the JS function.
auto func = mx::vmap(
[js_func = std::move(js_func)](const std::vector<mx::array>& primals) {
return ExecuteWithPrimals(js_func.Env(), js_func.Value(), primals)
.value_or(std::vector<mx::array>());
return ExecuteWithPrimals(js_func.Env(), js_func.Value(), primals);
},
ToIntVector(std::move(in_axes.value_or(std::vector<int>()))),
ToIntVector(std::move(out_axes.value_or(std::vector<int>()))));
// Return a JS function that converts JS args into primals.
return [env, func = std::move(func), multi_outs](ki::Arguments* args)
-> napi_value {
return [env, func = std::move(func)](ki::Arguments* args) -> napi_value {
std::vector<mx::array> arrays;
if (!ReadArgs(args, &arrays))
return nullptr;
auto results = func(std::move(arrays));
if (ki::IsExceptionPending(env))
return nullptr;
// Unflatten the results.
if (multi_outs)
return ki::ToNodeValue(env, results);
else
return ki::ToNodeValue(env, results[0]);
return UnflattenResults(env, results);
};
}

Expand Down
7 changes: 2 additions & 5 deletions tests/vmap.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -245,10 +245,7 @@ describe('vmap', () => {
const a = mx.random.uniform(0, 1, [3, 4, 2]);
const cpuSvd = x => mx.linalg.svd(x, mx.cpu);

// FIXME(zcbenz): Since tree flatten is not supported yet, the results of
// svd is treated as returning 3 arrays instead of one Array. Specify out
// axes explicitly to make test work.
let [Us, Ss, Vts] = mx.vmap(cpuSvd, 0, [0, 0, 0])(a);
let [Us, Ss, Vts] = mx.vmap(cpuSvd, 0)(a);
assert.deepEqual(Us.shape, [a.shape[0], a.shape[1], a.shape[1]]);
assert.deepEqual(Ss.shape, [a.shape[0], a.shape[2]]);
assert.deepEqual(Vts.shape, [a.shape[0], a.shape[2], a.shape[2]]);
Expand All @@ -263,7 +260,7 @@ describe('vmap', () => {
M, 1e-5, 1e-7));
}

[Us, Ss, Vts] = mx.vmap(cpuSvd, 1, [0, 0, 0])(a);
[Us, Ss, Vts] = mx.vmap(cpuSvd, 1)(a);
assert.deepEqual(Us.shape, [a.shape[1], a.shape[0], a.shape[0]]);
assert.deepEqual(Ss.shape, [a.shape[1], a.shape[2]]);
assert.deepEqual(Vts.shape, [a.shape[1], a.shape[2], a.shape[2]]);
Expand Down

0 comments on commit 6f4634f

Please sign in to comment.