Skip to content

Commit

Permalink
Add vmap
Browse files Browse the repository at this point in the history
  • Loading branch information
zcbenz committed Apr 28, 2024
1 parent 39231f7 commit e8ffafe
Show file tree
Hide file tree
Showing 6 changed files with 372 additions and 12 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ There are a few exceptions due to limitations of JavaScript:

Some features are not supported yet and will be implemented in future:

* The function passed to `mx.grad` and `mx.valueAndGrad` must have all its
parameters being `mx.array`.
* The function passed to `mx.grad`/`mx.valueAndGrad`/`mx.vmap` must have all its
parameters taking `mx.array`.
* When creating a `mx.array` from JavaScript Array, the Array must only include
primitive values.

Expand Down
8 changes: 6 additions & 2 deletions lib/index.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,8 @@ export class array {
square(s?: StreamOrDevice): array;
squeeze(axis?: number | number[], s?: StreamOrDevice): array;
swapaxes(axis1: number, axis2: number, s?: StreamOrDevice): array;
sum(s?: StreamOrDevice): array;
sum(keepdims?: boolean, s?: StreamOrDevice): array;
sum(axis?: number | number[], keepdims?: boolean, s?: StreamOrDevice): array;
transpose(s?: StreamOrDevice): array;
variance(indicesOrSections?: number | number[], keepdims?: boolean, ddof?: number, s?: StreamOrDevice): array;

Expand All @@ -163,6 +164,7 @@ export class array {
// Ops.
export function abs(array: ScalarOrArray, s?: StreamOrDevice): array;
export function add(array1: ScalarOrArray, array2: ScalarOrArray, s?: StreamOrDevice): array;
export function addmm(a: ScalarOrArray, b: ScalarOrArray, c: ScalarOrArray, alpha?: number, beta?: number, s?: StreamOrDevice): array;
export function all(array: ScalarOrArray, keepdims?: boolean, s?: StreamOrDevice): array;
export function all(array: ScalarOrArray, axis?: number | number[], keepdims?: boolean, s?: StreamOrDevice): array;
export function allclose(array1: ScalarOrArray, array2: ScalarOrArray, rtol?: number, atol?: number, equalNan?: boolean, s?: StreamOrDevice): boolean;
Expand All @@ -182,7 +184,7 @@ export function argmin(array: ScalarOrArray, keepdims?: boolean, s?: StreamOrDev
export function argmin(array: ScalarOrArray, axis?: number, keepdims?: boolean, s?: StreamOrDevice): array;
export function argpartition(array: ScalarOrArray, kth: number, axis?: number, s?: StreamOrDevice): array;
export function argsort(array: ScalarOrArray, s?: StreamOrDevice): array;
export function arrayEqual(array1: ScalarOrArray, array2: ScalarOrArray, s?: StreamOrDevice): array;
export function arrayEqual(array1: ScalarOrArray, array2: ScalarOrArray, equalNan?: boolean, s?: StreamOrDevice): array;
export function asStrided(array: ScalarOrArray, shape?: number[], strides?: number[], offset?: number, s?: StreamOrDevice): array;
export function atleast1d(...arrays: array[]): array;
export function atleast2d(...arrays: array[]): array;
Expand Down Expand Up @@ -321,6 +323,8 @@ type GradFunctionScalar = (...args: array[]) => array
type GradFunctionGeneric = (...args: array[]) => array[]
export function grad(func: (...args: array[]) => array, argnums?: number | number[]): GradFunctionScalar;
export function grad(func: (...args: array[]) => array[], argnums?: number | number[]): GradFunctionGeneric;
export function vmap(func: (...args: array[]) => array, inAxes?: number | number[], outAxis?: number): GradFunctionScalar;
export function vmap(func: (...args: array[]) => array[], inAxes?: number | number[], outAxes?: number[]): GradFunctionGeneric;

// Metal.
export namespace metal {
Expand Down
12 changes: 11 additions & 1 deletion src/ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -508,6 +508,16 @@ mx::array Tile(const mx::array& a,
return mx::tile(a, std::get<std::vector<int>>(reps), s);
}

mx::array AddMM(mx::array a,
mx::array b,
mx::array c,
std::optional<float> alpha,
std::optional<float> beta,
mx::StreamOrDevice s) {
return mx::addmm(std::move(a), std::move(b), std::move(c),
alpha.value_or(1), beta.value_or(1), s);
}

mx::array Diagonal(const mx::array& a,
std::optional<int> offset,
std::optional<int> axis1,
Expand Down Expand Up @@ -666,7 +676,7 @@ void InitOps(napi_env env, napi_value exports) {
"inner", &mx::inner,
"outer", &mx::outer,
"tile", &ops::Tile,
"addmm", &mx::addmm,
"addmm", &ops::AddMM,
"diagonal", &ops::Diagonal,
"diag", &ops::Diag,
"atleast1d", NdOpWrapper(&mx::atleast_1d, &mx::atleast_1d),
Expand Down
48 changes: 43 additions & 5 deletions src/transforms.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,15 +76,14 @@ ValueAndGrad(napi_env env,
// Reference the JS function as napi_value only lives at current tick.
ki::Persistent js_func(env, value);
// Get the indices of gradients.
std::vector<int> gradient_indices = ToIntVector(
std::move(argnums.value_or(std::vector<int>{0})));
bool multi_gradients = gradient_indices.size() > 1;
std::vector<int> indices = ToIntVector(std::move(argnums.value_or(0)));
bool multi_gradients = indices.size() > 1;
// Call value_and_grad with the JS function.
auto func = mx::value_and_grad(
[js_func = std::move(js_func)](const std::vector<mx::array>& primals) {
return ExecuteWithPrimals(js_func.Env(), js_func.Value(), primals)
.value_or(std::vector<mx::array>());
}, std::move(gradient_indices));
}, std::move(indices));
// Return a JS function that converts JS args into primals.
return [env, func = std::move(func), multi_gradients](ki::Arguments* args) {
std::pair<napi_value, napi_value> ret;
Expand Down Expand Up @@ -117,6 +116,44 @@ Grad(napi_env env,
};
}

std::function<napi_value(ki::Arguments*)>
VMap(napi_env env,
napi_value value,
std::optional<std::variant<int, std::vector<int>>> in_axes,
std::optional<std::variant<int, std::vector<int>>> out_axes) {
// Reference the JS function as napi_value only lives at current tick.
ki::Persistent js_func(env, value);
// Whether the function has multiple outputs.
bool multi_outs = false;
if (out_axes) {
auto v = std::get_if<std::vector<int>>(&out_axes.value());
multi_outs = v && v->size() > 1;
}
// Call vmap with the JS function.
auto func = mx::vmap(
[js_func = std::move(js_func)](const std::vector<mx::array>& primals) {
return ExecuteWithPrimals(js_func.Env(), js_func.Value(), primals)
.value_or(std::vector<mx::array>());
},
ToIntVector(std::move(in_axes.value_or(std::vector<int>()))),
ToIntVector(std::move(out_axes.value_or(std::vector<int>()))));
// Return a JS function that converts JS args into primals.
return [env, func = std::move(func), multi_outs](ki::Arguments* args)
-> napi_value {
std::vector<mx::array> arrays;
if (!ReadArgs(args, &arrays))
return nullptr;
auto results = func(std::move(arrays));
if (ki::IsExceptionPending(env))
return nullptr;
// Unflatten the results.
if (multi_outs)
return ki::ToNodeValue(env, results);
else
return ki::ToNodeValue(env, results[0]);
};
}

} // namespace transforms_ops

void InitTransforms(napi_env env, napi_value exports) {
Expand All @@ -126,5 +163,6 @@ void InitTransforms(napi_env env, napi_value exports) {
"jvp", JVPOpWrapper(&mx::jvp),
"vjp", JVPOpWrapper(&mx::vjp),
"valueAndGrad", &transforms_ops::ValueAndGrad,
"grad", &transforms_ops::Grad);
"grad", &transforms_ops::Grad,
"vmap", &transforms_ops::VMap);
}
7 changes: 5 additions & 2 deletions tests/fast.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ describe('fast', () => {
// TODO(zcbenz): Port the test_layer_norm_grad_params test.

it('fastTransforms', () => {
const x = mx.random.uniform(0, 1, [2, 2, 8]);
let x = mx.random.uniform(0, 1, [2, 2, 8]);

const defaults: [number, boolean, number, number, number] = [8, false, 10000.0, 1.0, 0];
const [dims, traditional, base, scale, offset] = defaults;
Expand All @@ -357,7 +357,10 @@ describe('fast', () => {
);
assertArrayAllTrue(mx.allclose(vjpOut[0], vjpFastOut[0]));

// TODO(zcbenz): Port the vmap test.
x = mx.random.uniform(0, 1, [2, 2, 2, 8]);
const vmapOut = mx.vmap(x => ropeOrig(x, ...defaults))(x);
const vmapFastOut = mx.vmap(x => mx.fast.rope(x, dims, traditional, base, scale, offset))(x);
assertArrayAllTrue(mx.allclose(vmapOut, vmapFastOut));
});
});

Expand Down
Loading

0 comments on commit e8ffafe

Please sign in to comment.