Skip to content

Commit

Permalink
Add Module.saveWeights/loadWeights
Browse files Browse the repository at this point in the history
  • Loading branch information
zcbenz committed May 2, 2024
1 parent 34318e0 commit 7a379cf
Show file tree
Hide file tree
Showing 12 changed files with 219 additions and 51 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ Some features are not supported yet and will be implemented in future:
primitive values.
* The APIs only accept plain parameters, e.g. `mx.uniform(0, 1, [2, 2])`. Named
parameter calls like `mx.uniform({shape: [2, 2]})` has not been implemented.
* The `.npz` tensor format is not supported yet.

### Complex numbers

Expand Down
8 changes: 3 additions & 5 deletions lib/index.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -282,11 +282,9 @@ export namespace core {
function rightShift(a: ScalarOrArray, b: ScalarOrArray, s?: StreamOrDevice): array;
function round(array: ScalarOrArray, s?: StreamOrDevice): array;
function rsqrt(array: ScalarOrArray, s?: StreamOrDevice): array;
function save(array: ScalarOrArray, filepath: string, s?: StreamOrDevice): void;
function savez(dict: { [key: string]: array }, filepath: string, s?: StreamOrDevice): void;
function savezCompressed(dict: { [key: string]: array }, filepath: string, s?: StreamOrDevice): void;
function saveGguf(array: ScalarOrArray, filepath: string, s?: StreamOrDevice): void;
function saveSafetensors(dict: { [key: string]: array }, filepath: string, s?: StreamOrDevice): void;
function save(filepath: string, array: array): void;
function saveGguf(filepath: string, arrays: Record<string, array>, metadata?: Record<string, string>): void;
function saveSafetensors(filepath: string, arrays: Record<string, array>, metadata?: Record<string, string>): void;
function sigmoid(array: ScalarOrArray, s?: StreamOrDevice): array;
function sign(array: ScalarOrArray, s?: StreamOrDevice): array;
function sin(array: ScalarOrArray, s?: StreamOrDevice): array;
Expand Down
102 changes: 101 additions & 1 deletion lib/nn/layers/base.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import {core as mx, utils} from '../../..';
import {deepEqual} from './pytools';

// A nested type that always has T as leaves.
type Nested<T> = T | Nested<T>[] | {[key: string]: Nested<T>};
Expand Down Expand Up @@ -150,6 +151,105 @@ export abstract class Module {
*/
abstract forward(...inputs: unknown[]): unknown;

/**
* Update the model's weights from a `.npz` or `.safetensors` file, or a list.
*
* @param fileOrWeights - The path to the weights `.npz` file (`.npz` or `.safetensors`) or a list of
* pairs of parameter names and arrays.
* @param strict - If `true` then checks that the provided weights exactly match the parameters of the
* model. Otherwise, only the weights actually contained in the model are loaded and shapes are not checked.
* Default: `true`.
*
* @returns The module instance after updating the weights.
*
* @example
* ```typescript
* import {core as mx, nn} from '@frost-beta/mlx';
*
* let model = new nn.Linear(10, 10);
*
* // Load from file
* model.loadWeights('weights.npz');
*
* // Load from .safetensors file
* model.loadWeights('weights.safetensors');
*
* // Load from list
* let weights = [
* ['weight', mx.random.uniform(0, 1, [10, 10])],
* ['bias', mx.zeros([10])],
* ];
* model.loadWeights(weights);
*
* // Missing weight raise exception
* weights = [
* ['weight', mx.random.uniform(0, 1, [10, 10])]
* ];
* try {
* model.loadWeights(weights);
* } catch (e) {
* console.log(e);
* }
*
* // Ok, only updates the weight but not the bias
* model.loadWeights(weights, false);
* ```
*/
loadWeights(fileOrWeights: string | [string, mx.array][], strict = true): this {
let weights = fileOrWeights;
if (typeof weights === 'string') {
weights = Object.entries(mx.load(weights));
}

if (strict) {
const newWeights = Object.fromEntries(weights) as Record<string, mx.array>;
const currentWeights = Object.fromEntries(utils.treeFlatten(this.parameters())) as Record<string, mx.array>;
const extras = Object.keys(newWeights).filter(key => !(key in currentWeights));
if (extras.length > 0) {
throw Error(`Received parameters not in model: ${extras.join(' ')}.`);
}
const missing = Object.keys(currentWeights).filter(key => !(key in newWeights));
if (missing.length > 0) {
throw Error(`Missing parameters: ${missing.join(' ')}.`);
}

Object.keys(currentWeights).forEach(key => {
const vNew = newWeights[key];
if (!(vNew instanceof mx.array)) {
throw Error(`Expected mx.array but received ${typeof vNew} for parameter ${key}`);
}
if (!deepEqual(vNew.shape, currentWeights[key].shape)) {
throw Error(`Expected shape ${currentWeights[key].shape} but received shape ${vNew.shape} for parameter ${key}`);
}
});
}

this.update(utils.treeUnflatten(weights) as Record<string, unknown>);
return this;
}

/**
* Save the model's weights to a file.
*
* @remarks
*
* The saving method is determined by the file extension:
* - `.npz` will use `mx.savez`
* - `.safetensors` will use `mx.saveSafetensors`
*
* @param filepath - The name of the file to save the weights to.
*/
saveWeights(filepath: string): void {
const params = Object.fromEntries(utils.treeFlatten(this.parameters())) as Record<string, mx.array>;
if (filepath.endsWith('.npz')) {
throw Error('Support for .npz format has not been implemented yet.');
} else if (filepath.endsWith('.safetensors')) {
mx.saveSafetensors(filepath, params);
} else {
throw Error("Unsupported file extension. Use '.npz' or '.safetensors'.");
}
}

/**
* Recursively filter the contents of the module using `filterFn`, namely only
* select keys and values where `filterFn` returns true.
Expand Down Expand Up @@ -546,5 +646,5 @@ function unwrap(model: Module,
return newValue;
}

throw new Error("Unexpected leaf found while traversing the module");
throw Error("Unexpected leaf found while traversing the module");
}
15 changes: 8 additions & 7 deletions lib/nn/layers/pooling.ts
Original file line number Diff line number Diff line change
Expand Up @@ -129,18 +129,19 @@ export class AvgPool1d extends Pool1d {
*
* @remarks
*
* Assuming an input of shape `(N, H, W, C)` and `kernelSize` is `(k_H, k_W)`, the output
* is a tensor of shape `(N, H_out, W_out, C)`, given by:
* Assuming an input of shape `(N, H, W, C)` and `kernelSize` is `(k_H, k_W)`,
* the output is a tensor of shape `(N, H_out, W_out, C)`, given by:
*
* `out(N_i, h, w, C_j) = max_{m=0,...,k_H-1} max_{n=0,...,k_W-1} input(N_i, stride[0] * h + m, stride[1] * w + n, C_j)`
*
* where `H_out = floor((H + 2 * padding[0] - kernelSize[0]) / stride[0]) + 1`
* `W_out = floor((W + 2 * padding[1] - kernelSize[1]) / stride[1]) + 1`
*
* The parameters `kernelSize`, `stride`, `padding`, can either be:
* - a single `number` -- in which case the same value is used for both the height and width axis;
* - a `tuple` of two `numbers`s -- in which case, the first `number`
* is used for the height axis, the second `number` for the width axis.
* - a single `number` -- in which case the same value is used for both the
* height and width axis;
* - a `tuple` of two `numbers`s -- in which case, the first `number` is used
* for the height axis, the second `number` for the width axis.
*
* @param kernelSize - The size of the pooling window.
* @param stride - The stride of the pooling window. Default: `kernelSize`.
Expand Down Expand Up @@ -173,8 +174,8 @@ export class MaxPool2d extends Pool2d {
*
* - a single `number` -- in which case the same value is used for both the
* height and width axis
* - a `number[]` of two numbers -- in which case, the first number is
* used for the height axis, the second number for the width axis.
* - a `number[]` of two numbers -- in which case, the first number is used for
* the height axis, the second number for the width axis.
*
* @param kernelSize - The size of the pooling window.
* @param stride - The stride of the pooling window. Default: `kernelSize`.
Expand Down
4 changes: 4 additions & 0 deletions lib/nn/layers/pytools.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
export function deepEqual(s1: number[], s2: number[]): boolean {
return s1.length === s2.length && s1.every((u, i) => u === s2[i]);
}

export function range(start: number, end: number, step = 1): number[] {
return Array.from({length: Math.ceil((end - start) / step)},
(_, i) => start + i * step);
Expand Down
18 changes: 10 additions & 8 deletions lib/nn/layers/quantized.ts
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,8 @@ export class QuantizedEmbedding extends Module {
*
* @remarks
*
* Use this for example when input embedding and output projection weights are tied.
* Use this for example when input embedding and output projection weights are
* tied.
*/
asLinear(x: mx.array): mx.array {
return mx.quantizedMatmul(x,
Expand All @@ -128,16 +129,17 @@ export class QuantizedEmbedding extends Module {
}

/**
* Applies an affine transformation to the input using a quantized weight matrix.
* Applies an affine transformation to the input using a quantized weight
* matrix.
*
* @remarks
*
* It is the quantized equivalent of `Linear`. For now its parameters are frozen
* and will not be included in any gradient computation but this will probably change
* in the future.
* and will not be included in any gradient computation but this will probably
* change in the future.
*
* `QuantizedLinear` also provides a classmethod `fromLinear` to convert
* linear layers to `QuantizedLinear` layers.
* `QuantizedLinear` also provides a classmethod `fromLinear` to convert linear
* layers to `QuantizedLinear` layers.
*/
export class QuantizedLinear extends Module {
/**
Expand Down Expand Up @@ -165,9 +167,9 @@ export class QuantizedLinear extends Module {
* @param inDims - The dimensionality of the input features.
* @param outDims - The dimensionality of the output features.
* @param bias - If set to `false` then the layer will not use a bias.
* Default: `true`.
* Default: `true`.
* @param groupSize - The group size to use for the quantized weight.
* Default: `64`.
* Default: `64`.
* @param bits - The bit width to use for the quantized weight. Default: `4`.
*/
constructor(inDims: number, outDims: number, bias = true, groupSize = 64, bits = 4) {
Expand Down
29 changes: 15 additions & 14 deletions lib/nn/layers/recurrent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,22 @@ import {Module} from './base';
* * `L` is the sequence length
* * `D` is the input's feature dimension
*
* Concretely, for each element along the sequence length axis, this
* layer applies the function:
* Concretely, for each element along the sequence length axis, this layer
* applies the function:
*
* ```math
* h_{t + 1} = \text{tanh} (W_{ih}x_t + W_{hh}h_t + b)
* ```
*
* The hidden state `h` has shape `NH` or `H`, depending on
* whether the input is batched or not. Returns the hidden state at each
* time step, of shape `NLH` or `LH`.
* The hidden state `h` has shape `NH` or `H`, depending on whether the input is
* batched or not. Returns the hidden state at each time step, of shape `NLH` or
* `LH`.
*
* @param inputDims - Dimension of the input, `D`.
* @param hiddenDims - Dimension of the hidden state, `H`.
* @param bias - Whether to use a bias. Default: `true`.
* @param nonlinearity - Non-linearity to use. If `null`,
* then func:`tanh` is used. Default: `null`.
* @param nonlinearity - Non-linearity to use. If `null`, then func:`tanh` is
* used. Default: `null`.
*/
export class RNN extends Module {
nonlinearity: (x: mx.array) => mx.array;
Expand Down Expand Up @@ -102,9 +102,10 @@ export class RNN extends Module {
* h_{t + 1} &= (1 - z_t) \odot n_t + z_t \odot h_t
* \end{aligned}
* ```
* The hidden state `h` has shape `NH` or `H` depending on
* whether the input is batched or not. Returns the hidden state at each
* time step of shape `NLH` or `LH`.
*
* The hidden state `h` has shape `NH` or `H` depending on whether the input is
* batched or not. Returns the hidden state at each time step of shape `NLH` or
* `LH`.
*
* @param inputDims - Dimension of the input, `D`.
* @param hiddenDims - Dimension of the hidden state, `H`.
Expand Down Expand Up @@ -202,11 +203,11 @@ export class GRU extends Module {
* \end{aligned}
* ```
*
* The hidden state `h` and cell state `c` have shape `NH`
* or `H`, depending on whether the input is batched or not.
* The hidden state `h` and cell state `c` have shape `NH` or `H`, depending on
* whether the input is batched or not.
*
* The layer returns two arrays, the hidden state and the cell state at
* each time step, both of shape `NLH` or `LH`.
* The layer returns two arrays, the hidden state and the cell state at each
* time step, both of shape `NLH` or `LH`.
*
* @param inputDims - Dimension of the input, `D`.
* @param hiddenDims - Dimension of the hidden state, `H`.
Expand Down
5 changes: 1 addition & 4 deletions lib/nn/losses.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import {core as mx} from '../..';
import {deepEqual} from './layers/pytools';

type Reduction = 'none' | 'mean' | 'sum';

Expand Down Expand Up @@ -568,7 +569,3 @@ const reduce = (loss: mx.array, reduction: Reduction = 'none'): mx.array => {
throw new Error("Invalid reduction. Must be 'none', 'mean', or 'sum'.");
}
};

const deepEqual = (s1: number[], s2: number[]): boolean => {
return s1.length === s2.length && s1.every((u, i) => u === s2[i]);
};
21 changes: 11 additions & 10 deletions lib/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -181,23 +181,24 @@ export function treeUnflatten(tree: [string, unknown][]): unknown {
// Walkthrough path and collect children.
const children: {[key: string]: [string, unknown][]} = {};
for (let [key, value] of tree) {
const [currentIndex, ...nextIndices] = key.split('.');
const nextIndex = nextIndices.length === 0 ? '' : nextIndices[0];
if (!(currentIndex in children)) {
children[currentIndex] = [];
}
children[currentIndex].push([nextIndex, value]);
const [index, ...nextIndices] = key.split('.');
const next = nextIndices?.join('.') ?? '';
if (!(index in children))
children[index] = [];
children[index].push([next, value]);
}

// Recursively map them to the original container.
if (isList) {
const keys = Object.keys(children).sort().map((idx) => parseInt(idx));
return keys.map((i: number) => treeUnflatten(children[i]));
const keys = Object.keys(children).sort().map((idx) => [ parseInt(idx), idx ]);
const newList = [];
for (const [i, k] of keys)
newList[i] = treeUnflatten(children[k]);
return newList;
} else {
const newTree = {};
for (let k in children) {
for (let k in children)
newTree[k] = treeUnflatten(children[k]);
}
return newTree;
}
}
Expand Down
3 changes: 2 additions & 1 deletion src/io.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ std::variant<mx::array,
mx::GGUFLoad>
Load(std::string file,
std::optional<std::string> format,
std::optional<bool> return_metadata,
std::optional<bool> return_metadata_arg,
mx::StreamOrDevice s) {
if (!format) {
size_t ext = file.find_last_of('.');
Expand All @@ -20,6 +20,7 @@ Load(std::string file,
format = file.substr(ext + 1);
}

bool return_metadata = return_metadata_arg.value_or(false);
if (return_metadata && (*format == "npy" || *format == "npz")) {
throw std::invalid_argument(
"[load] metadata not supported for format " + *format);
Expand Down
Loading

0 comments on commit 7a379cf

Please sign in to comment.