diff --git a/README.md b/README.md index 4862f786..2c9e9257 100644 --- a/README.md +++ b/README.md @@ -36,8 +36,8 @@ There are a few exceptions due to limitations of JavaScript: Some features are not supported yet and will be implemented in future: -* The function passed to `mx.grad` and `mx.valueAndGrad` must have all its - parameters being `mx.array`. +* The function passed to `mx.grad`/`mx.valueAndGrad`/`mx.vmap` must have all its + parameters taking `mx.array`. * When creating a `mx.array` from JavaScript Array, the Array must only include primitive values. diff --git a/lib/index.d.ts b/lib/index.d.ts index 656bf6f9..eb97c44e 100644 --- a/lib/index.d.ts +++ b/lib/index.d.ts @@ -151,7 +151,8 @@ export class array { square(s?: StreamOrDevice): array; squeeze(axis?: number | number[], s?: StreamOrDevice): array; swapaxes(axis1: number, axis2: number, s?: StreamOrDevice): array; - sum(s?: StreamOrDevice): array; + sum(keepdims?: boolean, s?: StreamOrDevice): array; + sum(axis?: number | number[], keepdims?: boolean, s?: StreamOrDevice): array; transpose(s?: StreamOrDevice): array; variance(indicesOrSections?: number | number[], keepdims?: boolean, ddof?: number, s?: StreamOrDevice): array; @@ -163,6 +164,7 @@ export class array { // Ops. export function abs(array: ScalarOrArray, s?: StreamOrDevice): array; export function add(array1: ScalarOrArray, array2: ScalarOrArray, s?: StreamOrDevice): array; +export function addmm(a: ScalarOrArray, b: ScalarOrArray, c: ScalarOrArray, alpha?: number, beta?: number, s?: StreamOrDevice): array; export function all(array: ScalarOrArray, keepdims?: boolean, s?: StreamOrDevice): array; export function all(array: ScalarOrArray, axis?: number | number[], keepdims?: boolean, s?: StreamOrDevice): array; export function allclose(array1: ScalarOrArray, array2: ScalarOrArray, rtol?: number, atol?: number, equalNan?: boolean, s?: StreamOrDevice): boolean; @@ -182,7 +184,7 @@ export function argmin(array: ScalarOrArray, keepdims?: boolean, s?: StreamOrDev export function argmin(array: ScalarOrArray, axis?: number, keepdims?: boolean, s?: StreamOrDevice): array; export function argpartition(array: ScalarOrArray, kth: number, axis?: number, s?: StreamOrDevice): array; export function argsort(array: ScalarOrArray, s?: StreamOrDevice): array; -export function arrayEqual(array1: ScalarOrArray, array2: ScalarOrArray, s?: StreamOrDevice): array; +export function arrayEqual(array1: ScalarOrArray, array2: ScalarOrArray, equalNan?: boolean, s?: StreamOrDevice): array; export function asStrided(array: ScalarOrArray, shape?: number[], strides?: number[], offset?: number, s?: StreamOrDevice): array; export function atleast1d(...arrays: array[]): array; export function atleast2d(...arrays: array[]): array; @@ -321,6 +323,8 @@ type GradFunctionScalar = (...args: array[]) => array type GradFunctionGeneric = (...args: array[]) => array[] export function grad(func: (...args: array[]) => array, argnums?: number | number[]): GradFunctionScalar; export function grad(func: (...args: array[]) => array[], argnums?: number | number[]): GradFunctionGeneric; +export function vmap(func: (...args: array[]) => array, inAxes?: number | number[], outAxis?: number): GradFunctionScalar; +export function vmap(func: (...args: array[]) => array[], inAxes?: number | number[], outAxes?: number[]): GradFunctionGeneric; // Metal. export namespace metal { diff --git a/src/ops.cc b/src/ops.cc index 9a915f84..86498d24 100644 --- a/src/ops.cc +++ b/src/ops.cc @@ -508,6 +508,16 @@ mx::array Tile(const mx::array& a, return mx::tile(a, std::get>(reps), s); } +mx::array AddMM(mx::array a, + mx::array b, + mx::array c, + std::optional alpha, + std::optional beta, + mx::StreamOrDevice s) { + return mx::addmm(std::move(a), std::move(b), std::move(c), + alpha.value_or(1), beta.value_or(1), s); +} + mx::array Diagonal(const mx::array& a, std::optional offset, std::optional axis1, @@ -666,7 +676,7 @@ void InitOps(napi_env env, napi_value exports) { "inner", &mx::inner, "outer", &mx::outer, "tile", &ops::Tile, - "addmm", &mx::addmm, + "addmm", &ops::AddMM, "diagonal", &ops::Diagonal, "diag", &ops::Diag, "atleast1d", NdOpWrapper(&mx::atleast_1d, &mx::atleast_1d), diff --git a/src/transforms.cc b/src/transforms.cc index 761c89a0..97581d54 100644 --- a/src/transforms.cc +++ b/src/transforms.cc @@ -76,15 +76,14 @@ ValueAndGrad(napi_env env, // Reference the JS function as napi_value only lives at current tick. ki::Persistent js_func(env, value); // Get the indices of gradients. - std::vector gradient_indices = ToIntVector( - std::move(argnums.value_or(std::vector{0}))); - bool multi_gradients = gradient_indices.size() > 1; + std::vector indices = ToIntVector(std::move(argnums.value_or(0))); + bool multi_gradients = indices.size() > 1; // Call value_and_grad with the JS function. auto func = mx::value_and_grad( [js_func = std::move(js_func)](const std::vector& primals) { return ExecuteWithPrimals(js_func.Env(), js_func.Value(), primals) .value_or(std::vector()); - }, std::move(gradient_indices)); + }, std::move(indices)); // Return a JS function that converts JS args into primals. return [env, func = std::move(func), multi_gradients](ki::Arguments* args) { std::pair ret; @@ -117,6 +116,44 @@ Grad(napi_env env, }; } +std::function +VMap(napi_env env, + napi_value value, + std::optional>> in_axes, + std::optional>> out_axes) { + // Reference the JS function as napi_value only lives at current tick. + ki::Persistent js_func(env, value); + // Whether the function has multiple outputs. + bool multi_outs = false; + if (out_axes) { + auto v = std::get_if>(&out_axes.value()); + multi_outs = v && v->size() > 1; + } + // Call vmap with the JS function. + auto func = mx::vmap( + [js_func = std::move(js_func)](const std::vector& primals) { + return ExecuteWithPrimals(js_func.Env(), js_func.Value(), primals) + .value_or(std::vector()); + }, + ToIntVector(std::move(in_axes.value_or(std::vector()))), + ToIntVector(std::move(out_axes.value_or(std::vector())))); + // Return a JS function that converts JS args into primals. + return [env, func = std::move(func), multi_outs](ki::Arguments* args) + -> napi_value { + std::vector arrays; + if (!ReadArgs(args, &arrays)) + return nullptr; + auto results = func(std::move(arrays)); + if (ki::IsExceptionPending(env)) + return nullptr; + // Unflatten the results. + if (multi_outs) + return ki::ToNodeValue(env, results); + else + return ki::ToNodeValue(env, results[0]); + }; +} + } // namespace transforms_ops void InitTransforms(napi_env env, napi_value exports) { @@ -126,5 +163,6 @@ void InitTransforms(napi_env env, napi_value exports) { "jvp", JVPOpWrapper(&mx::jvp), "vjp", JVPOpWrapper(&mx::vjp), "valueAndGrad", &transforms_ops::ValueAndGrad, - "grad", &transforms_ops::Grad); + "grad", &transforms_ops::Grad, + "vmap", &transforms_ops::VMap); } diff --git a/tests/fast.spec.ts b/tests/fast.spec.ts index 9210344b..22aa6e0a 100644 --- a/tests/fast.spec.ts +++ b/tests/fast.spec.ts @@ -336,7 +336,7 @@ describe('fast', () => { // TODO(zcbenz): Port the test_layer_norm_grad_params test. it('fastTransforms', () => { - const x = mx.random.uniform(0, 1, [2, 2, 8]); + let x = mx.random.uniform(0, 1, [2, 2, 8]); const defaults: [number, boolean, number, number, number] = [8, false, 10000.0, 1.0, 0]; const [dims, traditional, base, scale, offset] = defaults; @@ -357,7 +357,10 @@ describe('fast', () => { ); assertArrayAllTrue(mx.allclose(vjpOut[0], vjpFastOut[0])); - // TODO(zcbenz): Port the vmap test. + x = mx.random.uniform(0, 1, [2, 2, 2, 8]); + const vmapOut = mx.vmap(x => ropeOrig(x, ...defaults))(x); + const vmapFastOut = mx.vmap(x => mx.fast.rope(x, dims, traditional, base, scale, offset))(x); + assertArrayAllTrue(mx.allclose(vmapOut, vmapFastOut)); }); }); diff --git a/tests/vmap.spec.ts b/tests/vmap.spec.ts new file mode 100644 index 00000000..7965a7a3 --- /dev/null +++ b/tests/vmap.spec.ts @@ -0,0 +1,305 @@ +import mx from '..'; +import {assertArrayAllTrue} from './utils'; +import {assert} from 'chai'; + +describe('vmap', () => { + it('basics', () => { + const exp = x => mx.exp(x); + assert.throws(() => { + mx.vmap(exp)(mx.array(1.0)); + }, Error); + assert.throws(() => { + mx.vmap(exp, 2)(mx.array([0, 1])); + }, Error); + assert.throws(() => { + mx.vmap(exp, null, 2)(mx.array([0, 1])); + }, Error); + }); + + describe('unary', () => { + const ops = [ + mx.abs, + mx.cos, + mx.erf, + mx.erfinv, + mx.exp, + mx.log, + mx.log1p, + mx.log2, + mx.log10, + mx.logicalNot, + mx.negative, + mx.reciprocal, + mx.rsqrt, + mx.sigmoid, + mx.sign, + mx.sin, + mx.sqrt, + mx.square, + ]; + + for (const op of ops) { + let x = mx.arange(5); + let y = mx.vmap(x => op(x))(x); + assertArrayAllTrue(mx.arrayEqual(y, op(x), true)); + + x = mx.arange(8).reshape([2, 4]); + y = mx.vmap(x => op(x))(x); + assertArrayAllTrue(mx.arrayEqual(y, op(x), true)); + + y = mx.vmap(x => op(x), 1, 1)(x); + assertArrayAllTrue(mx.arrayEqual(y, op(x), true)); + }; + }); + + it('binary', () => { + const ops = [ + mx.add, + mx.divide, + mx.equal, + mx.greater, + mx.greaterEqual, + mx.less, + mx.lessEqual, + mx.logaddexp, + mx.maximum, + mx.minimum, + mx.multiply, + mx.power, + mx.subtract, + mx.logicalOr, + mx.logicalAnd, + ]; + for (const op of ops) { + let x = mx.random.uniform(0, 1, [5]); + let y = mx.random.uniform(0, 1, [5]); + let out = mx.vmap((needle, haystack) => op(needle, haystack))(x, y); + assertArrayAllTrue(mx.arrayEqual(out, op(x, y))); + + x = mx.random.uniform(0, 1, [2, 4]); + y = mx.random.uniform(0, 1, [2, 4]); + out = mx.vmap((needle, haystack) => op(needle, haystack))(x, y); + assertArrayAllTrue(mx.arrayEqual(out, op(x, y))); + + out = mx.vmap((needle, haystack) => op(needle, haystack), [0, 0], 0)(x, y); + assertArrayAllTrue(mx.arrayEqual(out, op(x, y))); + + y = mx.random.uniform(0, 1, [4, 2]); + out = mx.vmap((needle, haystack) => op(needle, haystack), [0, 1], 0)(x, y); + assertArrayAllTrue(mx.arrayEqual(out, op(x, y.T))); + + out = mx.vmap((needle, haystack) => op(needle, haystack), [0, 1], 1)(x, y); + assertArrayAllTrue(mx.arrayEqual(out, op(x, y.T).T)); + } + }); + + it('vmapIndexing', () => { + const x = mx.arange(16).reshape([2, 2, 2, 2]); + const inds = mx.array([[0, 1, 0], [1, 1, 0]], mx.int32); + + let out = mx.vmap((x, y) => x.index(y), [0, 0])(x, inds); + const expected = mx.array( + [ + [[[0, 1], [2, 3]], [[4, 5], [6, 7]], [[0, 1], [2, 3]]], + [[[12, 13], [14, 15]], [[12, 13], [14, 15]], [[8, 9], [10, 11]]], + ] + ); + assert(mx.arrayEqual(out, expected)); + + out = mx.vmap((x, y) => x.index(y), [0, null])(x, inds); + const expected2 = mx.array( + [ + [ + [[[0, 1], [2, 3]], [[4, 5], [6, 7]], [[0, 1], [2, 3]]], + [[[4, 5], [6, 7]], [[4, 5], [6, 7]], [[0, 1], [2, 3]]], + ], + [ + [[[8, 9], [10, 11]], [[12, 13], [14, 15]], [[8, 9], [10, 11]]], + [[[12, 13], [14, 15]], [[12, 13], [14, 15]], [[8, 9], [10, 11]]], + ], + ] + ); + assert(mx.arrayEqual(out, expected2)); + + out = mx.vmap((x, y) => x.index(y), [null, 0])(x, inds); + const expected3 = mx.array( + [ + [ + [[[0, 1], [2, 3]], [[4, 5], [6, 7]]], + [[[8, 9], [10, 11]], [[12, 13], [14, 15]]], + [[[0, 1], [2, 3]], [[4, 5], [6, 7]]], + ], + [ + [[[8, 9], [10, 11]], [[12, 13], [14, 15]]], + [[[8, 9], [10, 11]], [[12, 13], [14, 15]]], + [[[0, 1], [2, 3]], [[4, 5], [6, 7]]], + ], + ] + ); + assert(mx.arrayEqual(out, expected3)); + + const inds2 = mx.array([[0, 1, 0], [0, 1, 0]], mx.int32); + out = mx.vmap((x, y, z) => x.index(y, z), [null, 0, 0])(x, inds, inds2); + const expected4 = mx.array( + [ + [[[0, 1], [2, 3]], [[12, 13], [14, 15]], [[0, 1], [2, 3]]], + [[[8, 9], [10, 11]], [[12, 13], [14, 15]], [[0, 1], [2, 3]]], + ] + ); + assert(mx.arrayEqual(out, expected4)); + }); + + it('vmapReduce', () => { + let a = mx.ones([5, 5], mx.int32); + let out = mx.vmap((x) => x.sum())(a); + assertArrayAllTrue(mx.arrayEqual(out, mx.full([5], 5))); + + out = mx.vmap((x) => x.sum(null, true))(a); + assertArrayAllTrue(mx.arrayEqual(out, mx.full([5, 1], 5))); + + out = mx.vmap((x) => x.sum(0))(a); + assertArrayAllTrue(mx.arrayEqual(out, mx.full([5], 5))); + + a = mx.ones([5, 3, 2], mx.int32); + out = mx.vmap((x) => x.sum([0, 1]))(a); + assertArrayAllTrue(mx.arrayEqual(out, mx.full([5], 6))); + + a = mx.ones([5, 3, 2], mx.int32); + out = mx.vmap((x) => x.sum([0, 1]), [1])(a); + assertArrayAllTrue(mx.arrayEqual(out, mx.full([3], 10))); + + a = mx.ones([5, 3, 2], mx.int32); + out = mx.vmap((x) => x.sum([0, 1]), [2])(a); + assertArrayAllTrue(mx.arrayEqual(out, mx.full([2], 15))); + }); + + it('vmapArgreduce', () => { + const a = mx.array([[1, 2, 3], [2, 3, 1]]); + let out = mx.vmap(x => mx.argmin(x))(a); + let expected = mx.array([0, 2]); + assertArrayAllTrue(mx.arrayEqual(out, expected)); + + out = mx.vmap(x => mx.argmax(x))(a); + expected = mx.array([2, 1]); + assertArrayAllTrue(mx.arrayEqual(out, expected)); + }); + + it('vmapMean', () => { + let a = mx.reshape(mx.arange(8), [2, 4]); + let out = mx.vmap(x => mx.mean(x))(a); + let expected = mx.mean(a, 1); + assertArrayAllTrue(mx.allclose(out, expected)); + + a = mx.reshape(mx.arange(16), [2, 2, 4]); + out = mx.vmap(mx.vmap(x => mx.mean(x)))(a); + expected = mx.mean(a, 2); + assertArrayAllTrue(mx.allclose(out, expected)); + }); + + it('mismatchInputSizes', () => { + const a = mx.ones([10, 1]); + let b = mx.ones([1, 1, 1, 5]); + assert.throws(() => { + let out = mx.vmap((x, y) => mx.add(x, y))(a, b); + }, Error); + + b = mx.ones([10, 5]); + assert.throws(() => { + let out = mx.vmap((x, y) => mx.add(x, y), [0, 1])(a, b); + }, Error); + }); + + it('vmapMatmul', () => { + let a = mx.random.uniform(0, 1, [2, 3, 4]); + let b = mx.random.uniform(0, 1, [4, 3]); + + let out = mx.vmap((a, b) => mx.matmul(a, b), [0, -1])(a, b); + assertArrayAllTrue(mx.allclose(out, mx.matmul(a, b))); + + let c = mx.random.uniform(0, 1, [3]); + out = mx.vmap((c, a, b) => mx.addmm(c, a, b), [-1, 0, -1])(c, a, b); + assertArrayAllTrue(mx.allclose(out, mx.addmm(c, a, b))); + + b = mx.random.uniform(0, 1, [4, 2]); + out = mx.vmap((a, b) => mx.matmul(a, b), [1, -1], 1)(a, b); + let expected = mx.moveaxis(mx.matmul(mx.moveaxis(a, 1, 0), b), 0, 1); + assertArrayAllTrue(mx.allclose(out, expected)); + + c = mx.random.uniform(0, 1, [2]); + out = mx.vmap((c, a, b) => mx.addmm(c, a, b), [-1, 1, -1])(c, a, b); + assertArrayAllTrue(mx.allclose(out, mx.addmm(c, mx.moveaxis(a, 1, 0), b))); + + a = mx.random.uniform(0, 1, [2, 3, 4]); + b = mx.random.uniform(0, 1, [4, 2, 3]); + out = mx.vmap((a, b) => mx.matmul(a, b), [0, 1])(a, b); + expected = mx.matmul(a, mx.moveaxis(b, 1, 0)); + assertArrayAllTrue(mx.allclose(out, expected)); + + c = mx.random.uniform(0, 1, [3, 3, 2]); + out = mx.vmap((c, a, b) => mx.addmm(c, a, b), [2, 0, 1])(c, a, b); + expected = mx.addmm(mx.moveaxis(c, 2, 0), a, mx.moveaxis(b, 1, 0)); + assertArrayAllTrue(mx.allclose(out, expected)); + }); + + it('vmapSvd', () => { + const a = mx.random.uniform(0, 1, [3, 4, 2]); + const cpuSvd = x => mx.linalg.svd(x, mx.cpu); + + // FIXME(zcbenz): Since tree flatten is not supported yet, the results of + // svd is treated as returning 3 arrays instead of one Array. Specify out + // axes explicitly to make test work. + let [Us, Ss, Vts] = mx.vmap(cpuSvd, 0, [0, 0, 0])(a); + assert.deepEqual(Us.shape, [a.shape[0], a.shape[1], a.shape[1]]); + assert.deepEqual(Ss.shape, [a.shape[0], a.shape[2]]); + assert.deepEqual(Vts.shape, [a.shape[0], a.shape[2], a.shape[2]]); + + for (let i = 0; i < a.shape[0]; i++) { + const M = a.index(i); + const U = Us.index(i); + const S = Ss.index(i); + const Vt = Vts.index(i); + assertArrayAllTrue( + mx.allclose(mx.matmul(mx.matmul(U.index(mx.Slice(), mx.Slice(null, S.length)), mx.diag(S)), Vt), + M, 1e-5, 1e-7)); + } + + [Us, Ss, Vts] = mx.vmap(cpuSvd, 1, [0, 0, 0])(a); + assert.deepEqual(Us.shape, [a.shape[1], a.shape[0], a.shape[0]]); + assert.deepEqual(Ss.shape, [a.shape[1], a.shape[2]]); + assert.deepEqual(Vts.shape, [a.shape[1], a.shape[2], a.shape[2]]); + + for (let i = 0; i < a.shape[1]; i++) { + const M = a.index(mx.Slice(), i, mx.Slice()); + const U = Us.index(i); + const S = Ss.index(i); + const Vt = Vts.index(i); + assertArrayAllTrue( + mx.allclose(mx.matmul(mx.matmul(U.index(mx.Slice(), mx.Slice(null, S.length)), mx.diag(S)), Vt), + M, 1e-5, 1e-7)); + } + }); + + it('vmapInverse', () => { + let a = mx.random.uniform(0, 1, [3, 4, 4]); + const cpuInv = x => mx.linalg.inv(x, mx.cpu); + + let invs = mx.vmap(cpuInv, 0)(a); + for (let i = 0; i < a.shape[0]; i++) { + assertArrayAllTrue( + mx.allclose(mx.matmul(a.index(i), invs.index(i)), + mx.eye(a.shape[1]), 0, 1e-5)); + } + + a = mx.random.uniform(0, 1, [4, 3, 4]); + assert.throws(() => { + mx.eval(cpuInv(a)); + }, Error); + + invs = mx.vmap(cpuInv, [1])(a); + for (let i = 0; i < a.shape[1]; i++) { + assertArrayAllTrue( + mx.allclose(mx.matmul(a.index(mx.Slice(), i, mx.Slice()), invs.index(i)), + mx.eye(a.shape[0]), 0, 1e-5)); + } + }); +});