Skip to content

Commit

Permalink
Update to MLX v0.12.2
Browse files Browse the repository at this point in the history
  • Loading branch information
zcbenz committed May 4, 2024
1 parent e53f4be commit 09ddb40
Show file tree
Hide file tree
Showing 9 changed files with 178 additions and 15 deletions.
2 changes: 1 addition & 1 deletion deps/mlx
Submodule mlx updated 79 files
+2 −2 .pre-commit-config.yaml
+1 −1 ACKNOWLEDGMENTS.md
+2 −2 CMakeLists.txt
+2 −0 docs/src/python/metal.rst
+1 −0 docs/src/python/ops.rst
+7 −0 docs/src/python/optimizers.rst
+1 −0 docs/src/python/tree_utils.rst
+1 −1 examples/extensions/axpby/axpby.h
+22 −22 examples/extensions/axpby/axpby.metal
+1 −1 mlx/allocator.h
+33 −1 mlx/array.cpp
+8 −12 mlx/array.h
+1 −0 mlx/backend/accelerate/primitives.cpp
+1 −0 mlx/backend/common/default_primitives.cpp
+1 −1 mlx/backend/common/make_compiled_preamble.sh
+87 −0 mlx/backend/common/masked_mm.cpp
+12 −4 mlx/backend/metal/allocator.cpp
+4 −0 mlx/backend/metal/allocator.h
+23 −4 mlx/backend/metal/device.cpp
+1 −1 mlx/backend/metal/device.h
+11 −11 mlx/backend/metal/kernels/arange.metal
+41 −40 mlx/backend/metal/kernels/arg_reduce.metal
+93 −87 mlx/backend/metal/kernels/binary.metal
+109 −90 mlx/backend/metal/kernels/binary_two.metal
+1 −1 mlx/backend/metal/kernels/complex.h
+197 −203 mlx/backend/metal/kernels/conv.metal
+101 −104 mlx/backend/metal/kernels/copy.metal
+38 −34 mlx/backend/metal/kernels/fft.metal
+84 −98 mlx/backend/metal/kernels/gather.metal
+457 −217 mlx/backend/metal/kernels/gemv.metal
+17 −14 mlx/backend/metal/kernels/layer_norm.metal
+322 −232 mlx/backend/metal/kernels/quantized.metal
+4 −6 mlx/backend/metal/kernels/random.metal
+44 −41 mlx/backend/metal/kernels/reduction/kernels/reduce_all.metal
+99 −103 mlx/backend/metal/kernels/reduction/kernels/reduce_col.metal
+8 −9 mlx/backend/metal/kernels/reduction/kernels/reduce_init.metal
+127 −120 mlx/backend/metal/kernels/reduction/kernels/reduce_row.metal
+11 −6 mlx/backend/metal/kernels/rms_norm.metal
+21 −17 mlx/backend/metal/kernels/rope.metal
+512 −412 mlx/backend/metal/kernels/scaled_dot_product_attention.metal
+146 −100 mlx/backend/metal/kernels/scan.metal
+136 −157 mlx/backend/metal/kernels/scatter.metal
+22 −23 mlx/backend/metal/kernels/softmax.metal
+203 −213 mlx/backend/metal/kernels/sort.metal
+98 −63 mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.metal
+77 −53 mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.metal
+95 −66 mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm.metal
+284 −222 mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_addmm.metal
+168 −0 mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.metal
+296 −237 mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.metal
+208 −159 mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.metal
+4 −4 mlx/backend/metal/kernels/steel/gemm/params.h
+96 −92 mlx/backend/metal/kernels/ternary.metal
+30 −28 mlx/backend/metal/kernels/unary.metal
+333 −24 mlx/backend/metal/matmul.cpp
+12 −2 mlx/backend/metal/metal.h
+7 −0 mlx/backend/no_metal/metal.cpp
+2 −1 mlx/backend/no_metal/primitives.cpp
+1 −1 mlx/device.h
+1 −1 mlx/dtype.h
+1 −1 mlx/event.h
+7 −7 mlx/fast_primitives.h
+138 −15 mlx/ops.cpp
+8 −0 mlx/ops.h
+55 −1 mlx/primitives.cpp
+104 −84 mlx/primitives.h
+3 −3 mlx/transforms.cpp
+4 −4 mlx/types/complex.h
+33 −1 python/mlx/optimizers/optimizers.py
+42 −0 python/mlx/utils.py
+26 −18 python/src/indexing.cpp
+26 −2 python/src/metal.cpp
+32 −0 python/src/ops.cpp
+35 −0 python/tests/test_array.py
+193 −0 python/tests/test_blas.py
+3 −0 python/tests/test_metal.py
+42 −0 python/tests/test_optimizers.py
+1 −1 python/tests/test_quantized.py
+1 −1 setup.py
3 changes: 3 additions & 0 deletions lib/index.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ export namespace core {
function bitwiseXor(a: ScalarOrArray, b: ScalarOrArray, s?: StreamOrDevice): array;
function broadcastTo(array: ScalarOrArray, shape: number | number[], s?: StreamOrDevice): array;
function blockMaskedMM(a: ScalarOrArray, b: ScalarOrArray, blockSize: number, maskOut?: ScalarOrArray, maskLhs?: ScalarOrArray, maskRhs?: ScalarOrArray, s?: StreamOrDevice): array;
function blockSparseMM(a: ScalarOrArray, b: ScalarOrArray, indicesLhs?: ScalarOrArray, indicesRhs?: ScalarOrArray, s?: StreamOrDevice): array;
function ceil(array: ScalarOrArray, s?: StreamOrDevice): array;
function clip(array: ScalarOrArray, min: ScalarOrArray, max: ScalarOrArray, s?: StreamOrDevice): array;
function concatenate(arrays?: array[], axis?: number, s?: StreamOrDevice): array;
Expand Down Expand Up @@ -337,12 +338,14 @@ export namespace core {
function isAvailable(): boolean;
function getActiveMemory(): number;
function getPeakMemory(): number;
function resetPeakMemory(): void;
function getCacheMemory(): number;
function setMemoryLimit(limit: number, relaxed?: boolean): number;
function clearCache(): void;
function setCacheLimit(limit: number): number;
function startCapture(path: string): boolean;
function stopCapture(): void;
function deviceInfo(): {[key: string]: string | number};
}

// Random.
Expand Down
37 changes: 37 additions & 0 deletions lib/optimizers/optimizers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -886,3 +886,40 @@ export class Adafactor extends Optimizer {
mx.expandDims(cFactor, 0));
}
}

