diff --git a/deps/mlx b/deps/mlx index f76a49e5..f17536af 160000 --- a/deps/mlx +++ b/deps/mlx @@ -1 +1 @@ -Subproject commit f76a49e555e4fe76d46a3584bb72dce8287f24b2 +Subproject commit f17536af9c484ebb26a056292ed10acc32de3910 diff --git a/index.d.ts b/index.d.ts index 08ad638b..7ede6a5a 100644 --- a/index.d.ts +++ b/index.d.ts @@ -334,6 +334,7 @@ export namespace core { function tri(N: number, M?: number, k?: number, dtype?: Dtype, s?: StreamOrDevice): array; function tril(array: ScalarOrArray, k?: number, s?: StreamOrDevice): array; function triu(array: ScalarOrArray, k?: number, s?: StreamOrDevice): array; + function unflatten(array: ScalarOrArray, axis: number, shape: number[], s?: StreamOrDevice): array; function variance(array: ScalarOrArray, indicesOrSections?: number | number[], keepdims?: boolean, ddof?: number, s?: StreamOrDevice): array; function where(condition: ScalarOrArray, x: ScalarOrArray, y: ScalarOrArray, s?: StreamOrDevice): array; function zeros(shape: number | number[], dtype?: Dtype, s?: StreamOrDevice): array; diff --git a/src/indexing.cc b/src/indexing.cc index 7a0c7101..9d2cd5bf 100644 --- a/src/indexing.cc +++ b/src/indexing.cc @@ -304,21 +304,23 @@ mx::array IndexNDimensional(const mx::array* a, } if (unsqueeze_needed || squeeze_needed) { - mx::Shape out_shape; - size_t axis = 0; - for (const ArrayIndex& index : remaining_indices) { + std::vector squeeze_axes; + std::vector unsqueeze_axes; + for (int axis = 0; axis < remaining_indices.size(); ++axis) { + ArrayIndex& index = remaining_indices[axis]; if (unsqueeze_needed && std::holds_alternative(index)) { - out_shape.push_back(1); + unsqueeze_axes.push_back(axis - squeeze_axes.size()); } else if (squeeze_needed && std::holds_alternative(index)) { - axis++; - } else { - out_shape.push_back(gathered.shape(axis++)); + squeeze_axes.push_back(axis - unsqueeze_axes.size()); } } - - out_shape.insert(out_shape.end(), - gathered.shape().begin() + axis, gathered.shape().end()); - gathered = mx::reshape(std::move(gathered), std::move(out_shape)); + if (!squeeze_axes.empty()) { + gathered = mx::squeeze(std::move(gathered), std::move(squeeze_axes)); + } + if (!unsqueeze_axes.empty()) { + gathered = mx::expand_dims(std::move(gathered), + std::move(unsqueeze_axes)); + } } return gathered; diff --git a/src/ops.cc b/src/ops.cc index 163aa977..258febc8 100644 --- a/src/ops.cc +++ b/src/ops.cc @@ -764,6 +764,7 @@ void InitOps(napi_env env, napi_value exports) { ki::Set(env, exports, "reshape", &ops::Reshape, "flatten", &ops::Flatten, + "unflatten", &mx::unflatten, "squeeze", &ops::Squeeze, "expandDims", &ops::ExpandDims, "abs", &mx::abs, diff --git a/tests/autograd.spec.ts b/tests/autograd.spec.ts index f3d4f8bd..3b5bb4af 100644 --- a/tests/autograd.spec.ts +++ b/tests/autograd.spec.ts @@ -485,4 +485,22 @@ describe('autograd', () => { const expected = mx.array([[0, 0, 1, 0, 1], [1, 0, 0, 0, 1]], mx.float32); assertArrayAllTrue(mx.arrayEqual(out, expected)); }); + + it('flattenUnflattenVjps', () => { + const fun1 = (x: mx.array) => { + const y = mx.unflatten(x, 0, [2, 2]); + return y.sum(); + } + + let x = mx.zeros([4, 8]); + assert.deepEqual(mx.grad(fun1)(x).shape, [4, 8]); + + const fun2 = (x: mx.array) => { + const y = mx.flatten(x, 0, 2); + return y.sum(); + } + + x = mx.zeros([2, 4, 8]); + assert.deepEqual(mx.grad(fun2)(x).shape, [2, 4, 8]); + }); }); diff --git a/tests/compile.spec.ts b/tests/compile.spec.ts index 9c7a8c64..bba8eb95 100644 --- a/tests/compile.spec.ts +++ b/tests/compile.spec.ts @@ -355,6 +355,18 @@ describe('compile', function() { assertArrayAllTrue(mx.arrayEqual(fun2(x2), cfun2(x2))); }); + describe('shapelessCompileUnflatten', () => { + const x = mx.zeros([1, 1, 4 * 32]); + const fun = (x: mx.array) => mx.unflatten(x, -1, [4, -1]); + assert.deepEqual(mx.compile(fun, true)(x).shape, [1, 1, 4, 32]); + }); + + describe('shapelessCompileGather', () => { + const x = mx.zeros([1, 1, 32]); + const fun = (x: mx.array) => x.index(mx.Slice(), -1, mx.Slice()); + assert.deepEqual(mx.compile(fun, true)(x).shape, [1, 32]); + }); + describe('compileWithConstant', () => { it('float', () => { const fun = (x, y) => { diff --git a/tests/ops.spec.ts b/tests/ops.spec.ts index c555089b..7b4f7124 100644 --- a/tests/ops.spec.ts +++ b/tests/ops.spec.ts @@ -1367,7 +1367,7 @@ describe('ops', () => { }); it('asStrided', () => { - const x = mx.random.normal([128]).astype(mx.float32); + let x = mx.random.normal([128]).astype(mx.float32); const shapes = [[10, 10], [5, 5], [2, 20], [10]]; const strides = [[3, 3], [7, 1], [1, 5], [4]]; for (let i = 0; i < shapes.length; i++) { @@ -1378,6 +1378,10 @@ describe('ops', () => { assert.deepEqual(y.shape, shape); } } + + x = mx.random.uniform(0, 1, [32]); + let y = mx.asStrided(x, [x.size], [-1], x.size - 1); + assertArrayAllTrue(mx.arrayEqual(y, x.index(mx.Slice(null, null, -1)))); }); it('squeezeExpand', () => { @@ -1401,6 +1405,11 @@ describe('ops', () => { const x = mx.array([3, 1, 2]); const sortedX = mx.sort(x); assert.deepEqual(sortedX.tolist(), [1, 2, 3]); + + const a = mx.array([[4, 3], [2, 1], [5, 4], [3, 2]], mx.uint32); + const out = mx.argsort(a.index(mx.Slice(), 1)); + const expected = mx.array([1, 3, 0, 2], mx.uint32); + assertArrayAllTrue(mx.arrayEqual(out, expected)); }); it('argpartition', () => {