Skip to content

Commit

Permalink
Update to MLX f17536a
Browse files Browse the repository at this point in the history
  • Loading branch information
zcbenz committed Dec 28, 2024
1 parent 718ca88 commit 66c8754
Show file tree
Hide file tree
Showing 7 changed files with 56 additions and 13 deletions.
2 changes: 1 addition & 1 deletion deps/mlx
Submodule mlx updated 91 files
+4 −4 .circleci/config.yml
+5 −3 CMakeLists.txt
+11 −11 benchmarks/cpp/autograd.cpp
+7 −7 benchmarks/cpp/compare_devices.cpp
+82 −82 benchmarks/cpp/irregular_strides.cpp
+130 −128 benchmarks/cpp/single_ops.cpp
+1 −1 docs/src/install.rst
+1 −0 docs/src/python/ops.rst
+5 −5 examples/cpp/distributed.cpp
+14 −14 examples/cpp/linear_regression.cpp
+13 −13 examples/cpp/logistic_regression.cpp
+10 −10 examples/cpp/metal_capture.cpp
+19 −19 examples/cpp/tutorial.cpp
+1 −2 examples/extensions/CMakeLists.txt
+60 −60 examples/extensions/axpby/axpby.cpp
+30 −24 examples/extensions/axpby/axpby.h
+1 −3 examples/extensions/bindings.cpp
+2 −2 examples/extensions/pyproject.toml
+1 −1 examples/extensions/requirements.txt
+13 −1 mlx/array.cpp
+0 −1 mlx/backend/accelerate/primitives.cpp
+11 −3 mlx/backend/common/CMakeLists.txt
+2 −4 mlx/backend/common/common.cpp
+9 −4 mlx/backend/common/compiled_cpu.cpp
+0 −1 mlx/backend/common/default_primitives.cpp
+38 −0 mlx/backend/common/make_compiled_preamble.ps1
+6 −5 mlx/backend/common/make_compiled_preamble.sh
+28 −38 mlx/backend/common/primitives.cpp
+6 −0 mlx/backend/common/utils.h
+34 −14 mlx/backend/metal/allocator.cpp
+4 −2 mlx/backend/metal/allocator.h
+4 −2 mlx/backend/metal/binary.cpp
+8 −3 mlx/backend/metal/device.cpp
+10 −10 mlx/backend/metal/indexing.cpp
+19 −11 mlx/backend/metal/jit_kernels.cpp
+15 −15 mlx/backend/metal/kernels/binary.metal
+15 −15 mlx/backend/metal/kernels/binary_two.metal
+6 −6 mlx/backend/metal/kernels/reduce.metal
+11 −11 mlx/backend/metal/kernels/ternary.metal
+3 −3 mlx/backend/metal/kernels/unary.metal
+7 −18 mlx/backend/metal/kernels/utils.h
+28 −20 mlx/backend/metal/primitives.cpp
+16 −16 mlx/backend/metal/reduce.cpp
+7 −24 mlx/backend/metal/slicing.cpp
+11 −3 mlx/backend/metal/sort.cpp
+4 −4 mlx/backend/metal/ternary.cpp
+4 −2 mlx/backend/metal/unary.cpp
+2 −0 mlx/backend/no_cpu/primitives.cpp
+2 −0 mlx/backend/no_metal/primitives.cpp
+1 −0 mlx/compile.cpp
+18 −6 mlx/fast.cpp
+21 −20 mlx/io/gguf.cpp
+4 −2 mlx/io/load.cpp
+14 −2 mlx/io/load.h
+81 −31 mlx/ops.cpp
+3 −0 mlx/ops.h
+102 −0 mlx/primitives.cpp
+44 −10 mlx/primitives.h
+1 −1 mlx/transforms.cpp
+2 −2 pyproject.toml
+1 −1 python/mlx/nn/layers/upsample.py
+296 −266 python/src/array.cpp
+16 −16 python/src/buffer.h
+120 −118 python/src/convert.cpp
+17 −12 python/src/convert.h
+26 −23 python/src/device.cpp
+13 −12 python/src/distributed.cpp
+12 −12 python/src/fast.cpp
+57 −58 python/src/fft.cpp
+113 −119 python/src/indexing.cpp
+18 −15 python/src/indexing.h
+18 −20 python/src/linalg.cpp
+33 −30 python/src/load.cpp
+11 −9 python/src/load.h
+14 −14 python/src/metal.cpp
+539 −440 python/src/ops.cpp
+74 −65 python/src/random.cpp
+19 −18 python/src/stream.cpp
+63 −60 python/src/transforms.cpp
+17 −17 python/src/trees.cpp
+8 −8 python/src/trees.h
+24 −21 python/src/utils.cpp
+12 −11 python/src/utils.h
+12 −2 python/tests/test_array.py
+15 −0 python/tests/test_autograd.py
+16 −0 python/tests/test_compile.py
+6 −0 python/tests/test_nn.py
+10 −0 python/tests/test_ops.py
+10 −3 setup.py
+2 −2 tests/autograd_tests.cpp
+18 −1 tests/ops_tests.cpp
1 change: 1 addition & 0 deletions index.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,7 @@ export namespace core {
function tri(N: number, M?: number, k?: number, dtype?: Dtype, s?: StreamOrDevice): array;
function tril(array: ScalarOrArray, k?: number, s?: StreamOrDevice): array;
function triu(array: ScalarOrArray, k?: number, s?: StreamOrDevice): array;
function unflatten(array: ScalarOrArray, axis: number, shape: number[], s?: StreamOrDevice): array;
function variance(array: ScalarOrArray, indicesOrSections?: number | number[], keepdims?: boolean, ddof?: number, s?: StreamOrDevice): array;
function where(condition: ScalarOrArray, x: ScalarOrArray, y: ScalarOrArray, s?: StreamOrDevice): array;
function zeros(shape: number | number[], dtype?: Dtype, s?: StreamOrDevice): array;
Expand Down
24 changes: 13 additions & 11 deletions src/indexing.cc
Original file line number Diff line number Diff line change
Expand Up @@ -304,21 +304,23 @@ mx::array IndexNDimensional(const mx::array* a,
}

if (unsqueeze_needed || squeeze_needed) {
mx::Shape out_shape;
size_t axis = 0;
for (const ArrayIndex& index : remaining_indices) {
std::vector<int> squeeze_axes;
std::vector<int> unsqueeze_axes;
for (int axis = 0; axis < remaining_indices.size(); ++axis) {
ArrayIndex& index = remaining_indices[axis];
if (unsqueeze_needed && std::holds_alternative<std::monostate>(index)) {
out_shape.push_back(1);
unsqueeze_axes.push_back(axis - squeeze_axes.size());
} else if (squeeze_needed && std::holds_alternative<int>(index)) {
axis++;
} else {
out_shape.push_back(gathered.shape(axis++));
squeeze_axes.push_back(axis - unsqueeze_axes.size());
}
}

out_shape.insert(out_shape.end(),
gathered.shape().begin() + axis, gathered.shape().end());
gathered = mx::reshape(std::move(gathered), std::move(out_shape));
if (!squeeze_axes.empty()) {
gathered = mx::squeeze(std::move(gathered), std::move(squeeze_axes));
}
if (!unsqueeze_axes.empty()) {
gathered = mx::expand_dims(std::move(gathered),
std::move(unsqueeze_axes));
}
}

return gathered;
Expand Down
1 change: 1 addition & 0 deletions src/ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -764,6 +764,7 @@ void InitOps(napi_env env, napi_value exports) {
ki::Set(env, exports,
"reshape", &ops::Reshape,
"flatten", &ops::Flatten,
"unflatten", &mx::unflatten,
"squeeze", &ops::Squeeze,
"expandDims", &ops::ExpandDims,
"abs", &mx::abs,
Expand Down
18 changes: 18 additions & 0 deletions tests/autograd.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -485,4 +485,22 @@ describe('autograd', () => {
const expected = mx.array([[0, 0, 1, 0, 1], [1, 0, 0, 0, 1]], mx.float32);
assertArrayAllTrue(mx.arrayEqual(out, expected));
});

it('flattenUnflattenVjps', () => {
const fun1 = (x: mx.array) => {
const y = mx.unflatten(x, 0, [2, 2]);
return y.sum();
}

let x = mx.zeros([4, 8]);
assert.deepEqual(mx.grad(fun1)(x).shape, [4, 8]);

const fun2 = (x: mx.array) => {
const y = mx.flatten(x, 0, 2);
return y.sum();
}

x = mx.zeros([2, 4, 8]);
assert.deepEqual(mx.grad(fun2)(x).shape, [2, 4, 8]);
});
});
12 changes: 12 additions & 0 deletions tests/compile.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,18 @@ describe('compile', function() {
assertArrayAllTrue(mx.arrayEqual(fun2(x2), cfun2(x2)));
});

describe('shapelessCompileUnflatten', () => {
const x = mx.zeros([1, 1, 4 * 32]);
const fun = (x: mx.array) => mx.unflatten(x, -1, [4, -1]);
assert.deepEqual(mx.compile(fun, true)(x).shape, [1, 1, 4, 32]);
});

describe('shapelessCompileGather', () => {
const x = mx.zeros([1, 1, 32]);
const fun = (x: mx.array) => x.index(mx.Slice(), -1, mx.Slice());
assert.deepEqual(mx.compile(fun, true)(x).shape, [1, 32]);
});

describe('compileWithConstant', () => {
it('float', () => {
const fun = (x, y) => {
Expand Down
11 changes: 10 additions & 1 deletion tests/ops.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1367,7 +1367,7 @@ describe('ops', () => {
});

it('asStrided', () => {
const x = mx.random.normal([128]).astype(mx.float32);
let x = mx.random.normal([128]).astype(mx.float32);
const shapes = [[10, 10], [5, 5], [2, 20], [10]];
const strides = [[3, 3], [7, 1], [1, 5], [4]];
for (let i = 0; i < shapes.length; i++) {
Expand All @@ -1378,6 +1378,10 @@ describe('ops', () => {
assert.deepEqual(y.shape, shape);
}
}

x = mx.random.uniform(0, 1, [32]);
let y = mx.asStrided(x, [x.size], [-1], x.size - 1);
assertArrayAllTrue(mx.arrayEqual(y, x.index(mx.Slice(null, null, -1))));
});

it('squeezeExpand', () => {
Expand All @@ -1401,6 +1405,11 @@ describe('ops', () => {
const x = mx.array([3, 1, 2]);
const sortedX = mx.sort(x);
assert.deepEqual(sortedX.tolist(), [1, 2, 3]);

const a = mx.array([[4, 3], [2, 1], [5, 4], [3, 2]], mx.uint32);
const out = mx.argsort(a.index(mx.Slice(), 1));
const expected = mx.array([1, 3, 0, 2], mx.uint32);
assertArrayAllTrue(mx.arrayEqual(out, expected));
});

it('argpartition', () => {
Expand Down

0 comments on commit 66c8754

Please sign in to comment.