Skip to content

Commit

Permalink
Return the metadata in the gguf pyo3 bindings. (huggingface#729)
Browse files Browse the repository at this point in the history
* Return the metadata in the gguf pyo3 bindings.

* Read the metadata in the quantized llama example.

* Get inference to work on gguf files.
  • Loading branch information
LaurentMazare authored Sep 4, 2023
1 parent 9c61b0f commit 20512ba
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 8 deletions.
39 changes: 35 additions & 4 deletions candle-pyo3/quant-llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,8 @@ def __init__(self, hparams, all_tensors):
self.norm = RmsNorm(all_tensors["norm.weight"])
self.output = all_tensors["output.weight"]
self.layers = []
cos_sin = precompute_freqs_cis(hparams, 10000.)
rope_freq = hparams.get("rope_freq", 10000.)
cos_sin = precompute_freqs_cis(hparams, rope_freq)
for layer_idx in range(hparams["n_layer"]):
layer = QuantizedLayer(layer_idx, hparams, all_tensors, cos_sin)
self.layers.append(layer)
Expand All @@ -133,15 +134,45 @@ def __call__(self, token, index_pos):
x = self.output.matmul_t(x)
return x

def gguf_rename(tensor_name):
if tensor_name == 'token_embd.weight': return 'tok_embeddings.weight'
if tensor_name == 'output_norm.weight': return 'norm.weight'
tensor_name = tensor_name.replace('blk.', 'layers.')
tensor_name = tensor_name.replace('.attn_q.', '.attention.wq.')
tensor_name = tensor_name.replace('.attn_k.', '.attention.wk.')
tensor_name = tensor_name.replace('.attn_v.', '.attention.wv.')
tensor_name = tensor_name.replace('.attn_output.', '.attention.wo.')
tensor_name = tensor_name.replace('.ffn_gate.', '.feed_forward.w1.')
tensor_name = tensor_name.replace('.ffn_down.', '.feed_forward.w2.')
tensor_name = tensor_name.replace('.ffn_up.', '.feed_forward.w3.')
tensor_name = tensor_name.replace('.attn_norm.', '.attention_norm.')
return tensor_name

def main():
if len(sys.argv) < 2:
raise ValueError("missing weight file argument")
filename = sys.argv[1]
print(f"reading model file {filename}")
if filename.endswith("gguf"):
all_tensors = candle.load_gguf(sys.argv[1])
hparams = None
vocab = None
all_tensors, metadata = candle.load_gguf(sys.argv[1])
vocab = metadata["tokenizer.ggml.tokens"]
for i, v in enumerate(vocab):
vocab[i] = '\n' if v == '<0x0A>' else v.replace('▁', ' ')
hparams = {k: v for (k, v) in metadata.items() if not k.startswith("tokenizer")}
print(hparams)
hparams = {
'n_vocab': len(vocab),
'n_embd': metadata['llama.embedding_length'],
'n_mult': 256,
'n_head': metadata['llama.attention.head_count'],
'n_head_kv': metadata['llama.attention.head_count_kv'],
'n_layer': metadata['llama.block_count'],
'n_rot': metadata['llama.rope.dimension_count'],
'rope_freq': metadata['llama.rope.freq_base'],
'ftype': metadata['general.file_type'],
}
all_tensors = { gguf_rename(k): v for k, v in all_tensors.items() }

else:
all_tensors, hparams, vocab = candle.load_ggml(sys.argv[1])
print(hparams)
Expand Down
41 changes: 37 additions & 4 deletions candle-pyo3/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -746,10 +746,35 @@ fn load_ggml(path: &str, py: Python<'_>) -> PyResult<(PyObject, PyObject, PyObje
}

#[pyfunction]
fn load_gguf(path: &str, py: Python<'_>) -> PyResult<PyObject> {
fn load_gguf(path: &str, py: Python<'_>) -> PyResult<(PyObject, PyObject)> {
use ::candle::quantized::gguf_file;
fn gguf_value_to_pyobject(v: &gguf_file::Value, py: Python<'_>) -> PyResult<PyObject> {
let v: PyObject = match v {
gguf_file::Value::U8(x) => x.into_py(py),
gguf_file::Value::I8(x) => x.into_py(py),
gguf_file::Value::U16(x) => x.into_py(py),
gguf_file::Value::I16(x) => x.into_py(py),
gguf_file::Value::U32(x) => x.into_py(py),
gguf_file::Value::I32(x) => x.into_py(py),
gguf_file::Value::U64(x) => x.into_py(py),
gguf_file::Value::I64(x) => x.into_py(py),
gguf_file::Value::F32(x) => x.into_py(py),
gguf_file::Value::F64(x) => x.into_py(py),
gguf_file::Value::Bool(x) => x.into_py(py),
gguf_file::Value::String(x) => x.into_py(py),
gguf_file::Value::Array(x) => {
let list = pyo3::types::PyList::empty(py);
for elem in x.iter() {
list.append(gguf_value_to_pyobject(elem, py)?)?;
}
list.into()
}
};
Ok(v)
}
let mut file = std::fs::File::open(path)?;
let gguf = ::candle::quantized::gguf_file::Content::read(&mut file).map_err(wrap_err)?;
let res = gguf
let gguf = gguf_file::Content::read(&mut file).map_err(wrap_err)?;
let tensors = gguf
.tensor_infos
.keys()
.map(|key| {
Expand All @@ -758,7 +783,15 @@ fn load_gguf(path: &str, py: Python<'_>) -> PyResult<PyObject> {
})
.collect::<::candle::Result<Vec<_>>>()
.map_err(wrap_err)?;
Ok(res.into_py_dict(py).to_object(py))
let tensors = tensors.into_py_dict(py).to_object(py);
let metadata = gguf
.metadata
.iter()
.map(|(key, value)| Ok((key, gguf_value_to_pyobject(value, py)?)))
.collect::<PyResult<Vec<_>>>()?
.into_py_dict(py)
.to_object(py);
Ok((tensors, metadata))
}

#[pyfunction]
Expand Down

0 comments on commit 20512ba

Please sign in to comment.