/**
* Clips the global norm of the gradients.
*
* @remarks
*
* This function ensures that the global norm of the gradients does not exceed
* `maxNorm`. It scales down the gradients proportionally if their norm is
* greater than `maxNorm`.
*
* Example:
* ```typescript
* const grads = {'w1': mx.array([2, 3]), 'w2': mx.array([1])};
* const [clippedGrads, totalNorm] = clipGradNorm(grads, 2.0);
* console.log(clippedGrads);
* // {"w1": mx.array([...]), "w2": mx.array([...])}
* ```
*
* @param grads A dictionary containing the gradient arrays.
* @param maxNorm The maximum allowed global norm of the gradients.
* @returns The possibly rescaled gradients and the original gradient norm.
*/
export function clipGradNorm(grads: Nested<mx.array>,
maxNorm: number): [Nested<mx.array>, mx.array] {
const normSquared = utils.treeReduce((acc: number | mx.array, g: mx.array) => {
return mx.add(acc, g.square().sum());
}, grads, 0);
const totalNorm = mx.sqrt(normSquared);
const normalizer = mx.divide(maxNorm, mx.add(totalNorm, 1e-6));

function clipper(g: mx.array) {
return mx.where(mx.less(totalNorm, maxNorm), g, mx.multiply(g, normalizer));
}

const clippedGrads = utils.treeMap(clipper, grads) as Nested<mx.array>;
return [clippedGrads, totalNorm];
}
47 changes: 47 additions & 0 deletions lib/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,53 @@ export function treeUnflatten(tree: [string, unknown][]): unknown {
}
}

/**
* Applies a reduction to the leaves of a tree.
*
* @remarks
*
* This function reduces trees into an accumulated result by applying
* the provided function `func` to the leaves of the tree.
*
* @example
* ```
* const tree = {a: [1, 2, 3], b: [4, 5]};
* treeReduce((acc, x) => acc + x, tree, 0); // Returns 15
* ```
*
* @param func - The reducer function that takes two arguments (accumulator,
* current value) and returns the updated accumulator.
* @param tree - The Python tree to reduce. It can be any nested combination of
* lists, tuples, or dictionaries.
* @param initializer - The initial value to start the reduction. If
* not provided, the first leaf value is used.
* @param isLeaf - A function to determine if an object is a
* leaf, returning `true` for leaf nodes and `false` otherwise.
*
* @returns The accumulated value.
*/
export function treeReduce<U>(func: (accumulator: U, currentValue: U) => U,
tree: Nested<U>,
initializer?: U,
isLeaf?: (node: unknown) => boolean): U {
if (isLeaf && isLeaf(tree))
return initializer != null ? func(initializer, tree as U) : tree as U;

let accumulator = initializer;

if (Array.isArray(tree)) {
for (const item of tree)
accumulator = treeReduce(func, item, accumulator, isLeaf);
} else if (typeof tree === 'object' && isDict(tree)) {
for (const item of Object.values(tree))
accumulator = treeReduce(func, item, accumulator, isLeaf);
} else {
return accumulator != null ? func(accumulator, tree as U) : tree as U;
}

return accumulator;
}

