3
3
4
4
namespace {
5
5
6
+ // Unflatten the function call result.
7
+ inline napi_value UnflattenResults (napi_env env,
8
+ const std::vector<mx::array>& results) {
9
+ if (results.size () > 1 )
10
+ return ki::ToNodeValue (env, results);
11
+ else
12
+ return ki::ToNodeValue (env, results[0 ]);
13
+ }
14
+
6
15
// Execute JS function with primals.
7
- std::optional<std:: vector<mx::array> > ExecuteWithPrimals (
16
+ std::vector<mx::array> ExecuteWithPrimals (
8
17
napi_env env,
9
18
napi_value js_func,
10
19
const std::vector<mx::array>& primals) {
@@ -17,15 +26,15 @@ std::optional<std::vector<mx::array>> ExecuteWithPrimals(
17
26
if (napi_make_callback (env, nullptr , js_func, js_func,
18
27
args.size (), args.empty () ? nullptr : &args.front (),
19
28
&result) != napi_ok) {
20
- return std::nullopt ;
29
+ return {} ;
21
30
}
22
31
// Convert result to vector.
23
32
if (auto a = ki::FromNodeTo<mx::array*>(env, result); a)
24
33
return std::vector<mx::array>{*a.value ()};
25
34
if (auto v = ki::FromNodeTo<std::vector<mx::array>>(env, result); v)
26
35
return std::move (*v);
27
36
ki::ThrowError (env, " function does not return mx.array or Array of mx.array" );
28
- return std::nullopt ;
37
+ return {} ;
29
38
}
30
39
31
40
// A template converter for ops that accept infinite |array|s.
@@ -58,8 +67,7 @@ JVPOpWrapper(
58
67
std::vector<mx::array> primals,
59
68
std::vector<mx::array> tangents) {
60
69
auto vfunc = [env, js_func](const std::vector<mx::array>& primals) {
61
- return ExecuteWithPrimals (env, js_func, primals)
62
- .value_or (std::vector<mx::array>());
70
+ return ExecuteWithPrimals (env, js_func, primals);
63
71
};
64
72
return func (vfunc, primals, tangents);
65
73
};
@@ -75,34 +83,22 @@ ValueAndGrad(napi_env env,
75
83
std::optional<std::variant<int , std::vector<int >>> argnums) {
76
84
// Reference the JS function as napi_value only lives at current tick.
77
85
ki::Persistent js_func (env, value);
78
- // Get the indices of gradients.
79
- std::vector<int > indices = ToIntVector (std::move (argnums.value_or (0 )));
80
- bool multi_gradients = indices.size () > 1 ;
81
86
// Call value_and_grad with the JS function.
82
87
auto func = mx::value_and_grad (
83
88
[js_func = std::move (js_func)](const std::vector<mx::array>& primals) {
84
- return ExecuteWithPrimals (js_func.Env (), js_func.Value (), primals)
85
- .value_or (std::vector<mx::array>());
86
- }, std::move (indices));
89
+ return ExecuteWithPrimals (js_func.Env (), js_func.Value (), primals);
90
+ }, ToIntVector (std::move (argnums.value_or (0 ))));
87
91
// Return a JS function that converts JS args into primals.
88
- return [env, func = std::move (func), multi_gradients ](ki::Arguments* args) {
89
- std::pair<napi_value, napi_value> ret;
92
+ return [env, func = std::move (func)](ki::Arguments* args)
93
+ -> std::pair<napi_value, napi_value> {
90
94
std::vector<mx::array> arrays;
91
95
if (!ReadArgs (args, &arrays))
92
- return ret ;
96
+ return { nullptr , nullptr } ;
93
97
auto results = func (std::move (arrays));
94
98
if (ki::IsExceptionPending (env))
95
- return ret;
96
- // Unflatten the results.
97
- if (results.first .size () > 1 )
98
- ret.first = ki::ToNodeValue (env, results.first );
99
- else
100
- ret.first = ki::ToNodeValue (env, results.first [0 ]);
101
- if (multi_gradients)
102
- ret.second = ki::ToNodeValue (env, results.second );
103
- else
104
- ret.second = ki::ToNodeValue (env, results.second [0 ]);
105
- return ret;
99
+ return {nullptr , nullptr };
100
+ return {UnflattenResults (env, results.first ),
101
+ UnflattenResults (env, results.second )};
106
102
};
107
103
}
108
104
@@ -123,34 +119,22 @@ VMap(napi_env env,
123
119
std::optional<std::variant<int , std::vector<int >>> out_axes) {
124
120
// Reference the JS function as napi_value only lives at current tick.
125
121
ki::Persistent js_func (env, value);
126
- // Whether the function has multiple outputs.
127
- bool multi_outs = false ;
128
- if (out_axes) {
129
- auto v = std::get_if<std::vector<int >>(&out_axes.value ());
130
- multi_outs = v && v->size () > 1 ;
131
- }
132
122
// Call vmap with the JS function.
133
123
auto func = mx::vmap (
134
124
[js_func = std::move (js_func)](const std::vector<mx::array>& primals) {
135
- return ExecuteWithPrimals (js_func.Env (), js_func.Value (), primals)
136
- .value_or (std::vector<mx::array>());
125
+ return ExecuteWithPrimals (js_func.Env (), js_func.Value (), primals);
137
126
},
138
127
ToIntVector (std::move (in_axes.value_or (std::vector<int >()))),
139
128
ToIntVector (std::move (out_axes.value_or (std::vector<int >()))));
140
129
// Return a JS function that converts JS args into primals.
141
- return [env, func = std::move (func), multi_outs](ki::Arguments* args)
142
- -> napi_value {
130
+ return [env, func = std::move (func)](ki::Arguments* args) -> napi_value {
143
131
std::vector<mx::array> arrays;
144
132
if (!ReadArgs (args, &arrays))
145
133
return nullptr ;
146
134
auto results = func (std::move (arrays));
147
135
if (ki::IsExceptionPending (env))
148
136
return nullptr ;
149
- // Unflatten the results.
150
- if (multi_outs)
151
- return ki::ToNodeValue (env, results);
152
- else
153
- return ki::ToNodeValue (env, results[0 ]);
137
+ return UnflattenResults (env, results);
154
138
};
155
139
}
156
140
0 commit comments