// A nested type that always has T as leaves.
export type Nested<T> = T | T[] | {[key: string]: Nested<T>};
export type NestedDict<T> = {[key: string]: Nested<T>};
Expand Down
54 changes: 41 additions & 13 deletions src/indexing.cc
Original file line number Diff line number Diff line change
Expand Up @@ -603,65 +603,93 @@ std::pair<bool, mx::array> SliceUpdate(
ScalarOrArray vals) {
bool is_slice = std::holds_alternative<ArrayIndex>(obj) &&
std::holds_alternative<Slice>(std::get<ArrayIndex>(obj));
// Can't route to slice update if not slice or tuple.
if (a->ndim() == 0 ||
(!is_slice && !std::holds_alternative<ArrayIndices>(obj))) {
return std::make_pair(false, *a);
}
if (std::holds_alternative<ArrayIndices>(obj)) {
// Can't route to slice update if any arrays are present.
for (const ArrayIndex& index : std::get<ArrayIndices>(obj)) {
if (std::holds_alternative<mx::array*>(index))
return std::make_pair(false, *a);
}
}

// Should be able to route to slice update.

// Pre process tuple.
mx::array up = ToArray(std::move(vals), a->dtype());

// Remove leading singletons dimensions from the update.
std::vector<int> up_shape = GetUpShape(up);
up = mx::reshape(std::move(up), up_shape.empty() ? std::vector<int>{1}
: std::move(up_shape));

// Build slice update params.
std::vector<int> starts(a->ndim(), 0);
std::vector<int> stops(a->shape());
std::vector<int> steps(a->ndim(), 1);
// If it's just a simple slice, just do a slice update and return.
if (is_slice) {
ReadSlice(std::get<Slice>(std::get<ArrayIndex>(obj)), a->shape(0),
&starts[0], &stops[0], &steps[0]);
// Do slice update.
return {true,
mx::slice_update(*a, std::move(up), std::move(starts),
std::move(stops), std::move(steps))};
}

// It must be a tuple.
ArrayIndices entries = std::move(std::get<ArrayIndices>(obj));
for (const ArrayIndex& index : entries) {
if (std::holds_alternative<mx::array*>(index))
return std::make_pair(false, *a);
}

// Expand ellipses into a series of ':' slices.
auto [non_none_indices, indices] = ExpandEllipsis(a->shape(),
std::move(entries));
// Dimension check.
if (non_none_indices > a->ndim()) {
std::ostringstream msg;
msg << "Too many indices for array with " << a->ndim() << "dimensions.";
throw std::invalid_argument(msg.str());
}
// If no non-None indices return the broadcasted update.
if (non_none_indices == 0) {
return std::make_pair(true, mx::broadcast_to(std::move(up), a->shape()));
}

std::vector<int> upd_expand_dims;
size_t axis = 0;
for (const ArrayIndex& index : indices) {
// Process entries.
std::vector<int> up_reshape(a->ndim());
int axis = a->ndim() - 1;
int up_axis = up.ndim() - 1;
while (axis >= non_none_indices) {
if (up_axis >= 0) {
up_reshape[axis] = up.shape(up_axis);
up_axis--;
} else {
up_reshape[axis] = 1;
}
axis--;
}

for (auto it = indices.rbegin(); it != indices.rend(); ++it) {
const ArrayIndex& index = *it;
if (std::holds_alternative<Slice>(index)) {
ReadSlice(std::get<Slice>(index), a->shape(axis),
&starts[axis], &stops[axis], &steps[axis]);
axis++;
up_reshape[axis] = (up_axis >= 0) ? up.shape(up_axis--) : 1;
axis--;
} else if (std::holds_alternative<int>(index)) {
int start = std::get<int>(index);
if (start < 0)
start += a->shape(axis);
starts[axis] = start;
stops[axis] = start + 1;
if (a->ndim() - axis < up.ndim()) {
upd_expand_dims.push_back(axis - a->ndim());
}
axis++;
up_reshape[axis] = 1;
axis--;
}
}

up = mx::expand_dims(std::move(up), std::move(upd_expand_dims));
up = mx::reshape(std::move(up), std::move(up_reshape));
return {true,
mx::slice_update(*a, std::move(up), std::move(starts),
std::move(stops), std::move(steps))};
Expand Down
4 changes: 3 additions & 1 deletion src/metal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@ void InitMetal(napi_env env, napi_value exports) {
"isAvailable", &mx::metal::is_available,
"getActiveMemory", &mx::metal::get_active_memory,
"getPeakMemory", &mx::metal::get_peak_memory,
"resetPeakMemory", &mx::metal::reset_peak_memory,
"getCacheMemory", &mx::metal::get_cache_memory,
"setMemoryLimit", &metal_ops::SetMemoryLimit,
"clearCache", &mx::metal::clear_cache,
"setCacheLimit", &mx::metal::set_cache_limit,
"startCapture", &mx::metal::start_capture,
"stopCapture", &mx::metal::stop_capture);
"stopCapture", &mx::metal::stop_capture,
"deviceInfo", &mx::metal::device_info);
}
2 changes: 2 additions & 0 deletions src/ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -698,6 +698,8 @@ void InitOps(napi_env env, napi_value exports) {
"outer", &mx::outer,
"tile", &ops::Tile,
"addmm", &ops::AddMM,
"blockMaskedMM", &mx::block_masked_mm,
"blockSparseMM", &mx::block_sparse_mm,
"diagonal", &ops::Diagonal,
"diag", &ops::Diag,
"atleast1d", NdOpWrapper(&mx::atleast_1d, &mx::atleast_1d),
Expand Down
8 changes: 8 additions & 0 deletions tests/array.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -630,6 +630,14 @@ describe('array', () => {
a = mx.zeros([2, 2, 2, 2]);
a.indexPut_([null, '...', null], 1);
assert.deepEqual(a.tolist(), mx.ones([2, 2, 2, 2]).tolist());

a = mx.zeros([2, 3, 4, 5, 3]);
a.indexPut_(['...', 0], 1);
assert.deepEqual(a.index('...', 0).tolist(), mx.ones([2, 3, 4, 5]).tolist());

a = mx.zeros([2, 3, 4, 5, 3]);
a.indexPut_([mx.Slice(), 0], 1);
assert.deepEqual(a.index(mx.Slice(), 0).tolist(), mx.ones([2, 4, 5, 3]).tolist());
});
});

Expand Down
36 changes: 36 additions & 0 deletions tests/optimizers.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,42 @@ describe('optimizers', () => {
});

// TODO(zcbenz): Add test_update_lr_compiled after implementing captures for mx.compile.

it('clipGradNorm', () => {
// Test with small gradients that do not require clipping
const smallGrads = {
first: [mx.array([0.1, 0.2]), mx.array([0.1])],
second: mx.array([0.3]),
};
let maxNorm = 10.0; // A large maxNorm that shouldn't trigger clipping
let [clippedGrads, totalNorm] = opt.clipGradNorm(smallGrads, maxNorm);
utils.treeMap((x: mx.array, y: mx.array) => {
assertArrayAllTrue(mx.arrayEqual(x, y));
}, smallGrads, [clippedGrads]);

// Test with large gradients that require clipping
const largeGrads = {
first: [mx.array([10, 20]), mx.array([10])],
second: mx.array([30]),
};
maxNorm = 1.0; // A small maxNorm that should trigger clipping
[clippedGrads, totalNorm] = opt.clipGradNorm(largeGrads, maxNorm);
// Correctly extract only the gradient values for norm calculation
const clippedValues = utils.treeFlatten(clippedGrads);
let normOfClipped = mx.array(0);
for (const [_, g] of clippedValues) {
normOfClipped = mx.add(normOfClipped, mx.square(g as mx.array).sum());
}
normOfClipped = mx.sqrt(normOfClipped);
assert.closeTo(normOfClipped.item() as number, maxNorm, 1e-6);

// Ensures that the scaling was done correctly
const scale = mx.divide(maxNorm, totalNorm);
const expectedGrads = utils.treeMap((g: mx.array) => mx.multiply(g, scale), largeGrads);
utils.treeMap((x: mx.array, y: mx.array) => {
assertArrayAllTrue(mx.allclose(x, y, 1e-6));
}, expectedGrads, [clippedGrads]);
});
});

describe('schedulers', () => {
Expand Down

0 comments on commit 09ddb40

Please sign in to comment.