From 403680f17ddc086295fbaee316cbed22d97a519b Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 17 Jan 2024 10:27:58 +0100 Subject: [PATCH] Quantized GGUF style (#1523) * Metal quantized modifications proposal. - Add a device param, wherever needed. - Create new QMetal storage thing that implements QuantizedType. - Update everywhere needed. Fix Python. Fixing examples. Fix: fmt + clippy + stub. Moving everything around. Only missing the actual implems. Fixing everything + adding dequantized kernels. More work. Fixing matmul. Fmt + Clippy Some clippy fixes. Working state. Q2K Metal -> Bugged (also present in GGML). Q4K CPU -> Bugged (present previously, new test catch it). Q5K CPU -> Bugged (present previously). Q8_1 Both -> Never really implemented it seems Q8K metal -> Never implemented in metal Fixing Q2K bug (present in ggml). * Cleanup. * Fix the rebase. * Removing the fences speeds everything up and *is* correct this time... * Cleanup the fence. * After rebase. * Bad code removal. * Rebase after phi2 merge + fix replit default to CPU. * Making the CI happy. * More happy tests. --------- Co-authored-by: Nicolas Patry --- candle-core/examples/tensor-tools.rs | 122 +- candle-core/src/metal_backend.rs | 112 +- candle-core/src/quantized/ggml_file.rs | 84 +- candle-core/src/quantized/gguf_file.rs | 28 +- candle-core/src/quantized/metal.rs | 153 + candle-core/src/quantized/mod.rs | 302 +- candle-core/tests/quantized_tests.rs | 573 +- candle-examples/examples/blip/main.rs | 4 +- candle-examples/examples/llama2-c/main.rs | 8 +- candle-examples/examples/mistral/main.rs | 7 +- candle-examples/examples/phi/main.rs | 16 +- candle-examples/examples/quantized-t5/main.rs | 3 +- candle-examples/examples/quantized/main.rs | 16 +- candle-examples/examples/replit-code/main.rs | 13 +- candle-examples/examples/stable-lm/main.rs | 5 +- candle-examples/examples/whisper/main.rs | 6 +- candle-metal-kernels/src/lib.rs | 228 +- candle-metal-kernels/src/quantized.metal | 5107 +++++++++++++++++ candle-metal-kernels/src/tests.rs | 33 +- candle-metal-kernels/src/unary.metal | 2 +- candle-nn/examples/cpu_benchmarks.rs | 5 +- candle-pyo3/py_src/candle/utils/__init__.pyi | 8 +- candle-pyo3/src/lib.rs | 51 +- .../src/models/quantized_llama.rs | 41 +- .../src/models/quantized_mixformer.rs | 4 +- .../src/quantized_var_builder.rs | 12 +- candle-wasm-examples/blip/src/bin/m.rs | 2 +- candle-wasm-examples/phi/src/bin/m.rs | 6 +- .../t5/src/bin/m-quantized.rs | 9 +- candle-wasm-examples/whisper/src/worker.rs | 1 + candle-wasm-tests/tests/quantized_tests.rs | 2 +- 31 files changed, 6447 insertions(+), 516 deletions(-) create mode 100644 candle-core/src/quantized/metal.rs create mode 100644 candle-metal-kernels/src/quantized.metal diff --git a/candle-core/examples/tensor-tools.rs b/candle-core/examples/tensor-tools.rs index 337021aa44..eb6ceb1c62 100644 --- a/candle-core/examples/tensor-tools.rs +++ b/candle-core/examples/tensor-tools.rs @@ -1,5 +1,5 @@ -use candle_core::quantized::{gguf_file, k_quants, QTensor}; -use candle_core::{Device, Result, Tensor}; +use candle_core::quantized::{gguf_file, GgmlDType, QTensor}; +use candle_core::{Device, Result}; use clap::{Parser, Subcommand, ValueEnum}; use rayon::prelude::*; @@ -11,12 +11,7 @@ enum QuantizationMode { } impl QuantizationMode { - fn quantize( - &self, - name: &str, - tensor: QTensor, - default: fn(&Tensor) -> Result, - ) -> Result { + fn quantize(&self, name: &str, tensor: QTensor, dtype: GgmlDType) -> Result { match self { Self::Llama => { // Same behavior as the llama.cpp quantization. @@ -24,9 +19,9 @@ impl QuantizationMode { if should_quantize { let tensor = tensor.dequantize(&Device::Cpu)?; if name == "output.weight" { - QTensor::quantize::(&tensor) + QTensor::quantize(&tensor, GgmlDType::Q6K) } else { - default(&tensor) + QTensor::quantize(&tensor, dtype) } } else { Ok(tensor) @@ -60,6 +55,27 @@ enum Quantization { F32, } +impl Quantization { + fn dtype(&self) -> GgmlDType { + match self { + Quantization::Q4_0 => GgmlDType::Q4_0, + Quantization::Q4_1 => GgmlDType::Q4_1, + Quantization::Q5_0 => GgmlDType::Q5_0, + Quantization::Q5_1 => GgmlDType::Q5_1, + Quantization::Q8_0 => GgmlDType::Q8_0, + Quantization::Q8_1 => GgmlDType::Q8_1, + Quantization::Q2k => GgmlDType::Q2K, + Quantization::Q3k => GgmlDType::Q3K, + Quantization::Q4k => GgmlDType::Q4K, + Quantization::Q5k => GgmlDType::Q5K, + Quantization::Q6k => GgmlDType::Q6K, + Quantization::Q8k => GgmlDType::Q8K, + Quantization::F16 => GgmlDType::F16, + Quantization::F32 => GgmlDType::F32, + } + } +} + #[derive(ValueEnum, Debug, Clone)] enum Format { Safetensors, @@ -134,7 +150,12 @@ struct Args { command: Command, } -fn run_ls(file: &std::path::PathBuf, format: Option, verbose: bool) -> Result<()> { +fn run_ls( + file: &std::path::PathBuf, + format: Option, + verbose: bool, + device: &Device, +) -> Result<()> { let format = match format { Some(format) => format, None => match Format::infer(file) { @@ -200,7 +221,7 @@ fn run_ls(file: &std::path::PathBuf, format: Option, verbose: bool) -> R } Format::Ggml => { let mut file = std::fs::File::open(file)?; - let content = candle_core::quantized::ggml_file::Content::read(&mut file)?; + let content = candle_core::quantized::ggml_file::Content::read(&mut file, device)?; let mut tensors = content.tensors.into_iter().collect::>(); tensors.sort_by(|a, b| a.0.cmp(&b.0)); for (name, qtensor) in tensors.iter() { @@ -241,37 +262,8 @@ fn run_quantize_safetensors( } println!("tensors: {}", tensors.len()); - let quantize_fn = match q { - Quantization::Q4_0 => QTensor::quantize::, - Quantization::Q4_1 => QTensor::quantize::, - Quantization::Q5_0 => QTensor::quantize::, - Quantization::Q5_1 => QTensor::quantize::, - Quantization::Q8_0 => QTensor::quantize::, - Quantization::Q8_1 => QTensor::quantize::, - Quantization::Q2k => QTensor::quantize::, - Quantization::Q3k => QTensor::quantize::, - Quantization::Q4k => QTensor::quantize::, - Quantization::Q5k => QTensor::quantize::, - Quantization::Q6k => QTensor::quantize::, - Quantization::Q8k => QTensor::quantize::, - Quantization::F16 => QTensor::quantize::, - Quantization::F32 => QTensor::quantize::, - }; - let block_size = match q { - Quantization::Q4_0 => k_quants::QK4_0, - Quantization::Q4_1 => k_quants::QK4_1, - Quantization::Q5_0 => k_quants::QK5_0, - Quantization::Q5_1 => k_quants::QK5_1, - Quantization::Q8_0 => k_quants::QK8_0, - Quantization::Q8_1 => k_quants::QK8_1, - Quantization::Q2k - | Quantization::Q3k - | Quantization::Q4k - | Quantization::Q5k - | Quantization::Q6k - | Quantization::Q8k => k_quants::QK_K, - Quantization::F16 | Quantization::F32 => 1, - }; + let dtype = q.dtype(); + let block_size = dtype.block_size(); let qtensors = tensors .into_par_iter() @@ -279,9 +271,9 @@ fn run_quantize_safetensors( let should_quantize = tensor.rank() == 2 && tensor.dim(1)? % block_size == 0; println!(" quantizing {name} {tensor:?} {should_quantize}"); let tensor = if should_quantize { - quantize_fn(&tensor)? + QTensor::quantize(&tensor, dtype)? } else { - QTensor::quantize::(&tensor)? + QTensor::quantize(&tensor, GgmlDType::F32)? }; Ok((name, tensor)) }) @@ -294,13 +286,17 @@ fn run_quantize_safetensors( Ok(()) } -fn run_dequantize(in_file: std::path::PathBuf, out_file: std::path::PathBuf) -> Result<()> { +fn run_dequantize( + in_file: std::path::PathBuf, + out_file: std::path::PathBuf, + device: &Device, +) -> Result<()> { let mut in_file = std::fs::File::open(in_file)?; let content = gguf_file::Content::read(&mut in_file)?; let mut tensors = std::collections::HashMap::new(); for (tensor_name, _) in content.tensor_infos.iter() { - let tensor = content.tensor(&mut in_file, tensor_name)?; - let tensor = tensor.dequantize(&Device::Cpu)?; + let tensor = content.tensor(&mut in_file, tensor_name, device)?; + let tensor = tensor.dequantize(device)?; tensors.insert(tensor_name.to_string(), tensor); } candle_core::safetensors::save(&tensors, out_file)?; @@ -312,6 +308,7 @@ fn run_quantize( out_file: std::path::PathBuf, q: Quantization, qmode: QuantizationMode, + device: &Device, ) -> Result<()> { if in_files.is_empty() { candle_core::bail!("no specified input files") @@ -337,31 +334,15 @@ fn run_quantize( let content = gguf_file::Content::read(&mut in_)?; println!("tensors: {}", content.tensor_infos.len()); - let quantize_fn = match q { - Quantization::Q4_0 => QTensor::quantize::, - Quantization::Q4_1 => QTensor::quantize::, - Quantization::Q5_0 => QTensor::quantize::, - Quantization::Q5_1 => QTensor::quantize::, - Quantization::Q8_0 => QTensor::quantize::, - Quantization::Q8_1 => QTensor::quantize::, - Quantization::Q2k => QTensor::quantize::, - Quantization::Q3k => QTensor::quantize::, - Quantization::Q4k => QTensor::quantize::, - Quantization::Q5k => QTensor::quantize::, - Quantization::Q6k => QTensor::quantize::, - Quantization::Q8k => QTensor::quantize::, - Quantization::F16 => QTensor::quantize::, - Quantization::F32 => QTensor::quantize::, - }; - + let dtype = q.dtype(); let qtensors = content .tensor_infos .par_iter() .map(|(name, _)| { println!(" quantizing {name}"); let mut in_file = std::fs::File::open(&in_files[0])?; - let tensor = content.tensor(&mut in_file, name)?; - let tensor = qmode.quantize(name, tensor, quantize_fn)?; + let tensor = content.tensor(&mut in_file, name, device)?; + let tensor = qmode.quantize(name, tensor, dtype)?; Ok((name, tensor)) }) .collect::>>()?; @@ -381,6 +362,7 @@ fn run_quantize( fn main() -> anyhow::Result<()> { let args = Args::parse(); + let device = Device::Cpu; match args.command { Command::Ls { files, @@ -392,7 +374,7 @@ fn main() -> anyhow::Result<()> { if multiple_files { println!("--- {file:?} ---"); } - run_ls(file, format.clone(), verbose)? + run_ls(file, format.clone(), verbose, &device)? } } Command::Quantize { @@ -400,8 +382,8 @@ fn main() -> anyhow::Result<()> { out_file, quantization, mode, - } => run_quantize(&in_file, out_file, quantization, mode)?, - Command::Dequantize { in_file, out_file } => run_dequantize(in_file, out_file)?, + } => run_quantize(&in_file, out_file, quantization, mode, &device)?, + Command::Dequantize { in_file, out_file } => run_dequantize(in_file, out_file, &device)?, } Ok(()) } diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index 5269a89960..dc790ac9da 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -84,13 +84,8 @@ pub struct MetalDevice { command_buffer_index: Arc>, /// The maximum amount of [compute command encoder](https://developer.apple.com/documentation/metal/mtlcomputecommandencoder?language=objc) per [command buffer](https://developer.apple.com/documentation/metal/mtlcommandbuffer?language=objc) compute_per_buffer: usize, - /// Every compute command encoder (and blit encoders) are defended with this Fence, forcing the - /// execution order to be linear. - /// It could be relaxed in some circumstances, by managing ourselves the dependencies in the - /// compute graph. - fence: metal::Fence, /// Simple keeper struct to keep track of the already compiled kernels so we can reuse them. - /// Heavily used by [`candle_metal_kernels`], both fences need to match + /// Heavily used by [`candle_metal_kernels`] kernels: Arc, /// Simple allocator struct. /// The buffers are stored in size buckets since ML tends to use similar shapes over and over. @@ -221,10 +216,8 @@ impl MetalDevice { let command_buffer = self.command_buffer()?; command_buffer.set_label("with_data"); let blit = command_buffer.new_blit_command_encoder(); - blit.wait_for_fence(&self.fence); blit.set_label("with_data_blit"); blit.copy_from_buffer(&tmp, 0, &real, 0, tmp.length()); - blit.update_fence(&self.fence); blit.end_encoding(); // This is necessary, for mmaped safetensors @@ -238,6 +231,27 @@ impl MetalDevice { Ok(real) } + pub fn allocate_zeros(&self, size_in_bytes: usize) -> Result> { + let buffer = self.allocate_buffer( + size_in_bytes as NSUInteger, + MTLResourceOptions::StorageModePrivate, + "allocate_zeros", + )?; + let command_buffer = self.command_buffer()?; + command_buffer.set_label("zeros"); + let blit = command_buffer.new_blit_command_encoder(); + blit.fill_buffer( + &buffer, + metal::NSRange { + location: 0, + length: buffer.length(), + }, + 0, + ); + blit.end_encoding(); + Ok(buffer) + } + /// The critical allocator algorithm fn allocate_buffer( &self, @@ -308,35 +322,14 @@ impl BackendStorage for MetalStorage { } fn to_cpu_storage(&self) -> Result { - let length = self.buffer.length() as usize; - let size = self.dtype.size_in_bytes(); - if length % size != 0 { - crate::bail!( - "The Metal buffer length is not aligned with dtype {:?}", - self.dtype - ); - } - let buffer = self.device.new_buffer_managed(self.buffer.length())?; - { - let command_buffer = self.device.command_buffer()?; - command_buffer.set_label("to_cpu"); - let blit = command_buffer.new_blit_command_encoder(); - blit.set_label("blit_to_cpu"); - blit.wait_for_fence(&self.device.fence); - blit.copy_from_buffer(&self.buffer, 0, &buffer, 0, self.buffer.length()); - blit.update_fence(&self.device.fence); - blit.end_encoding(); - } - self.device.wait_until_completed()?; - match self.dtype { - DType::U8 => Ok(CpuStorage::U8(read_to_vec(&buffer, length / size))), - DType::U32 => Ok(CpuStorage::U32(read_to_vec(&buffer, length / size))), - DType::I64 => Ok(CpuStorage::I64(read_to_vec(&buffer, length / size))), - DType::F16 => Ok(CpuStorage::F16(read_to_vec(&buffer, length / size))), - DType::BF16 => Ok(CpuStorage::BF16(read_to_vec(&buffer, length / size))), - DType::F32 => Ok(CpuStorage::F32(read_to_vec(&buffer, length / size))), - DType::F64 => Ok(CpuStorage::F64(read_to_vec(&buffer, length / size))), + DType::U8 => Ok(CpuStorage::U8(self.to_cpu()?)), + DType::U32 => Ok(CpuStorage::U32(self.to_cpu()?)), + DType::I64 => Ok(CpuStorage::I64(self.to_cpu()?)), + DType::F16 => Ok(CpuStorage::F16(self.to_cpu()?)), + DType::BF16 => Ok(CpuStorage::BF16(self.to_cpu()?)), + DType::F32 => Ok(CpuStorage::F32(self.to_cpu()?)), + DType::F64 => Ok(CpuStorage::F64(self.to_cpu()?)), } } @@ -1264,7 +1257,7 @@ impl BackendStorage for MetalStorage { let src_offset = (src_l.start_offset() * self.dtype.size_in_bytes()) as NSUInteger; let length = (src_l.shape().elem_count() * self.dtype.size_in_bytes()) as NSUInteger; let dst_offset = (dst_offset * dst.dtype().size_in_bytes()) as NSUInteger; - blit.copy_from_buffer(&self.buffer, src_offset, dst.buffer(), dst_offset, length); + blit.copy_from_buffer(&self.buffer, src_offset, &dst.buffer(), dst_offset, length); blit.end_encoding(); } else { let src_shape = src_l.shape(); @@ -1521,6 +1514,28 @@ impl MetalStorage { command_buffer.set_label("binary"); Ok(Self::new(buffer, device.clone(), dtype)) } + + pub(crate) fn to_cpu(&self) -> Result> { + let length = self.buffer.length() as usize; + let size = self.dtype.size_in_bytes(); + if length % size != 0 { + crate::bail!( + "The Metal buffer length is not aligned with dtype {:?}", + self.dtype + ); + } + let buffer = self.device.new_buffer_managed(self.buffer.length())?; + { + let command_buffer = self.device.command_buffer()?; + command_buffer.set_label("to_cpu"); + let blit = command_buffer.new_blit_command_encoder(); + blit.set_label("blit_to_cpu"); + blit.copy_from_buffer(&self.buffer, 0, &buffer, 0, self.buffer.length()); + blit.end_encoding(); + } + self.device.wait_until_completed()?; + Ok(read_to_vec(&buffer, length / size)) + } } impl BackendDevice for MetalDevice { @@ -1533,16 +1548,14 @@ impl BackendDevice for MetalDevice { command_buffer.enqueue(); let command_buffer = Arc::new(RwLock::new(command_buffer)); let command_buffer_index = Arc::new(RwLock::new(0)); - let fence = device.new_fence(); - let kernels = Arc::new(Kernels::new(fence.clone())); + let kernels = Arc::new(Kernels::new()); let buffers = Arc::new(RwLock::new(HashMap::new())); let compute_per_buffer = match std::env::var("CANDLE_METAL_COMPUTE_PER_BUFFER") { Ok(val) => val.parse()?, - _ => 20, + _ => 10, }; Ok(Self { device, - fence, command_queue, command_buffer, command_buffer_index, @@ -1567,21 +1580,8 @@ impl BackendDevice for MetalDevice { } fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result { - let buffer = self.new_buffer(shape.elem_count(), dtype, "zeros")?; - let command_buffer = self.command_buffer()?; - command_buffer.set_label("zeros"); - let blit = command_buffer.new_blit_command_encoder(); - blit.wait_for_fence(&self.fence); - blit.fill_buffer( - &buffer, - metal::NSRange { - location: 0, - length: buffer.length(), - }, - 0, - ); - blit.update_fence(&self.fence); - blit.end_encoding(); + let size = shape.elem_count() * dtype.size_in_bytes(); + let buffer = self.allocate_zeros(size)?; Ok(MetalStorage::new(buffer, self.clone(), dtype)) } diff --git a/candle-core/src/quantized/ggml_file.rs b/candle-core/src/quantized/ggml_file.rs index 1dd3d9c05e..3823858056 100644 --- a/candle-core/src/quantized/ggml_file.rs +++ b/candle-core/src/quantized/ggml_file.rs @@ -1,7 +1,9 @@ //! Support for the GGML file format. -use super::{k_quants, GgmlDType}; -use crate::Result; +#[cfg(feature = "metal")] +use super::metal::load_quantized_metal; +use super::{k_quants, GgmlDType, QStorage}; +use crate::{Device, Result}; use byteorder::{LittleEndian, ReadBytesExt}; use std::collections::HashMap; @@ -121,11 +123,22 @@ fn from_raw_data( raw_data: &[u8], size_in_bytes: usize, dims: Vec, + device: &Device, ) -> Result { let raw_data_ptr = raw_data.as_ptr(); let n_blocks = size_in_bytes / std::mem::size_of::(); let data = unsafe { std::slice::from_raw_parts(raw_data_ptr as *const T, n_blocks) }; - super::QTensor::new(data.to_vec(), dims) + let data: QStorage = match device { + Device::Cpu => QStorage::Cpu(Box::new(data.to_vec())), + #[cfg(feature = "metal")] + Device::Metal(metal) => load_quantized_metal(metal, data)?, + #[cfg(not(feature = "metal"))] + Device::Metal(_metal) => { + crate::bail!("Metal backend requires `metal` feature") + } + device => unimplemented!("Implement quantized tensor for device {device:?}"), + }; + super::QTensor::new(data, dims) } /// Creates a [Tensor] from a raw GGML tensor. @@ -133,29 +146,50 @@ pub fn qtensor_from_ggml( ggml_dtype: GgmlDType, raw_data: &[u8], dims: Vec, + device: &Device, ) -> Result { let tensor_elems = dims.iter().product::(); - let blck_size = ggml_dtype.blck_size(); - if tensor_elems % blck_size != 0 { + let block_size = ggml_dtype.block_size(); + if tensor_elems % block_size != 0 { crate::bail!( - "the number of elements {tensor_elems} is not divisible by the block size {blck_size}" + "the number of elements {tensor_elems} is not divisible by the block size {block_size}" ) } - let size_in_bytes = tensor_elems / blck_size * ggml_dtype.type_size(); + let size_in_bytes = tensor_elems / block_size * ggml_dtype.type_size(); match ggml_dtype { - GgmlDType::F32 => from_raw_data::(raw_data, size_in_bytes, dims), - GgmlDType::F16 => from_raw_data::(raw_data, size_in_bytes, dims), - GgmlDType::Q4_0 => from_raw_data::(raw_data, size_in_bytes, dims), - GgmlDType::Q4_1 => from_raw_data::(raw_data, size_in_bytes, dims), - GgmlDType::Q5_0 => from_raw_data::(raw_data, size_in_bytes, dims), - GgmlDType::Q5_1 => from_raw_data::(raw_data, size_in_bytes, dims), - GgmlDType::Q8_0 => from_raw_data::(raw_data, size_in_bytes, dims), - GgmlDType::Q2K => from_raw_data::(raw_data, size_in_bytes, dims), - GgmlDType::Q3K => from_raw_data::(raw_data, size_in_bytes, dims), - GgmlDType::Q4K => from_raw_data::(raw_data, size_in_bytes, dims), - GgmlDType::Q5K => from_raw_data::(raw_data, size_in_bytes, dims), - GgmlDType::Q6K => from_raw_data::(raw_data, size_in_bytes, dims), + GgmlDType::F32 => from_raw_data::(raw_data, size_in_bytes, dims, device), + GgmlDType::F16 => from_raw_data::(raw_data, size_in_bytes, dims, device), + GgmlDType::Q4_0 => { + from_raw_data::(raw_data, size_in_bytes, dims, device) + } + GgmlDType::Q4_1 => { + from_raw_data::(raw_data, size_in_bytes, dims, device) + } + GgmlDType::Q5_0 => { + from_raw_data::(raw_data, size_in_bytes, dims, device) + } + GgmlDType::Q5_1 => { + from_raw_data::(raw_data, size_in_bytes, dims, device) + } + GgmlDType::Q8_0 => { + from_raw_data::(raw_data, size_in_bytes, dims, device) + } + GgmlDType::Q2K => { + from_raw_data::(raw_data, size_in_bytes, dims, device) + } + GgmlDType::Q3K => { + from_raw_data::(raw_data, size_in_bytes, dims, device) + } + GgmlDType::Q4K => { + from_raw_data::(raw_data, size_in_bytes, dims, device) + } + GgmlDType::Q5K => { + from_raw_data::(raw_data, size_in_bytes, dims, device) + } + GgmlDType::Q6K => { + from_raw_data::(raw_data, size_in_bytes, dims, device) + } _ => crate::bail!("quantized type {ggml_dtype:?} is not supported yet"), } } @@ -163,6 +197,7 @@ pub fn qtensor_from_ggml( fn read_one_tensor( reader: &mut R, magic: VersionedMagic, + device: &Device, ) -> Result<(String, super::QTensor)> { let n_dims = reader.read_u32::()?; let name_len = reader.read_u32::()?; @@ -183,11 +218,11 @@ fn read_one_tensor( } let dims = dims.iter().map(|&u| u as usize).collect::>(); let tensor_elems = dims.iter().product::(); - let size_in_bytes = tensor_elems * ggml_dtype.type_size() / ggml_dtype.blck_size(); + let size_in_bytes = tensor_elems * ggml_dtype.type_size() / ggml_dtype.block_size(); // TODO: Mmap version to avoid copying the data around? let mut raw_data = vec![0u8; size_in_bytes]; reader.read_exact(&mut raw_data)?; - match qtensor_from_ggml(ggml_dtype, &raw_data, dims) { + match qtensor_from_ggml(ggml_dtype, &raw_data, dims, device) { Ok(tensor) => Ok((name, tensor)), Err(e) => crate::bail!("Error creating tensor {name}: {e}"), } @@ -201,7 +236,10 @@ pub struct Content { } impl Content { - pub fn read(reader: &mut R) -> Result { + pub fn read( + reader: &mut R, + device: &Device, + ) -> Result { // https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/llama.cpp#L505 let last_position = reader.seek(std::io::SeekFrom::End(0))?; reader.seek(std::io::SeekFrom::Start(0))?; @@ -211,7 +249,7 @@ impl Content { let mut tensors = HashMap::new(); while reader.stream_position()? != last_position { - let (name, tensor) = read_one_tensor(reader, magic)?; + let (name, tensor) = read_one_tensor(reader, magic, device)?; tensors.insert(name, tensor); } Ok(Self { diff --git a/candle-core/src/quantized/gguf_file.rs b/candle-core/src/quantized/gguf_file.rs index 587ffc0f87..b729d4a0fd 100644 --- a/candle-core/src/quantized/gguf_file.rs +++ b/candle-core/src/quantized/gguf_file.rs @@ -3,7 +3,7 @@ //! Spec: https://github.com/philpax/ggml/blob/gguf-spec/docs/gguf.md use super::{GgmlDType, QTensor}; -use crate::Result; +use crate::{Device, Result}; use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; use std::collections::HashMap; @@ -59,19 +59,25 @@ impl TensorInfo { &self, reader: &mut R, tensor_data_offset: u64, + device: &Device, ) -> Result { let tensor_elems = self.shape.elem_count(); - let blck_size = self.ggml_dtype.blck_size(); - if tensor_elems % blck_size != 0 { + let block_size = self.ggml_dtype.block_size(); + if tensor_elems % block_size != 0 { crate::bail!( - "the number of elements {tensor_elems} is not divisible by the block size {blck_size}" + "the number of elements {tensor_elems} is not divisible by the block size {block_size}" ) } - let size_in_bytes = tensor_elems / blck_size * self.ggml_dtype.type_size(); + let size_in_bytes = tensor_elems / block_size * self.ggml_dtype.type_size(); let mut raw_data = vec![0u8; size_in_bytes]; reader.seek(std::io::SeekFrom::Start(tensor_data_offset + self.offset))?; reader.read_exact(&mut raw_data)?; - super::ggml_file::qtensor_from_ggml(self.ggml_dtype, &raw_data, self.shape.dims().to_vec()) + super::ggml_file::qtensor_from_ggml( + self.ggml_dtype, + &raw_data, + self.shape.dims().to_vec(), + device, + ) } } @@ -460,12 +466,13 @@ impl Content { &self, reader: &mut R, name: &str, + device: &Device, ) -> Result { let tensor_info = match self.tensor_infos.get(name) { Some(tensor_info) => tensor_info, None => crate::bail!("cannot find tensor info for {name}"), }; - tensor_info.read(reader, self.tensor_data_offset) + tensor_info.read(reader, self.tensor_data_offset, device) } } @@ -517,10 +524,9 @@ pub fn write( "internal error, unexpected current position {tensor_start_pos} {offset} {pos}" ) } - let data_ptr = tensor.as_ptr(); - let size_in_bytes = tensor.storage_size_in_bytes(); - let data = unsafe { std::slice::from_raw_parts(data_ptr, size_in_bytes) }; - w.write_all(data)?; + let data = tensor.data()?; + let size_in_bytes = data.len(); + w.write_all(&data)?; let padding = 31 - (31 + size_in_bytes) % 32; w.write_all(&vec![0u8; padding])?; } diff --git a/candle-core/src/quantized/metal.rs b/candle-core/src/quantized/metal.rs new file mode 100644 index 0000000000..fe57ce1465 --- /dev/null +++ b/candle-core/src/quantized/metal.rs @@ -0,0 +1,153 @@ +use super::{GgmlDType, QStorage}; +use crate::{DType, MetalDevice, MetalStorage, Result}; +use metal::Buffer; +use std::sync::Arc; + +pub struct QMetalStorage { + dtype: GgmlDType, + device: MetalDevice, + buffer: Arc, +} + +impl QMetalStorage { + pub fn dtype(&self) -> GgmlDType { + self.dtype + } + + pub fn buffer(&self) -> &Buffer { + &self.buffer + } + + pub fn new(buffer: Arc, device: MetalDevice, dtype: GgmlDType) -> Self { + Self { + device, + buffer, + dtype, + } + } + + pub fn dequantize(&self, elem_count: usize) -> Result { + let buffer = self.device.new_buffer_managed(self.buffer.length())?; + let command_buffer = self.device.command_buffer()?; + command_buffer.set_label("to_cpu"); + let blit = command_buffer.new_blit_command_encoder(); + blit.set_label("blit_to_cpu"); + blit.copy_from_buffer(&self.buffer, 0, &buffer, 0, self.buffer.length()); + blit.end_encoding(); + self.device.wait_until_completed()?; + let mut out = vec![0.0; elem_count]; + match self.dtype { + GgmlDType::F32 => { + let vec: Vec = read_to_vec(&buffer, elem_count); + use crate::quantized::k_quants::GgmlType; + f32::to_float(&vec, &mut out)?; + } + GgmlDType::F16 => { + let vec: Vec = read_to_vec(&buffer, elem_count); + use crate::quantized::k_quants::GgmlType; + half::f16::to_float(&vec, &mut out)?; + } + GgmlDType::Q4_0 => { + let vec: Vec = read_to_vec(&buffer, elem_count); + use crate::quantized::k_quants::GgmlType; + crate::quantized::BlockQ4_0::to_float(&vec, &mut out)?; + } + GgmlDType::Q4_1 => { + let vec: Vec = read_to_vec(&buffer, elem_count); + use crate::quantized::k_quants::GgmlType; + crate::quantized::BlockQ4_1::to_float(&vec, &mut out)?; + } + GgmlDType::Q5_0 => { + let vec: Vec = read_to_vec(&buffer, elem_count); + use crate::quantized::k_quants::GgmlType; + crate::quantized::BlockQ5_0::to_float(&vec, &mut out)?; + } + GgmlDType::Q5_1 => { + let vec: Vec = read_to_vec(&buffer, elem_count); + use crate::quantized::k_quants::GgmlType; + crate::quantized::BlockQ5_1::to_float(&vec, &mut out)?; + } + GgmlDType::Q8_0 => { + let vec: Vec = read_to_vec(&buffer, elem_count); + use crate::quantized::k_quants::GgmlType; + crate::quantized::BlockQ8_0::to_float(&vec, &mut out)?; + } + GgmlDType::Q8_1 => { + let vec: Vec = read_to_vec(&buffer, elem_count); + use crate::quantized::k_quants::GgmlType; + crate::quantized::BlockQ8_1::to_float(&vec, &mut out)?; + } + GgmlDType::Q2K => { + let vec: Vec = + read_to_vec(&buffer, elem_count / self.dtype.block_size()); + use crate::quantized::k_quants::GgmlType; + crate::quantized::BlockQ2K::to_float(&vec, &mut out)?; + } + GgmlDType::Q3K => { + let vec: Vec = + read_to_vec(&buffer, elem_count / self.dtype.block_size()); + use crate::quantized::k_quants::GgmlType; + crate::quantized::BlockQ3K::to_float(&vec, &mut out)?; + } + GgmlDType::Q4K => { + let vec: Vec = + read_to_vec(&buffer, elem_count / self.dtype.block_size()); + use crate::quantized::k_quants::GgmlType; + crate::quantized::BlockQ4K::to_float(&vec, &mut out)?; + } + GgmlDType::Q5K => { + let vec: Vec = + read_to_vec(&buffer, elem_count / self.dtype.block_size()); + use crate::quantized::k_quants::GgmlType; + crate::quantized::BlockQ5K::to_float(&vec, &mut out)?; + } + GgmlDType::Q6K => { + let vec: Vec = + read_to_vec(&buffer, elem_count / self.dtype.block_size()); + use crate::quantized::k_quants::GgmlType; + crate::quantized::BlockQ6K::to_float(&vec, &mut out)?; + } + GgmlDType::Q8K => { + let vec: Vec = + read_to_vec(&buffer, elem_count / self.dtype.block_size()); + use crate::quantized::k_quants::GgmlType; + crate::quantized::BlockQ8K::to_float(&vec, &mut out)?; + } + } + + let buffer = self.device.new_buffer_with_data(&out)?; + Ok(MetalStorage::new(buffer, self.device.clone(), DType::F32)) + } + + pub fn quantize(&mut self, src: &MetalStorage) -> Result<()> { + // Quantization only happens on CPU for now. + let src = src.to_cpu::()?; + let elem_count = src.len(); + let src = crate::Storage::Cpu(crate::CpuStorage::F32(src)); + let mut qcpu_storage = crate::Device::Cpu.qzeros(elem_count, self.dtype)?; + qcpu_storage.quantize(&src)?; + let buffer = self.device.new_buffer_with_data(&qcpu_storage.data()?)?; + self.buffer = buffer; + Ok(()) + } +} + +pub fn load_quantized_metal( + device: &MetalDevice, + data: &[T], +) -> Result { + let buffer = device.new_buffer_with_data(data)?; + let device = device.clone(); + Ok(QStorage::Metal(QMetalStorage { + dtype: T::DTYPE, + device, + buffer, + })) +} + +fn read_to_vec(buffer: &Buffer, n: usize) -> Vec { + let ptr = buffer.contents() as *const T; + assert!(!ptr.is_null()); + let slice = unsafe { std::slice::from_raw_parts(ptr, n) }; + slice.to_vec() +} diff --git a/candle-core/src/quantized/mod.rs b/candle-core/src/quantized/mod.rs index 043733ae87..1dc5fe8f74 100644 --- a/candle-core/src/quantized/mod.rs +++ b/candle-core/src/quantized/mod.rs @@ -1,23 +1,125 @@ -use crate::{Device, Result, Shape, Tensor}; +#[cfg(feature = "metal")] +use crate::{backend::BackendStorage, DType}; +use crate::{CpuStorage, Device, Result, Shape, Storage, Tensor}; +use k_quants::*; +use std::borrow::Cow; #[cfg(target_feature = "avx")] pub mod avx; pub mod ggml_file; pub mod gguf_file; pub mod k_quants; +#[cfg(feature = "metal")] +pub mod metal; #[cfg(target_feature = "neon")] pub mod neon; #[cfg(target_feature = "simd128")] pub mod simd128; pub mod utils; +use half::f16; pub use k_quants::GgmlType; pub struct QTensor { - data: Box, + storage: QStorage, shape: Shape, } +impl Device { + fn qzeros(&self, elem_count: usize, dtype: GgmlDType) -> Result { + match self { + Device::Cpu => { + let storage = dtype.cpu_zeros(elem_count); + Ok(QStorage::Cpu(storage)) + } + #[cfg(feature = "metal")] + Device::Metal(metal) => { + let size = elem_count * dtype.type_size() / dtype.block_size(); + let buffer = metal.allocate_zeros(size)?; + Ok(QStorage::Metal(metal::QMetalStorage::new( + buffer, + metal.clone(), + dtype, + ))) + } + #[cfg(not(feature = "metal"))] + Device::Metal(_metal) => { + crate::bail!("Metal feature not activated"); + } + Device::Cuda(_cuda) => { + crate::bail!("Cuda ggml quantization not supported"); + } + } + } +} + +pub enum QStorage { + Cpu(Box), + #[cfg(feature = "metal")] + Metal(metal::QMetalStorage), +} + +impl QStorage { + fn block_size(&self) -> usize { + match self { + QStorage::Cpu(storage) => storage.block_size(), + #[cfg(feature = "metal")] + QStorage::Metal(storage) => storage.dtype().block_size(), + } + } + + fn dtype(&self) -> GgmlDType { + match self { + QStorage::Cpu(storage) => storage.dtype(), + #[cfg(feature = "metal")] + QStorage::Metal(storage) => storage.dtype(), + } + } + + fn size_in_bytes(&self) -> usize { + match self { + QStorage::Cpu(storage) => storage.storage_size_in_bytes(), + #[cfg(feature = "metal")] + QStorage::Metal(storage) => storage.buffer().length() as usize, + } + } + + fn quantize(&mut self, src: &Storage) -> Result<()> { + match (self, src) { + (QStorage::Cpu(storage), Storage::Cpu(src)) => { + storage.from_float(src.as_slice::()?)?; + } + #[cfg(feature = "metal")] + (QStorage::Metal(storage), Storage::Metal(src)) => storage.quantize(src)?, + _ => crate::bail!("Invalid dequantize storage locations do not match"), + } + Ok(()) + } + + fn dequantize(&self, elem_count: usize) -> Result { + match self { + QStorage::Cpu(storage) => Ok(Storage::Cpu(storage.dequantize(elem_count)?)), + #[cfg(feature = "metal")] + QStorage::Metal(storage) => Ok(Storage::Metal(storage.dequantize(elem_count)?)), + } + } + + fn data(&self) -> Result> { + match self { + QStorage::Cpu(storage) => { + let data_ptr = storage.as_ptr(); + let size_in_bytes = storage.storage_size_in_bytes(); + let data = unsafe { std::slice::from_raw_parts(data_ptr, size_in_bytes) }; + Ok(Cow::from(data)) + } + #[cfg(feature = "metal")] + QStorage::Metal(_storage) => { + crate::bail!("not implemented"); + } + } + } +} + #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub enum GgmlDType { F32, @@ -77,6 +179,25 @@ impl GgmlDType { } } + /// The block dtype + pub fn cpu_zeros(&self, elem_count: usize) -> Box { + match self { + Self::F32 => Box::new(vec![f32::zeros(); elem_count]), + Self::F16 => Box::new(vec![f16::zeros(); elem_count]), + Self::Q4_0 => Box::new(vec![BlockQ4_0::zeros(); elem_count / BlockQ4_0::BLCK_SIZE]), + Self::Q4_1 => Box::new(vec![BlockQ4_1::zeros(); elem_count / BlockQ4_1::BLCK_SIZE]), + Self::Q5_0 => Box::new(vec![BlockQ5_0::zeros(); elem_count / BlockQ5_0::BLCK_SIZE]), + Self::Q5_1 => Box::new(vec![BlockQ5_1::zeros(); elem_count / BlockQ5_1::BLCK_SIZE]), + Self::Q8_0 => Box::new(vec![BlockQ8_0::zeros(); elem_count / BlockQ8_0::BLCK_SIZE]), + Self::Q8_1 => Box::new(vec![BlockQ8_1::zeros(); elem_count / BlockQ8_1::BLCK_SIZE]), + Self::Q2K => Box::new(vec![BlockQ2K::zeros(); elem_count / BlockQ2K::BLCK_SIZE]), + Self::Q3K => Box::new(vec![BlockQ3K::zeros(); elem_count / BlockQ3K::BLCK_SIZE]), + Self::Q4K => Box::new(vec![BlockQ4K::zeros(); elem_count / BlockQ4K::BLCK_SIZE]), + Self::Q5K => Box::new(vec![BlockQ5K::zeros(); elem_count / BlockQ5K::BLCK_SIZE]), + Self::Q6K => Box::new(vec![BlockQ6K::zeros(); elem_count / BlockQ6K::BLCK_SIZE]), + Self::Q8K => Box::new(vec![BlockQ8K::zeros(); elem_count / BlockQ8K::BLCK_SIZE]), + } + } /// The type size for blocks in bytes. pub fn type_size(&self) -> usize { use k_quants::*; @@ -100,7 +221,7 @@ impl GgmlDType { } /// The block size, i.e. the number of elements stored in each block. - pub fn blck_size(&self) -> usize { + pub fn block_size(&self) -> usize { match self { Self::F32 => 1, Self::F16 => 1, @@ -119,9 +240,13 @@ impl GgmlDType { pub trait QuantizedType: Send + Sync { fn dtype(&self) -> GgmlDType; fn matmul_t(&self, mkn: (usize, usize, usize), lhs: &[f32], dst: &mut [f32]) -> Result<()>; - fn to_float(&self, ys: &mut [f32]) -> Result<()>; + fn dequantize(&self, elem_count: usize) -> Result; fn storage_size_in_bytes(&self) -> usize; fn as_ptr(&self) -> *const u8; + fn block_size(&self) -> usize; + #[allow(clippy::wrong_self_convention)] + fn from_float(&mut self, xs: &[f32]) -> Result<()>; + fn size(&self) -> usize; } impl QuantizedType for Vec { @@ -129,12 +254,26 @@ impl QuantizedType for Vec { k_quants::matmul(mkn, lhs, self.as_slice(), dst) } + fn size(&self) -> usize { + self.len() * core::mem::size_of::() + } + + fn from_float(&mut self, xs: &[f32]) -> Result<()> { + T::from_float(xs, self) + } + fn dtype(&self) -> GgmlDType { T::DTYPE } - fn to_float(&self, ys: &mut [f32]) -> Result<()> { - T::to_float(self.as_slice(), ys) + fn block_size(&self) -> usize { + T::BLCK_SIZE + } + + fn dequantize(&self, elem_count: usize) -> Result { + let mut ys = vec![0.0f32; elem_count]; + T::to_float(self.as_slice(), &mut ys)?; + Ok(CpuStorage::F32(ys)) } fn storage_size_in_bytes(&self) -> usize { @@ -152,56 +291,49 @@ impl std::fmt::Debug for QTensor { } } -fn check_shape(shape: &Shape) -> Result<()> { +fn check_shape(shape: &Shape, block_size: usize) -> Result<()> { let dims = shape.dims(); if dims.is_empty() { crate::bail!("scalar tensor cannot be quantized {shape:?}") } - if dims[dims.len() - 1] % T::BLCK_SIZE != 0 { + if dims[dims.len() - 1] % block_size != 0 { crate::bail!( "quantized tensor must have their last dim divisible by block size {shape:?} {}", - T::BLCK_SIZE + block_size ) } Ok(()) } impl QTensor { - pub fn new, T: k_quants::GgmlType + Send + Sync + 'static>( - data: Vec, - shape: S, - ) -> Result { + pub fn new>(storage: QStorage, shape: S) -> Result { let shape = shape.into(); - check_shape::(&shape)?; - Ok(Self { - data: Box::new(data), - shape, - }) + check_shape(&shape, storage.block_size())?; + Ok(Self { storage, shape }) } - pub fn quantize(src: &Tensor) -> Result { + pub fn quantize(src: &Tensor, dtype: GgmlDType) -> Result { let shape = src.shape(); - check_shape::(shape)?; - let src = src - .to_dtype(crate::DType::F32)? - .flatten_all()? - .to_vec1::()?; - if src.len() % T::BLCK_SIZE != 0 { + let block_size = dtype.block_size(); + check_shape(shape, block_size)?; + let src = src.to_dtype(crate::DType::F32)?.flatten_all()?; + let elem_count = shape.elem_count(); + if elem_count % block_size != 0 { crate::bail!( "tensor size ({shape:?}) is not divisible by block size {}", - T::BLCK_SIZE + block_size ) } - let mut data = vec![T::zeros(); src.len() / T::BLCK_SIZE]; - T::from_float(&src, &mut data)?; + let mut storage = src.device().qzeros(elem_count, dtype)?; + storage.quantize(&src.storage())?; Ok(Self { - data: Box::new(data), + storage, shape: shape.clone(), }) } pub fn dtype(&self) -> GgmlDType { - self.data.dtype() + self.storage.dtype() } pub fn rank(&self) -> usize { @@ -213,21 +345,19 @@ impl QTensor { } pub fn dequantize(&self, device: &Device) -> Result { - let mut f32_data = vec![0f32; self.shape.elem_count()]; - self.data.to_float(&mut f32_data)?; - Tensor::from_vec(f32_data, &self.shape, device) - } - - pub fn matmul_t(&self, mkn: (usize, usize, usize), lhs: &[f32], dst: &mut [f32]) -> Result<()> { - self.data.matmul_t(mkn, lhs, dst) + let storage = self.storage.dequantize(self.shape.elem_count())?; + let none = crate::op::BackpropOp::none(); + let is_variable = false; + crate::tensor::from_storage(storage, self.shape.clone(), none, is_variable) + .to_device(device) } pub fn storage_size_in_bytes(&self) -> usize { - self.data.storage_size_in_bytes() + self.storage.size_in_bytes() } - pub fn as_ptr(&self) -> *const u8 { - self.data.as_ptr() + pub fn data(&self) -> Result> { + self.storage.data() } } @@ -294,17 +424,93 @@ impl crate::CustomOp1 for QTensor { } dst_shape.push(n); let dst_shape = Shape::from(dst_shape); - let storage = storage.as_slice::()?; - let storage = - &storage[layout.start_offset()..layout.start_offset() + src_shape.elem_count()]; + #[allow(clippy::infallible_destructuring_match)] + let self_storage = match &self.storage { + QStorage::Cpu(storage) => storage, + #[cfg(feature = "metal")] + _ => crate::bail!("Invalid storage"), + }; + let slice = storage.as_slice::()?; + let slice = &slice[layout.start_offset()..layout.start_offset() + src_shape.elem_count()]; let mut dst_storage = vec![0f32; dst_shape.elem_count()]; - self.matmul_t( - (dst_shape.elem_count() / n, k, n), - storage, - &mut dst_storage, - )?; + self_storage.matmul_t((dst_shape.elem_count() / n, k, n), slice, &mut dst_storage)?; Ok((crate::CpuStorage::F32(dst_storage), dst_shape)) } + + #[cfg(feature = "metal")] + fn metal_fwd( + &self, + storage: &crate::MetalStorage, + layout: &crate::Layout, + ) -> Result<(crate::MetalStorage, Shape)> { + use crate::MetalError; + + if !layout.is_contiguous() { + crate::bail!("input tensor is not contiguous {layout:?}") + } + let src_shape = layout.shape(); + // self is transposed so n is first then k. + if src_shape.rank() < 2 { + crate::bail!("input tensor has only one dimension {layout:?}") + } + let (n, k) = self.shape.dims2()?; + let mut dst_shape = src_shape.dims().to_vec(); + + let (b, m) = match dst_shape.len() { + 3 => (dst_shape[0], dst_shape[1]), + 2 => (1, dst_shape[0]), + n => crate::bail!("Invalid rank {n} for quantized matmul metal"), + }; + let last_k = dst_shape.pop().unwrap(); + if last_k != k { + crate::bail!("input tensor {layout:?} incompatible with {:?}", self.shape) + } + dst_shape.push(n); + let dst_shape = Shape::from(dst_shape); + let device = storage.device().clone(); + let dst = device.new_buffer(dst_shape.elem_count(), DType::F32, "qmatmul")?; + let (buffer, dtype) = match &self.storage { + QStorage::Metal(metal) => (metal.buffer(), metal.dtype()), + _ => unreachable!("Cannot call metal matmul on non metal QTensor"), + }; + let command_buffer = device.command_buffer()?; + candle_metal_kernels::call_quantized_matmul_t( + device.device(), + &command_buffer, + device.kernels(), + dtype.into(), + (b, m, n, k), + storage.buffer(), + layout.start_offset() * storage.dtype().size_in_bytes(), + buffer, + &dst, + ) + .map_err(MetalError::from)?; + let dst_storage = crate::MetalStorage::new(dst, device, DType::F32); + Ok((dst_storage, dst_shape)) + } +} + +#[cfg(feature = "metal")] +impl From for candle_metal_kernels::GgmlDType { + fn from(value: GgmlDType) -> Self { + match value { + GgmlDType::Q4_0 => candle_metal_kernels::GgmlDType::Q4_0, + GgmlDType::Q4_1 => candle_metal_kernels::GgmlDType::Q4_1, + GgmlDType::Q5_0 => candle_metal_kernels::GgmlDType::Q5_0, + GgmlDType::Q5_1 => candle_metal_kernels::GgmlDType::Q5_1, + GgmlDType::Q8_0 => candle_metal_kernels::GgmlDType::Q8_0, + GgmlDType::Q8_1 => candle_metal_kernels::GgmlDType::Q8_1, + GgmlDType::Q2K => candle_metal_kernels::GgmlDType::Q2K, + GgmlDType::Q3K => candle_metal_kernels::GgmlDType::Q3K, + GgmlDType::Q4K => candle_metal_kernels::GgmlDType::Q4K, + GgmlDType::Q5K => candle_metal_kernels::GgmlDType::Q5K, + GgmlDType::Q6K => candle_metal_kernels::GgmlDType::Q6K, + GgmlDType::Q8K => candle_metal_kernels::GgmlDType::Q8K, + GgmlDType::F16 => candle_metal_kernels::GgmlDType::F16, + GgmlDType::F32 => candle_metal_kernels::GgmlDType::F32, + } + } } impl crate::Module for QMatMul { diff --git a/candle-core/tests/quantized_tests.rs b/candle-core/tests/quantized_tests.rs index d31e77a7de..a7811ca5a4 100644 --- a/candle-core/tests/quantized_tests.rs +++ b/candle-core/tests/quantized_tests.rs @@ -1,6 +1,7 @@ use candle_core::{ bail, quantized::{self, GgmlDType}, + test_device, test_utils::to_vec2_round, Device, Module, Result, Tensor, }; @@ -14,16 +15,48 @@ const GGML_MAX_QUANTIZATION_TOTAL_ERROR_2BITS: f32 = 0.0075; const GGML_MAX_QUANTIZATION_TOTAL_ERROR_3BITS: f32 = 0.0040; const GGML_MAX_DOT_PRODUCT_ERROR: f32 = 0.02; -#[test] -fn quantized_matmul() -> Result<()> { - let cpu = &Device::Cpu; +fn test_matmul( + device: &Device, + (b, m, n, k): (usize, usize, usize, usize), + dtype: GgmlDType, +) -> Result<()> { + let lhs = (0..(m * k)) + .map(|v| v as f32 / (m * k) as f32) + .collect::>(); + let rhs = (0..(k * n)) + .map(|v| v as f32 / (n * k) as f32) + .collect::>(); + + let lhs = Tensor::from_slice(&lhs, (m, k), device)?; + let rhs = Tensor::from_slice(&rhs, (k, n), device)?; + let mm = lhs.matmul(&rhs)?; + let qtensor = quantized::QTensor::quantize(&rhs.t()?, dtype)?; + let matmul = quantized::QMatMul::from_qtensor(qtensor)?; + let res = matmul.forward(&lhs)?; + + let error: f32 = ((&mm - &res)?.abs()? / &mm.abs()?)? + .sum_all()? + .to_scalar()?; + let error = error / (b * m * n) as f32; + assert!( + error <= 0.02, + "Error {error} is too big. \nExpected:\n {mm} \nFound:\n {res}\n for {dtype:?}" + ); + + Ok(()) +} + +fn quantized_matmul(device: &Device) -> Result<()> { + // TODO Enable this later when we enable cuda. + if device.is_cuda() { + return Ok(()); + } let (m, k, n) = (3, 64, 4); let lhs = (0..(m * k)).map(|v| v as f32).collect::>(); - let tensor_lhs = Tensor::from_slice(&lhs, (m, k), cpu)?; + let tensor_lhs = Tensor::from_slice(&lhs, (m, k), device)?; let mut dst = vec![42.; 3 * 4]; let mut rhs_t = vec![k_quants::BlockQ4_0::zeros(); 8]; let rhs = (0..(k * n)).map(|v| v as f32).collect::>(); - let tensor_rhs = Tensor::from_slice(&rhs, (n, k), cpu)?.t()?; k_quants::BlockQ4_0::from_float(&rhs, &mut rhs_t)?; k_quants::matmul((m, k, n), &lhs, &rhs_t, &mut dst)?; assert_eq!( @@ -33,6 +66,7 @@ fn quantized_matmul() -> Result<()> { 341876.0, 994283.0, 1655709.0, 2301518.0 ] ); + let tensor_rhs = Tensor::from_slice(&rhs, (n, k), device)?.t()?; let mm = tensor_lhs.matmul(&tensor_rhs)?; assert_eq!( mm.to_vec2::()?, @@ -43,35 +77,49 @@ fn quantized_matmul() -> Result<()> { ] ); - let qtensor = quantized::QTensor::new(rhs_t, (4, 64))?; + let qtensor = quantized::QTensor::quantize(&tensor_rhs.t()?, GgmlDType::Q4_0)?; let matmul = quantized::QMatMul::from_qtensor(qtensor)?; let res = matmul.forward(&tensor_lhs)?; - assert_eq!( - to_vec2_round(&res, 0)?, - &[ - [85120.0, 214562.0, 345455.0, 474748.0], - [213475.0, 604465.0, 1000686.0, 1388317.0], - [341876.0, 994283.0, 1655709.0, 2301518.0] - ] - ); + match device { + Device::Metal(_) => assert_eq!( + to_vec2_round(&res, 0)?, + &[ + [84946.0, 214126.0, 344757.0, 473798.0], + [213458.0, 604350.0, 1000469.0, 1387990.0], + [341970.0, 994574.0, 1656181.0, 2302182.0] + ] + ), + _ => assert_eq!( + to_vec2_round(&res, 0)?, + &[ + [85120.0, 214562.0, 345455.0, 474748.0], + [213475.0, 604465.0, 1000686.0, 1388317.0], + [341876.0, 994283.0, 1655709.0, 2301518.0] + ] + ), + } + + test_matmul(device, (1, 3, 4, 256), GgmlDType::Q4_0)?; Ok(()) } -#[test] -fn quantized_matmul_neg() -> Result<()> { - let cpu = &Device::Cpu; +fn quantized_matmul_neg(device: &Device) -> Result<()> { + // TODO Enable this later when we enable cuda. + if device.is_cuda() { + return Ok(()); + } let (m, k, n) = (3, 64, 4); let lhs = (0..(m * k)) .map(|v| v as f32 - (m * k) as f32 / 2.0) .collect::>(); - let tensor_lhs = Tensor::from_slice(&lhs, (m, k), cpu)?; + let tensor_lhs = Tensor::from_slice(&lhs, (m, k), device)?; let mut dst = vec![42.; 3 * 4]; let mut rhs_t = vec![k_quants::BlockQ4_0::zeros(); 8]; let rhs = (0..k * n) .map(|v| v as f32 - (k * n) as f32 / 3.0) .collect::>(); - let tensor_rhs = Tensor::from_slice(&rhs, (n, k), cpu)?.t()?; + let tensor_rhs = Tensor::from_slice(&rhs, (n, k), device)?.t()?; k_quants::BlockQ4_0::from_float(&rhs, &mut rhs_t)?; k_quants::matmul((m, k, n), &lhs, &rhs_t, &mut dst)?; assert_eq!( @@ -91,32 +139,56 @@ fn quantized_matmul_neg() -> Result<()> { ] ); - let qtensor = quantized::QTensor::new(rhs_t, (4, 64))?; + let qtensor = quantized::QTensor::quantize(&tensor_rhs.t()?, GgmlDType::Q4_0)?; let matmul = quantized::QMatMul::from_qtensor(qtensor)?; let res = matmul.forward(&tensor_lhs)?; - assert_eq!( - to_vec2_round(&res, 0)?, - &[ - [243524.0, -19596.0, -285051.0, -549815.0], - [23777.0, 21651.0, 19398.0, 18367.0], - [-196472.0, 63012.0, 324585.0, 587902.0] - ] - ); + match device { + Device::Metal(_) => assert_eq!( + to_vec2_round(&res, 0)?, + &[ + [243666.0, -19714.0, -285433.0, -550453.0], + [23782.0, 21654.0, 19400.0, 18369.0], + [-196102.0, 63022.0, 324233.0, 587191.0] + ] + ), + _ => assert_eq!( + to_vec2_round(&res, 0)?, + &[ + [243524.0, -19596.0, -285051.0, -549815.0], + [23777.0, 21651.0, 19398.0, 18367.0], + [-196472.0, 63012.0, 324585.0, 587902.0] + ] + ), + } Ok(()) } -#[test] -fn quantize_q4_0() -> Result<()> { - use k_quants::BlockQ4_0; - +test_device!( + quantized_matmul, + quantized_matmul_cpu, + quantized_matmul_cuda, + quantized_matmul_metal +); +test_device!( + quantized_matmul_neg, + quantized_matmul_neg_cpu, + quantized_matmul_neg_cuda, + quantized_matmul_neg_metal +); + +fn quantize_q4_0(device: &Device) -> Result<()> { + // TODO Enable this later when we enable cuda. + if device.is_cuda() { + return Ok(()); + } let src = (0..32 * 4).map(|v| v as f32).collect::>(); - let mut dst = vec![0f32; 32 * 4]; - let mut quant = vec![BlockQ4_0::zeros(); 4]; - BlockQ4_0::from_float(&src, &mut quant)?; - BlockQ4_0::to_float(&quant, dst.as_mut_slice())?; + + let src = Tensor::from_slice(&src, (32 * 4,), device)?; + let quant = quantized::QTensor::quantize(&src, GgmlDType::Q4_0)?; + let dst = quant.dequantize(device)?; assert_eq!( - dst, + dst.to_vec1::()?, &[ -0.0, -0.0, 3.875, 3.875, 3.875, 3.875, 7.75, 7.75, 7.75, 7.75, 11.625, 11.625, 11.625, 11.625, 15.5, 15.5, 15.5, 15.5, 19.375, 19.375, 19.375, 19.375, 23.25, 23.25, 23.25, @@ -132,21 +204,21 @@ fn quantize_q4_0() -> Result<()> { 127.0, 127.0 ] ); - ggml_quantization_error_test::(GGML_MAX_QUANTIZATION_TOTAL_ERROR)?; + ggml_quantization_error_test(GgmlDType::Q4_0, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR)?; Ok(()) } -#[test] -fn quantize_q4_1() -> Result<()> { - use k_quants::BlockQ4_1; - +fn quantize_q4_1(device: &Device) -> Result<()> { + // TODO Enable this later when we enable cuda. + if device.is_cuda() { + return Ok(()); + } let src = (0..32 * 4).map(|v| v as f32).collect::>(); - let mut dst = vec![0f32; 32 * 4]; - let mut quant = vec![BlockQ4_1::zeros(); 4]; - BlockQ4_1::from_float(&src, &mut quant)?; - BlockQ4_1::to_float(&quant, dst.as_mut_slice())?; + let src = Tensor::from_slice(&src, (32 * 4,), device)?; + let quant = quantized::QTensor::quantize(&src, GgmlDType::Q4_1)?; + let dst = quant.dequantize(device)?; assert_eq!( - round_vector(&dst), + round_vector(&dst.to_vec1::()?), &[ 0.0, 0.0, 2.066, 2.066, 4.133, 4.133, 6.199, 6.199, 8.266, 8.266, 10.332, 10.332, 12.398, 12.398, 14.465, 14.465, 16.531, 16.531, 18.598, 18.598, 20.664, 20.664, 22.73, @@ -162,21 +234,21 @@ fn quantize_q4_1() -> Result<()> { 118.73, 118.73, 120.797, 120.797, 122.863, 122.863, 124.93, 124.93, 126.996, 126.996 ] ); - ggml_quantization_error_test::(GGML_MAX_QUANTIZATION_TOTAL_ERROR)?; + ggml_quantization_error_test(GgmlDType::Q4_1, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR)?; Ok(()) } -#[test] -fn quantize_q5_0() -> Result<()> { - use k_quants::BlockQ5_0; - +fn quantize_q5_0(device: &Device) -> Result<()> { + // TODO Enable this later when we enable cuda. + if device.is_cuda() { + return Ok(()); + } let src = (0..32 * 4).map(|v| v as f32).collect::>(); - let mut dst = vec![0f32; 32 * 4]; - let mut quant = vec![BlockQ5_0::zeros(); 4]; - BlockQ5_0::from_float(&src, &mut quant)?; - BlockQ5_0::to_float(&quant, dst.as_mut_slice())?; + let src = Tensor::from_slice(&src, (32 * 4,), device)?; + let quant = quantized::QTensor::quantize(&src, GgmlDType::Q5_0)?; + let dst = quant.dequantize(device)?; assert_eq!( - round_vector(&dst), + round_vector(&dst.to_vec1::()?), &[ -0.0, 1.938, 1.938, 3.875, 3.875, 5.813, 5.813, 7.75, 7.75, 9.688, 9.688, 11.625, 11.625, 13.563, 13.563, 15.5, 15.5, 17.438, 17.438, 19.375, 19.375, 21.313, 21.313, @@ -192,21 +264,21 @@ fn quantize_q5_0() -> Result<()> { 119.063, 119.063, 119.063, 119.063, 127.0, 127.0, 127.0, 127.0 ] ); - ggml_quantization_error_test::(GGML_MAX_QUANTIZATION_TOTAL_ERROR)?; + ggml_quantization_error_test(GgmlDType::Q5_0, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR)?; Ok(()) } -#[test] -fn quantize_q5_1() -> Result<()> { - use k_quants::BlockQ5_1; - +fn quantize_q5_1(device: &Device) -> Result<()> { + // TODO Enable this later when we enable cuda. + if device.is_cuda() { + return Ok(()); + } let src = (0..32 * 4).map(|v| v as f32).collect::>(); - let mut dst = vec![0f32; 32 * 4]; - let mut quant = vec![BlockQ5_1::zeros(); 4]; - BlockQ5_1::from_float(&src, &mut quant)?; - BlockQ5_1::to_float(&quant, dst.as_mut_slice())?; + let src = Tensor::from_slice(&src, (32 * 4,), device)?; + let quant = quantized::QTensor::quantize(&src, GgmlDType::Q5_1)?; + let dst = quant.dequantize(device)?; assert_eq!( - dst, + round_vector(&dst.to_vec1::()?), &[ 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, @@ -220,13 +292,11 @@ fn quantize_q5_1() -> Result<()> { 124.0, 125.0, 126.0, 127.0 ] ); - - ggml_quantization_error_test::(GGML_MAX_QUANTIZATION_TOTAL_ERROR)?; + ggml_quantization_error_test(GgmlDType::Q5_1, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR)?; Ok(()) } -/// Generates a small test vector ranging from -`bound` to `bound` with `size` steps -fn get_test_vector(bound: f32, size: usize) -> (Vec, Vec) { +fn get_test_vector2(bound: f32, size: usize, device: &Device) -> Result { assert!( size % crate::quantized::k_quants::QK_K == 0, "size must be a multiple of {}", @@ -236,10 +306,8 @@ fn get_test_vector(bound: f32, size: usize) -> (Vec, Vec) { let src = (0..size) .map(|v| (v as f32 - size as f32 / 2.) * bound / (size as f32 / 2.)) .collect::>(); - - let dst = vec![0f32; size]; assert_eq!([src[0], src[size / 2]], [-bound, 0.0]); - (src, dst) + Tensor::from_vec(src, (size,), device) } /// Round a vector @@ -288,11 +356,12 @@ fn calculate_rmse(a: &[f32], b: &[f32]) -> f32 { /// Similar to the GGML quantization unit test: /// https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L43-L50 -fn ggml_quantization_error_test(max_error: f32) -> Result<()> { +fn ggml_quantization_error_test(dtype: GgmlDType, device: &Device, max_error: f32) -> Result<()> { let src = create_ggml_like_vector(0.0); - let mut dst = vec![0.0; GGML_TEST_SIZE]; - let _quant = quantize_roundtrip::(src.as_slice(), dst.as_mut_slice())?; - let error = calculate_rmse(src.as_slice(), dst.as_slice()); + let src = Tensor::from_slice(&src, (GGML_TEST_SIZE,), device)?; + let quant = quantized::QTensor::quantize(&src, dtype)?; + let dst = quant.dequantize(device)?; + let error = calculate_rmse(&src.to_vec1::()?, &dst.to_vec1::()?); if error > max_error { bail!( "Quantization error {} exceeds max error {}", @@ -303,19 +372,19 @@ fn ggml_quantization_error_test(max_error: f32) -> Result<()> { Ok(()) } -fn quantize_roundtrip(src: &[f32], dst: &mut [f32]) -> Result> { - let mut quant = vec![T::zeros(); src.len() / T::BLCK_SIZE]; - T::from_float(src, &mut quant)?; - T::to_float(&quant, dst)?; - Ok(quant) -} +fn quantize_q2k(device: &Device) -> Result<()> { + // TODO Enable this later when we enable cuda. + if device.is_cuda() { + return Ok(()); + } + let dtype = GgmlDType::Q2K; -#[test] -fn quantize_q2k() -> Result<()> { - use k_quants::BlockQ2K; + let src = get_test_vector2(0.5, 1024, device)?; + let quant = quantized::QTensor::quantize(&src, dtype)?; + let dst = quant.dequantize(device)?; - let (src, mut dst) = get_test_vector(0.5, 1024); - let _quant = quantize_roundtrip::(src.as_slice(), dst.as_mut_slice())?; + let src = src.to_vec1::()?; + let dst = dst.to_vec1::()?; compare_with_error(dst.as_slice(), src.as_slice(), 0.1); // Test some specific values @@ -329,20 +398,30 @@ fn quantize_q2k() -> Result<()> { [-0.499, -0.366, -0.249, 0.0, 0.295, 0.492] ); - let (src_big, mut dst_big) = get_test_vector(128.0, 1024); - let _quant_big = quantize_roundtrip::(src_big.as_slice(), dst_big.as_mut_slice())?; + let src_big = get_test_vector2(128.0, 1024, device)?; + let quant_big = quantized::QTensor::quantize(&src_big, dtype)?; + let dst_big = quant_big.dequantize(device)?; + + let src_big = src_big.to_vec1::()?; + let dst_big = dst_big.to_vec1::()?; compare_with_error(dst_big.as_slice(), src_big.as_slice(), 6.0); - ggml_quantization_error_test::(GGML_MAX_QUANTIZATION_TOTAL_ERROR_2BITS)?; + ggml_quantization_error_test(dtype, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR_2BITS)?; Ok(()) } -#[test] -fn quantize_q3k() -> Result<()> { - use k_quants::BlockQ3K; +fn quantize_q3k(device: &Device) -> Result<()> { + // TODO Enable this later when we enable cuda. + if device.is_cuda() { + return Ok(()); + } + let dtype = GgmlDType::Q3K; + let src = get_test_vector2(0.5, 1024, device)?; + let quant = quantized::QTensor::quantize(&src, dtype)?; + let dst = quant.dequantize(device)?; - let (src, mut dst) = get_test_vector(0.5, 1024); - let _quant = quantize_roundtrip::(src.as_slice(), dst.as_mut_slice())?; + let src = src.to_vec1::()?; + let dst = dst.to_vec1::()?; compare_with_error(dst.as_slice(), src.as_slice(), 0.03); // Test some specific values @@ -356,20 +435,30 @@ fn quantize_q3k() -> Result<()> { [-0.493, -0.37, -0.243, -0.0, 0.292, 0.492] ); - let (src_big, mut dst_big) = get_test_vector(128.0, 1024); - let _quant_big = quantize_roundtrip::(src_big.as_slice(), dst_big.as_mut_slice())?; + let src_big = get_test_vector2(128.0, 1024, device)?; + let quant_big = quantized::QTensor::quantize(&src_big, dtype)?; + let dst_big = quant_big.dequantize(device)?; + + let src_big = src_big.to_vec1::()?; + let dst_big = dst_big.to_vec1::()?; compare_with_error(dst_big.as_slice(), src_big.as_slice(), 3.5); - ggml_quantization_error_test::(GGML_MAX_QUANTIZATION_TOTAL_ERROR_3BITS)?; + ggml_quantization_error_test(dtype, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR_3BITS)?; Ok(()) } -#[test] -fn quantize_q4k() -> Result<()> { - use k_quants::BlockQ4K; +fn quantize_q4k(device: &Device) -> Result<()> { + // TODO Enable this later when we enable cuda. + if device.is_cuda() { + return Ok(()); + } + let dtype = GgmlDType::Q4K; + let src = get_test_vector2(0.5, 1024, device)?; + let quant = quantized::QTensor::quantize(&src, dtype)?; + let dst = quant.dequantize(device)?; - let (src, mut dst) = get_test_vector(0.5, 1024); - let _quant = quantize_roundtrip::(src.as_slice(), dst.as_mut_slice())?; + let src = src.to_vec1::()?; + let dst = dst.to_vec1::()?; compare_with_error(dst.as_slice(), src.as_slice(), 0.017); // Test some specific values @@ -383,21 +472,31 @@ fn quantize_q4k() -> Result<()> { [-0.5, -0.373, -0.25, 0.0, 0.288, 0.498] ); - let (src_big, mut dst_big) = get_test_vector(128.0, 1024); - let _quant_big = quantize_roundtrip::(src_big.as_slice(), dst_big.as_mut_slice())?; + let src_big = get_test_vector2(128.0, 1024, device)?; + let quant_big = quantized::QTensor::quantize(&src_big, dtype)?; + let dst_big = quant_big.dequantize(device)?; + + let src_big = src_big.to_vec1::()?; + let dst_big = dst_big.to_vec1::()?; compare_with_error(dst_big.as_slice(), src_big.as_slice(), 4.5); - ggml_quantization_error_test::(GGML_MAX_QUANTIZATION_TOTAL_ERROR)?; + ggml_quantization_error_test(dtype, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR)?; Ok(()) } -#[test] -fn quantize_q5k() -> Result<()> { - use k_quants::BlockQ5K; +fn quantize_q5k(device: &Device) -> Result<()> { + // TODO Enable this later when we enable cuda. + if device.is_cuda() { + return Ok(()); + } + let dtype = GgmlDType::Q5K; + let src = get_test_vector2(0.5, 1024, device)?; + let quant = quantized::QTensor::quantize(&src, dtype)?; + let dst = quant.dequantize(device)?; - let (src, mut dst) = get_test_vector(0.5, 1024); - let _quant = quantize_roundtrip::(src.as_slice(), dst.as_mut_slice())?; - compare_with_error(dst.as_slice(), src.as_slice(), 0.008); + let src = src.to_vec1::()?; + let dst = dst.to_vec1::()?; + compare_with_error(dst.as_slice(), src.as_slice(), 0.009); // Test some specific values assert_eq!( @@ -410,21 +509,30 @@ fn quantize_q5k() -> Result<()> { [-0.5, -0.373, -0.25, 0.0, 0.279, 0.499] ); - let (src_big, mut dst_big) = get_test_vector(128.0, 1024); - let _quant_big = quantize_roundtrip::(src_big.as_slice(), dst_big.as_mut_slice())?; - compare_with_error(dst_big.as_slice(), src_big.as_slice(), 2.5); + let src_big = get_test_vector2(128.0, 1024, device)?; + let quant_big = quantized::QTensor::quantize(&src_big, dtype)?; + let dst_big = quant_big.dequantize(device)?; - ggml_quantization_error_test::(GGML_MAX_QUANTIZATION_TOTAL_ERROR)?; + let src_big = src_big.to_vec1::()?; + let dst_big = dst_big.to_vec1::()?; + compare_with_error(dst_big.as_slice(), src_big.as_slice(), 2.5); + ggml_quantization_error_test(dtype, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR)?; Ok(()) } -#[test] -fn quantize_q6k() -> Result<()> { - use k_quants::BlockQ6K; +fn quantize_q6k(device: &Device) -> Result<()> { + // TODO Enable this later when we enable cuda. + if device.is_cuda() { + return Ok(()); + } + let dtype = GgmlDType::Q6K; + let src = get_test_vector2(0.5, 1024, device)?; + let quant = quantized::QTensor::quantize(&src, dtype)?; + let dst = quant.dequantize(device)?; - let (src, mut dst) = get_test_vector(0.5, 1024); - let _quant = quantize_roundtrip::(src.as_slice(), dst.as_mut_slice())?; + let src = src.to_vec1::()?; + let dst = dst.to_vec1::()?; compare_with_error(dst.as_slice(), src.as_slice(), 0.008); // Test some specific values @@ -438,22 +546,31 @@ fn quantize_q6k() -> Result<()> { [-0.497, -0.372, -0.25, -0.0, 0.284, 0.5] ); - let (src_big, mut dst_big) = get_test_vector(128.0, 1024); - let _quant_big = quantize_roundtrip::(src_big.as_slice(), dst_big.as_mut_slice())?; - compare_with_error(dst_big.as_slice(), src_big.as_slice(), 2.0); + let src_big = get_test_vector2(128.0, 1024, device)?; + let quant_big = quantized::QTensor::quantize(&src_big, dtype)?; + let dst_big = quant_big.dequantize(device)?; - ggml_quantization_error_test::(GGML_MAX_QUANTIZATION_TOTAL_ERROR)?; + let src_big = src_big.to_vec1::()?; + let dst_big = dst_big.to_vec1::()?; + compare_with_error(dst_big.as_slice(), src_big.as_slice(), 2.0); + ggml_quantization_error_test(dtype, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR)?; Ok(()) } -#[test] -fn quantize_q8k() -> Result<()> { - use k_quants::BlockQ8K; +fn quantize_q8k(device: &Device) -> Result<()> { + // TODO Enable this later when we enable cuda. + if device.is_cuda() { + return Ok(()); + } + let dtype = GgmlDType::Q8K; + let src = get_test_vector2(0.5, 1024, device)?; + let quant = quantized::QTensor::quantize(&src, dtype)?; + let dst = quant.dequantize(device)?; - let (src, mut dst) = get_test_vector(0.5, 1024); - let _quant = quantize_roundtrip::(src.as_slice(), dst.as_mut_slice())?; - compare_with_error(dst.as_slice(), src.as_slice(), 0.003); + let src = src.to_vec1::()?; + let dst = dst.to_vec1::()?; + compare_with_error(dst.as_slice(), src.as_slice(), 0.008); // Test some specific values assert_eq!( @@ -466,15 +583,79 @@ fn quantize_q8k() -> Result<()> { [-0.5, -0.375, -0.25, -0.0, 0.281, 0.499] ); - let (src_big, mut dst_big) = get_test_vector(128.0, 1024); - let _quant_big = quantize_roundtrip::(src_big.as_slice(), dst_big.as_mut_slice())?; - compare_with_error(dst_big.as_slice(), src_big.as_slice(), 0.6); + let src_big = get_test_vector2(128.0, 1024, device)?; + let quant_big = quantized::QTensor::quantize(&src_big, dtype)?; + let dst_big = quant_big.dequantize(device)?; - ggml_quantization_error_test::(GGML_MAX_QUANTIZATION_TOTAL_ERROR)?; + let src_big = src_big.to_vec1::()?; + let dst_big = dst_big.to_vec1::()?; + compare_with_error(dst_big.as_slice(), src_big.as_slice(), 0.6); + ggml_quantization_error_test(dtype, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR)?; Ok(()) } +test_device!( + quantize_q4_0, + quantize_q4_0_cpu, + quantize_q4_0_cuda, + quantize_q4_0_metal +); +test_device!( + quantize_q4_1, + quantize_q4_1_cpu, + quantize_q4_1_cuda, + quantize_q4_1_metal +); +test_device!( + quantize_q5_0, + quantize_q5_0_cpu, + quantize_q5_0_cuda, + quantize_q5_0_metal +); +test_device!( + quantize_q5_1, + quantize_q5_1_cpu, + quantize_q5_1_cuda, + quantize_q5_1_metal +); +test_device!( + quantize_q2k, + quantize_q2k_cpu, + quantize_q2k_cuda, + quantize_q2k_metal +); +test_device!( + quantize_q3k, + quantize_q3k_cpu, + quantize_q3k_cuda, + quantize_q3k_metal +); +test_device!( + quantize_q4k, + quantize_q4k_cpu, + quantize_q4k_cuda, + quantize_q4k_metal +); +test_device!( + quantize_q5k, + quantize_q5k_cpu, + quantize_q5k_cuda, + quantize_q5k_metal +); +test_device!( + quantize_q6k, + quantize_q6k_cpu, + quantize_q6k_cuda, + quantize_q6k_metal +); +test_device!( + quantize_q8k, + quantize_q8k_cpu, + quantize_q8k_cuda, + quantize_q8k_metal +); + /// Very simple dot product implementation fn vec_dot_reference(a: &[f32], b: &[f32]) -> f32 { a.iter().zip(b).map(|(a, b)| a * b).sum() @@ -591,6 +772,112 @@ fn get_random_tensors( Ok((lhs, rhs, mm)) } +#[macro_export] +macro_rules! quantized_matmul { + // TODO: Switch to generating the two last arguments automatically once concat_idents is + // stable. https://github.com/rust-lang/rust/issues/29599 + ($fn_name: ident, $fn_name_cpu: ident, $fn_name_cuda: ident, $fn_name_metal: ident, $dtype: expr) => { + fn $fn_name(device: &Device) -> Result<()> { + if device.is_cuda() { + // TODO Enable Cuda GGML sometime maybe. + return Ok(()); + } + test_matmul(device, (1, 3, 4, 256), $dtype)?; + Ok(()) + } + + test_device!($fn_name, $fn_name_cpu, $fn_name_cuda, $fn_name_metal); + }; +} + +quantized_matmul!( + quantized_matmul_q4_0_bis, + quantized_matmul_q4_0_cpu, + quantized_matmul_q4_0_cuda, + quantized_matmul_q4_0_metal, + GgmlDType::Q4_0 +); +quantized_matmul!( + quantized_matmul_q4_1_bis, + quantized_matmul_q4_1_cpu, + quantized_matmul_q4_1_cuda, + quantized_matmul_q4_1_metal, + GgmlDType::Q4_1 +); +quantized_matmul!( + quantized_matmul_q5_0_bis, + quantized_matmul_q5_0_cpu, + quantized_matmul_q5_0_cuda, + quantized_matmul_q5_0_metal, + GgmlDType::Q5_0 +); +quantized_matmul!( + quantized_matmul_q5_1_bis, + quantized_matmul_q5_1_cpu, + quantized_matmul_q5_1_cuda, + quantized_matmul_q5_1_metal, + GgmlDType::Q5_1 +); +quantized_matmul!( + quantized_matmul_q8_0_bis, + quantized_matmul_q8_0_cpu, + quantized_matmul_q8_0_cuda, + quantized_matmul_q8_0_metal, + GgmlDType::Q8_0 +); +// Not implemented in Ggml +// quantized_matmul!( +// quantized_matmul_q8_1_bis, +// quantized_matmul_q8_1_cpu, +// quantized_matmul_q8_1_cuda, +// quantized_matmul_q8_1_metal, +// GgmlDType::Q8_1 +// ); +// TODO This is bugged (also bugged in GGML +quantized_matmul!( + quantized_matmul_q2k_bis, + quantized_matmul_q2k_cpu, + quantized_matmul_q2k_cuda, + quantized_matmul_q2k_metal, + GgmlDType::Q2K +); +quantized_matmul!( + quantized_matmul_q3k_bis, + quantized_matmul_q3k_cpu, + quantized_matmul_q3k_cuda, + quantized_matmul_q3k_metal, + GgmlDType::Q3K +); +quantized_matmul!( + quantized_matmul_q4k_bis, + quantized_matmul_q4k_cpu, + quantized_matmul_q4k_cuda, + quantized_matmul_q4k_metal, + GgmlDType::Q4K +); +quantized_matmul!( + quantized_matmul_q5k_bis, + quantized_matmul_q5k_cpu, + quantized_matmul_q5k_cuda, + quantized_matmul_q5k_metal, + GgmlDType::Q5K +); +quantized_matmul!( + quantized_matmul_q6k_bis, + quantized_matmul_q6k_cpu, + quantized_matmul_q6k_cuda, + quantized_matmul_q6k_metal, + GgmlDType::Q6K +); +// Not implemented on metal +// quantized_matmul!( +// quantized_matmul_q8k_bis, +// quantized_matmul_q8k_cpu, +// quantized_matmul_q8k_cuda, +// quantized_matmul_q8k_metal, +// GgmlDType::Q8K +// ); + #[test] fn quantized_matmul_q2k() -> Result<()> { use k_quants::BlockQ2K; @@ -603,7 +890,7 @@ fn quantized_matmul_q2k() -> Result<()> { let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]); assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]); - let rhs = quantized::QTensor::quantize::(&rhs)?; + let rhs = quantized::QTensor::quantize(&rhs, GgmlDType::Q2K)?; let rhs = quantized::QMatMul::from_qtensor(rhs)?; let mm = rhs.forward(&lhs)?; @@ -629,7 +916,7 @@ fn quantized_matmul_q3k() -> Result<()> { let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]); assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]); - let rhs = quantized::QTensor::quantize::(&rhs)?; + let rhs = quantized::QTensor::quantize(&rhs, GgmlDType::Q3K)?; let rhs = quantized::QMatMul::from_qtensor(rhs)?; let mm = rhs.forward(&lhs)?; @@ -655,7 +942,7 @@ fn quantized_matmul_q4k() -> Result<()> { let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]); assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]); - let rhs = quantized::QTensor::quantize::(&rhs)?; + let rhs = quantized::QTensor::quantize(&rhs, GgmlDType::Q4K)?; let rhs = quantized::QMatMul::from_qtensor(rhs)?; let mm = rhs.forward(&lhs)?; @@ -681,7 +968,7 @@ fn quantized_matmul_q5k() -> Result<()> { let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]); assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]); - let rhs = quantized::QTensor::quantize::(&rhs)?; + let rhs = quantized::QTensor::quantize(&rhs, GgmlDType::Q5K)?; let rhs = quantized::QMatMul::from_qtensor(rhs)?; let mm = rhs.forward(&lhs)?; @@ -708,7 +995,7 @@ fn quantized_matmul_q6k() -> Result<()> { let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]); assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]); - let rhs = quantized::QTensor::quantize::(&rhs)?; + let rhs = quantized::QTensor::quantize(&rhs, GgmlDType::Q6K)?; let rhs = quantized::QMatMul::from_qtensor(rhs)?; let mm = rhs.forward(&lhs)?; @@ -733,7 +1020,7 @@ fn quantized_matmul_q8k() -> Result<()> { let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]); assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]); - let rhs = quantized::QTensor::quantize::(&rhs)?; + let rhs = quantized::QTensor::quantize(&rhs, GgmlDType::Q8K)?; let rhs = quantized::QMatMul::from_qtensor(rhs)?; let mm = rhs.forward(&lhs)?; diff --git a/candle-examples/examples/blip/main.rs b/candle-examples/examples/blip/main.rs index a1051a8eaa..15e36476c9 100644 --- a/candle-examples/examples/blip/main.rs +++ b/candle-examples/examples/blip/main.rs @@ -106,17 +106,17 @@ pub fn main() -> anyhow::Result<()> { let config = blip::Config::image_captioning_large(); + let device = candle_examples::device(args.cpu)?; let (image_embeds, device, mut model) = if args.quantized { let device = Device::Cpu; let image = load_image(args.image)?.to_device(&device)?; println!("loaded image {image:?}"); - let vb = quantized_blip::VarBuilder::from_gguf(model_file)?; + let vb = quantized_blip::VarBuilder::from_gguf(model_file, &device)?; let model = quantized_blip::BlipForConditionalGeneration::new(&config, vb)?; let image_embeds = image.unsqueeze(0)?.apply(model.vision_model())?; (image_embeds, device, Model::Q(model)) } else { - let device = candle_examples::device(args.cpu)?; let image = load_image(args.image)?.to_device(&device)?; println!("loaded image {image:?}"); diff --git a/candle-examples/examples/llama2-c/main.rs b/candle-examples/examples/llama2-c/main.rs index 0ceb27af7e..9d42dcc822 100644 --- a/candle-examples/examples/llama2-c/main.rs +++ b/candle-examples/examples/llama2-c/main.rs @@ -262,7 +262,7 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> { .extension() .map_or(false, |v| v == "safetensors"); let (model, config) = if is_gguf { - let vb = qmodel::VarBuilder::from_gguf(config_path)?; + let vb = qmodel::VarBuilder::from_gguf(config_path, &device)?; let (_vocab_size, dim) = vb .get_no_shape("model.embed_tokens.weight")? .shape() @@ -279,13 +279,13 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> { (config.seq_len, config.head_size() / 2), "rot.freq_cis_real", )? - .dequantize(&candle::Device::Cpu)?; + .dequantize(&device)?; let freq_cis_imag = vb .get( (config.seq_len, config.head_size() / 2), "rot.freq_cis_imag", )? - .dequantize(&candle::Device::Cpu)?; + .dequantize(&device)?; let fake_vb = candle_nn::VarBuilder::from_tensors( [ @@ -295,7 +295,7 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> { .into_iter() .collect(), candle::DType::F32, - &candle::Device::Cpu, + &device, ); let cache = model::Cache::new(true, &config, fake_vb)?; let model = Model::QLlama(QLlama::load(vb, &cache, config.clone())?); diff --git a/candle-examples/examples/mistral/main.rs b/candle-examples/examples/mistral/main.rs index 5ed5e5cb06..bad860989e 100644 --- a/candle-examples/examples/mistral/main.rs +++ b/candle-examples/examples/mistral/main.rs @@ -244,13 +244,14 @@ fn main() -> Result<()> { let start = std::time::Instant::now(); let config = Config::config_7b_v0_1(args.use_flash_attn); + let device = candle_examples::device(args.cpu)?; let (model, device) = if args.quantized { let filename = &filenames[0]; - let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(filename)?; + let vb = + candle_transformers::quantized_var_builder::VarBuilder::from_gguf(filename, &device)?; let model = QMistral::new(&config, vb)?; - (Model::Quantized(model), Device::Cpu) + (Model::Quantized(model), device) } else { - let device = candle_examples::device(args.cpu)?; let dtype = if device.is_cuda() { DType::BF16 } else { diff --git a/candle-examples/examples/phi/main.rs b/candle-examples/examples/phi/main.rs index 69eed84ff2..39f4fd581b 100644 --- a/candle-examples/examples/phi/main.rs +++ b/candle-examples/examples/phi/main.rs @@ -307,18 +307,21 @@ fn main() -> Result<()> { WhichModel::PuffinPhiV2 => Config::puffin_phi_v2(), WhichModel::PhiHermes => Config::phi_hermes_1_3b(), }; - let (model, device) = if args.quantized { - let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&filenames[0])?; + let device = candle_examples::device(args.cpu)?; + let model = if args.quantized { let config = config(); + let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf( + &filenames[0], + &device, + )?; let model = match args.model { WhichModel::V2 | WhichModel::V2Old => QMixFormer::new_v2(&config, vb)?, _ => QMixFormer::new(&config, vb)?, }; - (Model::Quantized(model), Device::Cpu) + Model::Quantized(model) } else { - let device = candle_examples::device(args.cpu)?; let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? }; - let model = match args.model { + match args.model { WhichModel::V1 | WhichModel::V1_5 | WhichModel::V2 => { let config_filename = repo.get("config.json")?; let config = std::fs::read_to_string(config_filename)?; @@ -334,8 +337,7 @@ fn main() -> Result<()> { let config = config(); Model::MixFormer(MixFormer::new(&config, vb)?) } - }; - (model, device) + } }; println!("loaded the model in {:?}", start.elapsed()); diff --git a/candle-examples/examples/quantized-t5/main.rs b/candle-examples/examples/quantized-t5/main.rs index 0ea2e0bd32..ed3f1030a1 100644 --- a/candle-examples/examples/quantized-t5/main.rs +++ b/candle-examples/examples/quantized-t5/main.rs @@ -132,7 +132,8 @@ impl T5ModelBuilder { } pub fn build_model(&self) -> Result { - let vb = t5::VarBuilder::from_gguf(&self.weights_filename)?; + let device = Device::Cpu; + let vb = t5::VarBuilder::from_gguf(&self.weights_filename, &device)?; Ok(t5::T5ForConditionalGeneration::load(vb, &self.config)?) } diff --git a/candle-examples/examples/quantized/main.rs b/candle-examples/examples/quantized/main.rs index bfc6de53e7..34c442330a 100644 --- a/candle-examples/examples/quantized/main.rs +++ b/candle-examples/examples/quantized/main.rs @@ -9,7 +9,7 @@ use std::io::Write; use tokenizers::Tokenizer; use candle::quantized::{ggml_file, gguf_file}; -use candle::{Device, Tensor}; +use candle::Tensor; use candle_transformers::generation::LogitsProcessor; use candle_examples::token_output_stream::TokenOutputStream; @@ -361,6 +361,7 @@ fn main() -> anyhow::Result<()> { let model_path = args.model()?; let mut file = std::fs::File::open(&model_path)?; let start = std::time::Instant::now(); + let device = candle_examples::device(false)?; let mut model = match model_path.extension().and_then(|v| v.to_str()) { Some("gguf") => { @@ -369,7 +370,7 @@ fn main() -> anyhow::Result<()> { for (_, tensor) in model.tensor_infos.iter() { let elem_count = tensor.shape.elem_count(); total_size_in_bytes += - elem_count * tensor.ggml_dtype.type_size() / tensor.ggml_dtype.blck_size(); + elem_count * tensor.ggml_dtype.type_size() / tensor.ggml_dtype.block_size(); } println!( "loaded {:?} tensors ({}) in {:.2}s", @@ -377,15 +378,16 @@ fn main() -> anyhow::Result<()> { &format_size(total_size_in_bytes), start.elapsed().as_secs_f32(), ); - ModelWeights::from_gguf(model, &mut file)? + ModelWeights::from_gguf(model, &mut file, &device)? } Some("ggml" | "bin") | Some(_) | None => { - let model = ggml_file::Content::read(&mut file).map_err(|e| e.with_path(model_path))?; + let model = ggml_file::Content::read(&mut file, &device) + .map_err(|e| e.with_path(model_path))?; let mut total_size_in_bytes = 0; for (_, tensor) in model.tensors.iter() { let elem_count = tensor.shape().elem_count(); total_size_in_bytes += - elem_count * tensor.dtype().type_size() / tensor.dtype().blck_size(); + elem_count * tensor.dtype().type_size() / tensor.dtype().block_size(); } println!( "loaded {:?} tensors ({}) in {:.2}s", @@ -486,7 +488,7 @@ fn main() -> anyhow::Result<()> { let start_prompt_processing = std::time::Instant::now(); let mut next_token = { - let input = Tensor::new(prompt_tokens.as_slice(), &Device::Cpu)?.unsqueeze(0)?; + let input = Tensor::new(prompt_tokens.as_slice(), &device)?.unsqueeze(0)?; let logits = model.forward(&input, 0)?; let logits = logits.squeeze(0)?; logits_processor.sample(&logits)? @@ -507,7 +509,7 @@ fn main() -> anyhow::Result<()> { let start_post_prompt = std::time::Instant::now(); let mut sampled = 0; for index in 0..to_sample { - let input = Tensor::new(&[next_token], &Device::Cpu)?.unsqueeze(0)?; + let input = Tensor::new(&[next_token], &device)?.unsqueeze(0)?; let logits = model.forward(&input, prompt_tokens.len() + index)?; let logits = logits.squeeze(0)?; let logits = if args.repeat_penalty == 1. { diff --git a/candle-examples/examples/replit-code/main.rs b/candle-examples/examples/replit-code/main.rs index 0f72b86251..b7f767b96a 100644 --- a/candle-examples/examples/replit-code/main.rs +++ b/candle-examples/examples/replit-code/main.rs @@ -236,16 +236,15 @@ fn main() -> Result<()> { let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; let start = std::time::Instant::now(); + let device = candle_examples::device(args.cpu)?; let config = Config::replit_code_v1_5_3b(); - let (model, device) = if args.quantized { - let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&filename)?; - let model = Model::Q(Q::new(&config, vb.pp("transformer"))?); - (model, Device::Cpu) + let model = if args.quantized { + let vb = + candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&filename, &device)?; + Model::Q(Q::new(&config, vb.pp("transformer"))?) } else { - let device = candle_examples::device(args.cpu)?; let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[filename], DType::F32, &device)? }; - let model = Model::M(M::new(&config, vb.pp("transformer"))?); - (model, device) + Model::M(M::new(&config, vb.pp("transformer"))?) }; println!("loaded the model in {:?}", start.elapsed()); diff --git a/candle-examples/examples/stable-lm/main.rs b/candle-examples/examples/stable-lm/main.rs index 0535aa702b..ccd924a40f 100644 --- a/candle-examples/examples/stable-lm/main.rs +++ b/candle-examples/examples/stable-lm/main.rs @@ -234,13 +234,14 @@ fn main() -> Result<()> { let start = std::time::Instant::now(); let config = Config::stablelm_3b_4e1t(args.use_flash_attn); + let device = candle_examples::device(args.cpu)?; let (model, device) = if args.quantized { let filename = &filenames[0]; - let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(filename)?; + let vb = + candle_transformers::quantized_var_builder::VarBuilder::from_gguf(filename, &device)?; let model = QStableLM::new(&config, vb)?; (Model::Quantized(model), Device::Cpu) } else { - let device = candle_examples::device(args.cpu)?; let dtype = if device.is_cuda() { DType::BF16 } else { diff --git a/candle-examples/examples/whisper/main.rs b/candle-examples/examples/whisper/main.rs index 5be81f2dd2..6ea3461364 100644 --- a/candle-examples/examples/whisper/main.rs +++ b/candle-examples/examples/whisper/main.rs @@ -557,8 +557,10 @@ fn main() -> Result<()> { println!("loaded mel: {:?}", mel.dims()); let mut model = if args.quantized { - let vb = - candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&weights_filename)?; + let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf( + &weights_filename, + &device, + )?; Model::Quantized(m::quantized_model::Whisper::load(&vb, config)?) } else { let vb = diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index c872dc6057..201af97eac 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -15,6 +15,7 @@ const CAST: &str = include_str!("cast.metal"); const REDUCE: &str = include_str!("reduce.metal"); const CONV: &str = include_str!("conv.metal"); const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib"); +const QUANTIZED: &str = include_str!("quantized.metal"); /// Most kernels apply similarly across the tensors /// This creates a strategy that uses the maximum amount of threads per threadgroup (capped at the @@ -62,6 +63,8 @@ macro_rules! primitive { }; } primitive!(usize); +primitive!(i64); +primitive!(i32); primitive!(u32); primitive!(f32); @@ -117,6 +120,7 @@ pub enum Source { Reduce, Mfa, Conv, + Quantized, } macro_rules! ops{ @@ -215,17 +219,15 @@ type Pipelines = HashMap<(&'static str, Option), ComputePipeline pub struct Kernels { libraries: RwLock, pipelines: RwLock, - fence: metal::Fence, } impl Kernels { - pub fn new(fence: metal::Fence) -> Self { + pub fn new() -> Self { let libraries = RwLock::new(Libraries::new()); let pipelines = RwLock::new(Pipelines::new()); Self { libraries, pipelines, - fence, } } @@ -239,6 +241,7 @@ impl Kernels { Source::Cast => CAST, Source::Reduce => REDUCE, Source::Conv => CONV, + Source::Quantized => QUANTIZED, Source::Mfa => panic!("Invalid lib"), } } @@ -345,7 +348,6 @@ pub fn call_unary_contiguous( ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?; let encoder = command_buffer.new_compute_command_encoder(); - encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!(encoder, (length, input, output)); @@ -354,7 +356,6 @@ pub fn call_unary_contiguous( encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -376,7 +377,6 @@ pub fn call_unary_strided( let num_dims: usize = shape.len(); let encoder = command_buffer.new_compute_command_encoder(); - encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); let length: usize = shape.iter().product(); @@ -398,7 +398,6 @@ pub fn call_unary_strided( encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -417,7 +416,6 @@ pub fn call_binary_contiguous( let pipeline = kernels.load_pipeline(device, Source::Binary, kernel_name.0)?; let encoder = command_buffer.new_compute_command_encoder(); - encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!(encoder, (length, left, right, output)); @@ -428,7 +426,6 @@ pub fn call_binary_contiguous( encoder.use_resource(right, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -453,7 +450,6 @@ pub fn call_binary_strided( let num_dims: usize = shape.len(); let encoder = command_buffer.new_compute_command_encoder(); let width: usize = shape.iter().product(); - encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); let length: usize = shape.iter().product(); @@ -478,7 +474,6 @@ pub fn call_binary_strided( encoder.use_resource(right_input, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -497,7 +492,6 @@ pub fn call_cast_contiguous( let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?; let encoder = command_buffer.new_compute_command_encoder(); - encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!(encoder, (length, (input, input_offset), output)); @@ -506,7 +500,6 @@ pub fn call_cast_contiguous( encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -526,7 +519,6 @@ pub fn call_cast_strided( let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?; let encoder = command_buffer.new_compute_command_encoder(); - encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); let length: usize = shape.iter().product(); @@ -548,7 +540,6 @@ pub fn call_cast_strided( encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -568,7 +559,6 @@ pub fn call_reduce_contiguous( let elements_to_sum = length / out_length; let encoder = command_buffer.new_compute_command_encoder(); - encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!( @@ -597,7 +587,6 @@ pub fn call_reduce_contiguous( encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -619,7 +608,6 @@ pub fn call_reduce_strided( let elements_to_sum = length / out_length; let encoder = command_buffer.new_compute_command_encoder(); - encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!( @@ -655,7 +643,6 @@ pub fn call_reduce_strided( encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -674,7 +661,6 @@ pub fn call_last_softmax( ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?; let encoder = command_buffer.new_compute_command_encoder(); - encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!( @@ -705,7 +691,6 @@ pub fn call_last_softmax( encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -725,7 +710,6 @@ pub fn call_affine( let pipeline = kernels.load_pipeline(device, Source::Affine, name)?; let encoder = command_buffer.new_compute_command_encoder(); - encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!(encoder, (size, mul, add, input, output)); @@ -734,7 +718,6 @@ pub fn call_affine( encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -757,7 +740,6 @@ pub fn call_affine_strided( let size: usize = shape.iter().product(); let encoder = command_buffer.new_compute_command_encoder(); - encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!( @@ -778,7 +760,6 @@ pub fn call_affine_strided( encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -797,7 +778,6 @@ pub fn call_powf( let pipeline = kernels.load_pipeline(device, Source::Affine, name)?; let encoder = command_buffer.new_compute_command_encoder(); - encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!(encoder, (size, mul, input, output)); @@ -806,7 +786,6 @@ pub fn call_powf( encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -828,7 +807,6 @@ pub fn call_powf_strided( let size: usize = shape.iter().product(); let encoder = command_buffer.new_compute_command_encoder(); - encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!( @@ -848,7 +826,6 @@ pub fn call_powf_strided( encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -867,7 +844,6 @@ pub fn call_elu( let pipeline = kernels.load_pipeline(device, Source::Affine, name)?; let encoder = command_buffer.new_compute_command_encoder(); - encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!(encoder, (size, mul, input, output)); @@ -876,7 +852,6 @@ pub fn call_elu( encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -898,7 +873,6 @@ pub fn call_elu_strided( let size: usize = shape.iter().product(); let encoder = command_buffer.new_compute_command_encoder(); - encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!( @@ -918,7 +892,6 @@ pub fn call_elu_strided( encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -940,7 +913,6 @@ pub fn call_where_cond_strided( let pipeline = kernels.load_pipeline(device, Source::Ternary, name)?; let encoder = command_buffer.new_compute_command_encoder(); - encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); let size: usize = shape.iter().product(); @@ -969,7 +941,6 @@ pub fn call_where_cond_strided( encoder.use_resource(right, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -996,7 +967,6 @@ pub fn call_index_select( let encoder = command_buffer.new_compute_command_encoder(); - encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!( @@ -1019,7 +989,6 @@ pub fn call_index_select( encoder.use_resource(ids, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -1048,7 +1017,6 @@ pub fn call_gather( let encoder = command_buffer.new_compute_command_encoder(); - encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!( @@ -1071,7 +1039,6 @@ pub fn call_gather( encoder.use_resource(ids, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -1100,7 +1067,6 @@ pub fn call_scatter_add( let encoder = command_buffer.new_compute_command_encoder(); - encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!( @@ -1123,7 +1089,6 @@ pub fn call_scatter_add( encoder.use_resource(ids, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -1153,7 +1118,6 @@ pub fn call_index_add( let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?; let encoder = command_buffer.new_compute_command_encoder(); - encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!( @@ -1177,7 +1141,6 @@ pub fn call_index_add( encoder.use_resource(ids, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -1381,7 +1344,6 @@ pub fn call_gemm( let block_bytes = block_elements * bytes; let encoder = command_buffer.new_compute_command_encoder(); - encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); encoder.set_threadgroup_memory_length(0, block_bytes.into()); encoder.set_buffer(0, Some(lhs_buffer), lhs_offset as NSUInteger); @@ -1421,12 +1383,10 @@ pub fn call_gemm( height: 1, depth: 1, }; - // println!("grid size {grid_size:?} group size {group_size:?}"); encoder.use_resource(lhs_buffer, metal::MTLResourceUsage::Read); encoder.use_resource(rhs_buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(grid_size, group_size); - encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) @@ -1451,7 +1411,6 @@ pub fn call_im2col1d_strided( let encoder = command_buffer.new_compute_command_encoder(); let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el); - encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!( encoder, @@ -1471,7 +1430,6 @@ pub fn call_im2col1d_strided( encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) @@ -1501,7 +1459,6 @@ pub fn call_im2col_strided( let encoder = command_buffer.new_compute_command_encoder(); let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el); - encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!( encoder, @@ -1523,7 +1480,6 @@ pub fn call_im2col_strided( encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) @@ -1549,7 +1505,6 @@ pub fn call_upsample_nearest_2d( let scale_h = shape[3] as f32 / out_h as f32; let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el); let encoder = command_buffer.new_compute_command_encoder(); - encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!( encoder, @@ -1567,7 +1522,176 @@ pub fn call_upsample_nearest_2d( encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - encoder.update_fence(&kernels.fence); + encoder.end_encoding(); + + Ok(()) +} + +#[derive(Debug, Clone, Copy)] +pub enum GgmlDType { + Q4_0, + Q4_1, + Q5_0, + Q5_1, + Q8_0, + Q8_1, + Q2K, + Q3K, + Q4K, + Q5K, + Q6K, + Q8K, + F16, + F32, +} + +pub fn call_quantized_matmul_t( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + dtype: GgmlDType, + (b, m, n, k): (usize, usize, usize, usize), + lhs: &Buffer, + lhs_offset: usize, + rhs: &Buffer, + output: &Buffer, +) -> Result<(), MetalKernelError> { + // Everything is in reverse + let ne00 = k as i64; + let ne01 = n as i64; + let ne02 = b as i64; + let ne03 = 1 as i64; + + let nb00 = 0i64; + let nb01 = 0 as i64; + let nb02 = 0 as i64; + + let ne10 = k as i64; + let ne11 = m as i64; + let ne12 = b as i64; + let ne13 = 1 as i64; + + let nb10 = 0i64; + let nb11 = 0i64; + let nb12 = 0i64; + + let ne0 = n as i64; + let ne1 = m as i64; + let r2: u32 = (ne12 / ne02) as u32; + let r3: u32 = (ne13 / ne03) as u32; + + let (nth0, nth1, align) = match dtype { + GgmlDType::Q4_0 + | GgmlDType::Q4_1 + | GgmlDType::Q5_0 + | GgmlDType::Q5_1 + | GgmlDType::Q8_0 + | GgmlDType::Q8_1 => { + let nth0 = 8; + let nth1 = 8; + let align = 8; + (nth0, nth1, align) + } + GgmlDType::Q2K => { + // Fixing a bug in Metal for GGML + let nth0 = 4; + let nth1 = 8; + let align = 4; + (nth0, nth1, align) + } + GgmlDType::Q4K => { + let nth0 = 4; + let nth1 = 8; + let align = 4; + (nth0, nth1, align) + } + GgmlDType::Q3K | GgmlDType::Q5K => { + let nth0 = 2; + let nth1 = 32; + let align = 4; + (nth0, nth1, align) + } + GgmlDType::Q6K => { + let nth0 = 2; + let nth1 = 32; + let align = 2; + (nth0, nth1, align) + } + GgmlDType::F16 | GgmlDType::Q8K => { + // Original implem uses rows + let nth0 = 32; + let nth1 = 1; + let align = 8; + (nth0, nth1, align) + } + GgmlDType::F32 => { + let nth0 = 32; + let nth1 = 1; + let align = 8; + (nth0, nth1, align) + } + }; + let thread_groups_count = MTLSize { + width: divide(ne01 as usize, align), + height: ne11 as u64, + depth: (ne12 * ne13) as u64, + }; + let threads_per_threadgroup = MTLSize { + width: nth0, + height: nth1, + depth: 1, + }; + let name = match dtype { + GgmlDType::Q4_0 => "kernel_mul_mv_q4_0_f32", + GgmlDType::Q4_1 => "kernel_mul_mv_q4_1_f32", + GgmlDType::Q5_0 => "kernel_mul_mv_q5_0_f32", + GgmlDType::Q5_1 => "kernel_mul_mv_q5_1_f32", + GgmlDType::Q8_0 => "kernel_mul_mv_q8_0_f32", + GgmlDType::Q8_1 => "kernel_mul_mv_q8_1_f32", + GgmlDType::Q2K => "kernel_mul_mv_q2_K_f32", + GgmlDType::Q3K => "kernel_mul_mv_q3_K_f32", + GgmlDType::Q4K => "kernel_mul_mv_q4_K_f32", + GgmlDType::Q5K => "kernel_mul_mv_q5_K_f32", + GgmlDType::Q6K => "kernel_mul_mv_q6_K_f32", + GgmlDType::Q8K => "kernel_mul_mv_q8_K_f32", + GgmlDType::F16 => "kernel_mul_mv_f16_f32", + GgmlDType::F32 => "kernel_mul_mv_f32_f32", + }; + + let pipeline = kernels.load_pipeline(device, Source::Quantized, name)?; + let encoder = command_buffer.new_compute_command_encoder(); + encoder.set_compute_pipeline_state(&pipeline); + + set_params!( + encoder, + ( + rhs, + (lhs, lhs_offset), + output, + ne00, + ne01, + ne02, + nb00, + nb01, + nb02, + ne10, + ne11, + ne12, + nb10, + nb11, + nb12, + ne0, + ne1, + r2, + r3 + ) + ); + encoder.set_threadgroup_memory_length(0, 8192); + encoder.use_resource(lhs, metal::MTLResourceUsage::Read); + encoder.use_resource(rhs, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); + + encoder.dispatch_thread_groups(thread_groups_count, threads_per_threadgroup); encoder.end_encoding(); Ok(()) diff --git a/candle-metal-kernels/src/quantized.metal b/candle-metal-kernels/src/quantized.metal new file mode 100644 index 0000000000..9aa7b502a9 --- /dev/null +++ b/candle-metal-kernels/src/quantized.metal @@ -0,0 +1,5107 @@ +#include + +using namespace metal; + +#define MAX(x, y) ((x) > (y) ? (x) : (y)) +#define MIN(x, y) ((x) < (y) ? (x) : (y)) +#define SWAP(x, y) { auto tmp = (x); (x) = (y); (y) = tmp; } + +#define QK4_0 32 +#define QR4_0 2 +typedef struct { + half d; // delta + uint8_t qs[QK4_0 / 2]; // nibbles / quants +} block_q4_0; + +#define QK4_1 32 +typedef struct { + half d; // delta + half m; // min + uint8_t qs[QK4_1 / 2]; // nibbles / quants +} block_q4_1; + +#define QK5_0 32 +typedef struct { + half d; // delta + uint8_t qh[4]; // 5-th bit of quants + uint8_t qs[QK5_0 / 2]; // nibbles / quants +} block_q5_0; + +#define QK5_1 32 +typedef struct { + half d; // delta + half m; // min + uint8_t qh[4]; // 5-th bit of quants + uint8_t qs[QK5_1 / 2]; // nibbles / quants +} block_q5_1; + +#define QK8_0 32 +typedef struct { + half d; // delta + int8_t qs[QK8_0]; // quants +} block_q8_0; + +#define N_SIMDWIDTH 32 // assuming SIMD group size is 32 + +enum ggml_sort_order { + GGML_SORT_ASC, + GGML_SORT_DESC, +}; + +// general-purpose kernel for addition, multiplication and division of two tensors +// pros: works for non-contiguous tensors, supports broadcast across all dims +// cons: not very efficient +kernel void kernel_add( + device const char * src0, + device const char * src1, + device char * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant uint64_t & nb13, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + constant int64_t & offs, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t i03 = tgpig.z; + const int64_t i02 = tgpig.y; + const int64_t i01 = tgpig.x; + + const int64_t i13 = i03 % ne13; + const int64_t i12 = i02 % ne12; + const int64_t i11 = i01 % ne11; + + device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + offs; + device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11; + device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + offs; + + for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { + const int i10 = i0 % ne10; + *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) + *((device float *)(src1_ptr + i10*nb10)); + } +} + +kernel void kernel_mul( + device const char * src0, + device const char * src1, + device char * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant uint64_t & nb13, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t i03 = tgpig.z; + const int64_t i02 = tgpig.y; + const int64_t i01 = tgpig.x; + + const int64_t i13 = i03 % ne13; + const int64_t i12 = i02 % ne12; + const int64_t i11 = i01 % ne11; + + device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01; + device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11; + device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1; + + for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { + const int i10 = i0 % ne10; + *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) * *((device float *)(src1_ptr + i10*nb10)); + } +} + +kernel void kernel_div( + device const char * src0, + device const char * src1, + device char * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant uint64_t & nb13, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t i03 = tgpig.z; + const int64_t i02 = tgpig.y; + const int64_t i01 = tgpig.x; + + const int64_t i13 = i03 % ne13; + const int64_t i12 = i02 % ne12; + const int64_t i11 = i01 % ne11; + + device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01; + device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11; + device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1; + + for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { + const int i10 = i0 % ne10; + *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) / *((device float *)(src1_ptr + i10*nb10)); + } +} + +// assumption: src1 is a row +// broadcast src1 into src0 +kernel void kernel_add_row( + device const float4 * src0, + device const float4 * src1, + device float4 * dst, + constant uint64_t & nb [[buffer(28)]], + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = src0[tpig] + src1[tpig % nb]; +} + +kernel void kernel_mul_row( + device const float4 * src0, + device const float4 * src1, + device float4 * dst, + constant uint64_t & nb [[buffer(28)]], + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = src0[tpig] * src1[tpig % nb]; +} + +kernel void kernel_div_row( + device const float4 * src0, + device const float4 * src1, + device float4 * dst, + constant uint64_t & nb [[buffer(28)]], + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = src0[tpig] / src1[tpig % nb]; +} + +kernel void kernel_scale( + device const float * src0, + device float * dst, + constant float & scale, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = src0[tpig] * scale; +} + +kernel void kernel_scale_4( + device const float4 * src0, + device float4 * dst, + constant float & scale, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = src0[tpig] * scale; +} + +kernel void kernel_relu( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = max(0.0f, src0[tpig]); +} + +kernel void kernel_tanh( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + device const float & x = src0[tpig]; + dst[tpig] = precise::tanh(x); +} + +constant float GELU_COEF_A = 0.044715f; +constant float GELU_QUICK_COEF = -1.702f; +constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; + +kernel void kernel_gelu( + device const float4 * src0, + device float4 * dst, + uint tpig[[thread_position_in_grid]]) { + device const float4 & x = src0[tpig]; + + // BEWARE !!! + // Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs! + // This was observed with Falcon 7B and 40B models + // + dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); +} + +kernel void kernel_gelu_quick( + device const float4 * src0, + device float4 * dst, + uint tpig[[thread_position_in_grid]]) { + device const float4 & x = src0[tpig]; + + dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x))); +} + +kernel void kernel_silu( + device const float4 * src0, + device float4 * dst, + uint tpig[[thread_position_in_grid]]) { + device const float4 & x = src0[tpig]; + dst[tpig] = x / (1.0f + exp(-x)); +} + +kernel void kernel_sqr( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = src0[tpig] * src0[tpig]; +} + +kernel void kernel_sum_rows( + device const float * src0, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant uint64_t & nb13, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + uint3 tpig[[thread_position_in_grid]]) { + int64_t i3 = tpig.z; + int64_t i2 = tpig.y; + int64_t i1 = tpig.x; + + if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) { + return; + } + + device const float * src_row = (device const float *) ((device const char *) src0 + i1*nb01 + i2*nb02 + i3*nb03); + device float * dst_row = (device float *) ((device char *) dst + i1*nb1 + i2*nb2 + i3*nb3); + + float row_sum = 0; + + for (int64_t i0 = 0; i0 < ne00; i0++) { + row_sum += src_row[i0]; + } + + dst_row[0] = row_sum; +} + +kernel void kernel_soft_max( + device const float * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant float & scale, + threadgroup float * buf [[threadgroup(0)]], + uint tgpig[[threadgroup_position_in_grid]], + uint tpitg[[thread_position_in_threadgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint ntg[[threads_per_threadgroup]]) { + const int64_t i03 = (tgpig) / (ne02*ne01); + const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01; + const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01); + + device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + device const float * pmask = src1 != src0 ? src1 + i01*ne00 : nullptr; + device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + + // parallel max + float lmax = -INFINITY; + + for (int i00 = tpitg; i00 < ne00; i00 += ntg) { + lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f)); + } + + // find the max value in the block + float max_val = simd_max(lmax); + if (ntg > N_SIMDWIDTH) { + if (sgitg == 0) { + buf[tiisg] = -INFINITY; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (tiisg == 0) { + buf[sgitg] = max_val; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + max_val = buf[tiisg]; + max_val = simd_max(max_val); + } + + // parallel sum + float lsum = 0.0f; + for (int i00 = tpitg; i00 < ne00; i00 += ntg) { + const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f)) - max_val); + lsum += exp_psrc0; + pdst[i00] = exp_psrc0; + } + + // This barrier fixes a failing test + // ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335 + threadgroup_barrier(mem_flags::mem_none); + + float sum = simd_sum(lsum); + + if (ntg > N_SIMDWIDTH) { + if (sgitg == 0) { + buf[tiisg] = 0.0f; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (tiisg == 0) { + buf[sgitg] = sum; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + sum = buf[tiisg]; + sum = simd_sum(sum); + } + + const float inv_sum = 1.0f/sum; + + for (int i00 = tpitg; i00 < ne00; i00 += ntg) { + pdst[i00] *= inv_sum; + } +} + +kernel void kernel_soft_max_4( + device const float * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant float & scale, + threadgroup float * buf [[threadgroup(0)]], + uint tgpig[[threadgroup_position_in_grid]], + uint tpitg[[thread_position_in_threadgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint ntg[[threads_per_threadgroup]]) { + const int64_t i03 = (tgpig) / (ne02*ne01); + const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01; + const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01); + + device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); + device const float4 * pmask = src1 != src0 ? (device const float4 *)(src1 + i01*ne00) : nullptr; + device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); + + // parallel max + float4 lmax4 = -INFINITY; + + for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { + lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f)); + } + + const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3])); + + float max_val = simd_max(lmax); + if (ntg > N_SIMDWIDTH) { + if (sgitg == 0) { + buf[tiisg] = -INFINITY; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (tiisg == 0) { + buf[sgitg] = max_val; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + max_val = buf[tiisg]; + max_val = simd_max(max_val); + } + + // parallel sum + float4 lsum4 = 0.0f; + for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { + const float4 exp_psrc4 = exp((psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f)) - max_val); + lsum4 += exp_psrc4; + pdst4[i00] = exp_psrc4; + } + + const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3]; + + // This barrier fixes a failing test + // ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335 + threadgroup_barrier(mem_flags::mem_none); + + float sum = simd_sum(lsum); + + if (ntg > N_SIMDWIDTH) { + if (sgitg == 0) { + buf[tiisg] = 0.0f; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (tiisg == 0) { + buf[sgitg] = sum; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + sum = buf[tiisg]; + sum = simd_sum(sum); + } + + const float inv_sum = 1.0f/sum; + + for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { + pdst4[i00] *= inv_sum; + } +} + +kernel void kernel_diag_mask_inf( + device const float * src0, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int & n_past, + uint3 tpig[[thread_position_in_grid]]) { + const int64_t i02 = tpig[2]; + const int64_t i01 = tpig[1]; + const int64_t i00 = tpig[0]; + + if (i00 > n_past + i01) { + dst[i02*ne01*ne00 + i01*ne00 + i00] = -INFINITY; + } else { + dst[i02*ne01*ne00 + i01*ne00 + i00] = src0[i02*ne01*ne00 + i01*ne00 + i00]; + } +} + +kernel void kernel_diag_mask_inf_8( + device const float4 * src0, + device float4 * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int & n_past, + uint3 tpig[[thread_position_in_grid]]) { + + const int64_t i = 2*tpig[0]; + + dst[i+0] = src0[i+0]; + dst[i+1] = src0[i+1]; + int64_t i4 = 4*i; + const int64_t i02 = i4/(ne00*ne01); i4 -= i02*ne00*ne01; + const int64_t i01 = i4/(ne00); i4 -= i01*ne00; + const int64_t i00 = i4; + for (int k = 3; k >= 0; --k) { + if (i00 + 4 + k <= n_past + i01) { + break; + } + dst[i+1][k] = -INFINITY; + if (i00 + k > n_past + i01) { + dst[i][k] = -INFINITY; + } + } +} + +kernel void kernel_norm( + device const void * src0, + device float * dst, + constant int64_t & ne00, + constant uint64_t & nb01, + constant float & eps, + threadgroup float * sum [[threadgroup(0)]], + uint tgpig[[threadgroup_position_in_grid]], + uint tpitg[[thread_position_in_threadgroup]], + uint ntg[[threads_per_threadgroup]]) { + device const float * x = (device const float *) ((device const char *) src0 + tgpig*nb01); + // MEAN + // parallel sum + sum[tpitg] = 0.0f; + for (int i00 = tpitg; i00 < ne00; i00 += ntg) { + sum[tpitg] += x[i00]; + } + // reduce + threadgroup_barrier(mem_flags::mem_threadgroup); + for (uint i = ntg/2; i > 0; i /= 2) { + if (tpitg < i) { + sum[tpitg] += sum[tpitg + i]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + } + const float mean = sum[0] / ne00; + + // recenter and VARIANCE + threadgroup_barrier(mem_flags::mem_threadgroup); + device float * y = dst + tgpig*ne00; + sum[tpitg] = 0.0f; + for (int i00 = tpitg; i00 < ne00; i00 += ntg) { + y[i00] = x[i00] - mean; + sum[tpitg] += y[i00] * y[i00]; + } + + // reduce + threadgroup_barrier(mem_flags::mem_threadgroup); + for (uint i = ntg/2; i > 0; i /= 2) { + if (tpitg < i) { + sum[tpitg] += sum[tpitg + i]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + } + const float variance = sum[0] / ne00; + + const float scale = 1.0f/sqrt(variance + eps); + for (int i00 = tpitg; i00 < ne00; i00 += ntg) { + y[i00] = y[i00] * scale; + } +} + +kernel void kernel_rms_norm( + device const void * src0, + device float * dst, + constant int64_t & ne00, + constant uint64_t & nb01, + constant float & eps, + threadgroup float * buf [[threadgroup(0)]], + uint tgpig[[threadgroup_position_in_grid]], + uint tpitg[[thread_position_in_threadgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint ntg[[threads_per_threadgroup]]) { + device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01); + + float4 sumf = 0; + float all_sum = 0; + + // parallel sum + for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { + sumf += x[i00] * x[i00]; + } + all_sum = sumf[0] + sumf[1] + sumf[2] + sumf[3]; + all_sum = simd_sum(all_sum); + if (ntg > N_SIMDWIDTH) { + if (sgitg == 0) { + buf[tiisg] = 0.0f; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (tiisg == 0) { + buf[sgitg] = all_sum; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + all_sum = buf[tiisg]; + all_sum = simd_sum(all_sum); + } + + const float mean = all_sum/ne00; + const float scale = 1.0f/sqrt(mean + eps); + + device float4 * y = (device float4 *) (dst + tgpig*ne00); + for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { + y[i00] = x[i00] * scale; + } +} + +kernel void kernel_group_norm( + device const float * src0, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int32_t & n_groups, + constant float & eps, + threadgroup float * buf [[threadgroup(0)]], + uint tgpig[[threadgroup_position_in_grid]], + uint tpitg[[thread_position_in_threadgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint ntg[[threads_per_threadgroup]]) { + const int64_t ne = ne00*ne01*ne02; + const int64_t gs = ne00*ne01*((ne02 + n_groups - 1) / n_groups); + + int start = tgpig * gs; + int end = start + gs; + + start += tpitg; + + if (end >= ne) { + end = ne; + } + + float tmp = 0.0f; // partial sum for thread in warp + + for (int j = start; j < end; j += ntg) { + tmp += src0[j]; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + tmp = simd_sum(tmp); + if (ntg > N_SIMDWIDTH) { + if (sgitg == 0) { + buf[tiisg] = 0.0f; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (tiisg == 0) { + buf[sgitg] = tmp; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + tmp = buf[tiisg]; + tmp = simd_sum(tmp); + } + + const float mean = tmp / gs; + tmp = 0.0f; + + for (int j = start; j < end; j += ntg) { + float xi = src0[j] - mean; + dst[j] = xi; + tmp += xi * xi; + } + + tmp = simd_sum(tmp); + if (ntg > N_SIMDWIDTH) { + if (sgitg == 0) { + buf[tiisg] = 0.0f; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (tiisg == 0) { + buf[sgitg] = tmp; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + tmp = buf[tiisg]; + tmp = simd_sum(tmp); + } + + const float variance = tmp / gs; + const float scale = 1.0f/sqrt(variance + eps); + for (int j = start; j < end; j += ntg) { + dst[j] *= scale; + } +} + +// function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i]) +// il indicates where the q4 quants begin (0 or QK4_0/4) +// we assume that the yl's have been multiplied with the appropriate scale factor +// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096) +inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl, int il) { + float d = qb_curr->d; + + float2 acc = 0.f; + + device const uint16_t * qs = ((device const uint16_t *)qb_curr + 1 + il/2); + + for (int i = 0; i < 8; i+=2) { + acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F) + + yl[i + 1] * (qs[i / 2] & 0x0F00); + acc[1] += yl[i + 8] * (qs[i / 2] & 0x00F0) + + yl[i + 9] * (qs[i / 2] & 0xF000); + } + return d * (sumy * -8.f + acc[0] + acc[1]); +} + +// function for calculate inner product between half a q4_1 block and 16 floats (yl), sumy is SUM(yl[i]) +// il indicates where the q4 quants begin (0 or QK4_0/4) +// we assume that the yl's have been multiplied with the appropriate scale factor +// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096) +inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl, int il) { + float d = qb_curr->d; + float m = qb_curr->m; + + float2 acc = 0.f; + + device const uint16_t * qs = ((device const uint16_t *)qb_curr + 2 + il/2); + + for (int i = 0; i < 8; i+=2) { + acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F) + + yl[i + 1] * (qs[i / 2] & 0x0F00); + acc[1] += yl[i + 8] * (qs[i / 2] & 0x00F0) + + yl[i + 9] * (qs[i / 2] & 0xF000); + } + return d * (acc[0] + acc[1]) + sumy * m; +} + +// function for calculate inner product between half a q5_0 block and 16 floats (yl), sumy is SUM(yl[i]) +// il indicates where the q5 quants begin (0 or QK5_0/4) +// we assume that the yl's have been multiplied with the appropriate scale factor +// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096) +inline float block_q_n_dot_y(device const block_q5_0 * qb_curr, float sumy, thread float * yl, int il) { + float d = qb_curr->d; + + float2 acc = 0.f; + + device const uint16_t * qs = ((device const uint16_t *)qb_curr + 3 + il/2); + const uint32_t qh = *((device const uint32_t *)qb_curr->qh); + + for (int i = 0; i < 8; i+=2) { + acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010)) + + yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000)); + acc[1] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100)) + + yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000)); + } + return d * (sumy * -16.f + acc[0] + acc[1]); +} + +// function for calculate inner product between half a q5_1 block and 16 floats (yl), sumy is SUM(yl[i]) +// il indicates where the q5 quants begin (0 or QK5_1/4) +// we assume that the yl's have been multiplied with the appropriate scale factor +// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096) +inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thread float * yl, int il) { + float d = qb_curr->d; + float m = qb_curr->m; + + float2 acc = 0.f; + + device const uint16_t * qs = ((device const uint16_t *)qb_curr + 4 + il/2); + const uint32_t qh = *((device const uint32_t *)qb_curr->qh); + + for (int i = 0; i < 8; i+=2) { + acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010)) + + yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000)); + acc[1] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100)) + + yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000)); + } + return d * (acc[0] + acc[1]) + sumy * m; +} + +// putting them in the kernel cause a significant performance penalty +#define N_DST 4 // each SIMD group works on 4 rows +#define N_SIMDGROUP 2 // number of SIMD groups in a thread group +//Note: This is a template, but strictly speaking it only applies to +// quantizations where the block size is 32. It also does not +// guard against the number of rows not being divisible by +// N_DST, so this is another explicit assumption of the implementation. +template +void mul_vec_q_n_f32_impl( + device const void * src0, + device const float * src1, + device float * dst, + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + uint3 tgpig, uint tiisg, uint sgitg) { + const int nb = ne00/QK4_0; + + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + + const int first_row = (r0 * nsg + sgitg) * nr; + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + + device const block_q_type * x = (device const block_q_type *) src0 + offset0; + device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + + float yl[16]; // src1 vector cache + float sumf[nr] = {0.f}; + + const int ix = (tiisg/2); + const int il = (tiisg%2)*8; + + device const float * yb = y + ix * QK4_0 + il; + + // each thread in a SIMD group deals with half a block. + for (int ib = ix; ib < nb; ib += nw/2) { + float sumy = 0; + for (int i = 0; i < 8; i += 2) { + sumy += yb[i] + yb[i+1]; + yl[i+0] = yb[i+ 0]; + yl[i+1] = yb[i+ 1]/256.f; + + sumy += yb[i+16] + yb[i+17]; + yl[i+8] = yb[i+16]/16.f; + yl[i+9] = yb[i+17]/4096.f; + } + + for (int row = 0; row < nr; row++) { + sumf[row] += block_q_n_dot_y(x+ib+row*nb, sumy, yl, il); + } + + yb += QK4_0 * 16; + } + + for (int row = 0; row < nr; ++row) { + const float tot = simd_sum(sumf[row]); + if (tiisg == 0 && first_row + row < ne01) { + dst[im*ne0*ne1 + r1*ne0 + first_row + row] = tot; + } + } +} + +kernel void kernel_mul_mv_q4_0_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg); +} + +kernel void kernel_mul_mv_q4_1_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg); +} + +kernel void kernel_mul_mv_q5_0_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg); +} + +kernel void kernel_mul_mv_q5_1_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg); +} + + +#define NB_Q8_0 8 + +void kernel_mul_mv_q8_0_f32_impl( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne10, + constant int64_t & ne12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + const int nr = N_DST; + const int nsg = N_SIMDGROUP; + const int nw = N_SIMDWIDTH; + + const int nb = ne00/QK8_0; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + + const int first_row = (r0 * nsg + sgitg) * nr; + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + + device const block_q8_0 * x = (device const block_q8_0 *) src0 + offset0; + device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + + float yl[NB_Q8_0]; + float sumf[nr]={0.f}; + + const int ix = tiisg/4; + const int il = tiisg%4; + + device const float * yb = y + ix * QK8_0 + NB_Q8_0*il; + + // each thread in a SIMD group deals with NB_Q8_0 quants at a time + for (int ib = ix; ib < nb; ib += nw/4) { + for (int i = 0; i < NB_Q8_0; ++i) { + yl[i] = yb[i]; + } + + for (int row = 0; row < nr; row++) { + device const int8_t * qs = x[ib+row*nb].qs + NB_Q8_0*il; + float sumq = 0.f; + for (int iq = 0; iq < NB_Q8_0; ++iq) { + sumq += qs[iq] * yl[iq]; + } + sumf[row] += sumq*x[ib+row*nb].d; + } + + yb += NB_Q8_0 * nw; + } + + for (int row = 0; row < nr; ++row) { + const float tot = simd_sum(sumf[row]); + if (tiisg == 0 && first_row + row < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot; + } + } +} + +[[host_name("kernel_mul_mv_q8_0_f32")]] +kernel void kernel_mul_mv_q8_0_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + kernel_mul_mv_q8_0_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg); +} + +#define N_F32_F32 4 + +void kernel_mul_mv_f32_f32_impl( + device const char * src0, + device const char * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]]) { + + const int64_t r0 = tgpig.x; + const int64_t rb = tgpig.y*N_F32_F32; + const int64_t im = tgpig.z; + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02; + + device const float * x = (device const float *) (src0 + offset0); + + if (ne00 < 128) { + for (int row = 0; row < N_F32_F32; ++row) { + int r1 = rb + row; + if (r1 >= ne11) { + break; + } + + device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); + + float sumf = 0; + for (int i = tiisg; i < ne00; i += 32) { + sumf += (float) x[i] * (float) y[i]; + } + + float all_sum = simd_sum(sumf); + if (tiisg == 0) { + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } + } + } else { + device const float4 * x4 = (device const float4 *)x; + for (int row = 0; row < N_F32_F32; ++row) { + int r1 = rb + row; + if (r1 >= ne11) { + break; + } + + device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); + device const float4 * y4 = (device const float4 *) y; + + float sumf = 0; + for (int i = tiisg; i < ne00/4; i += 32) { + for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k]; + } + + float all_sum = simd_sum(sumf); + if (tiisg == 0) { + for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i]; + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } + } + } +} + +[[host_name("kernel_mul_mv_f32_f32")]] +kernel void kernel_mul_mv_f32_f32( + device const char * src0, + device const char * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]]) { + kernel_mul_mv_f32_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg); +} + +#define N_F16_F16 4 + +kernel void kernel_mul_mv_f16_f16( + device const char * src0, + device const char * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]]) { + + const int64_t r0 = tgpig.x; + const int64_t rb = tgpig.y*N_F16_F16; + const int64_t im = tgpig.z; + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02; + + device const half * x = (device const half *) (src0 + offset0); + + if (ne00 < 128) { + for (int row = 0; row < N_F16_F16; ++row) { + int r1 = rb + row; + if (r1 >= ne11) { + break; + } + + device const half * y = (device const half *) (src1 + r1*nb11 + im*nb12); + + float sumf = 0; + for (int i = tiisg; i < ne00; i += 32) { + sumf += (half) x[i] * (half) y[i]; + } + + float all_sum = simd_sum(sumf); + if (tiisg == 0) { + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } + } + } else { + device const half4 * x4 = (device const half4 *)x; + for (int row = 0; row < N_F16_F16; ++row) { + int r1 = rb + row; + if (r1 >= ne11) { + break; + } + + device const half * y = (device const half *) (src1 + r1*nb11 + im*nb12); + device const half4 * y4 = (device const half4 *) y; + + float sumf = 0; + for (int i = tiisg; i < ne00/4; i += 32) { + for (int k = 0; k < 4; ++k) sumf += (half) x4[i][k] * y4[i][k]; + } + + float all_sum = simd_sum(sumf); + if (tiisg == 0) { + for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (half) x[i] * y[i]; + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } + } + } +} + +void kernel_mul_mv_f16_f32_1row_impl( + device const char * src0, + device const char * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]]) { + + const int64_t r0 = tgpig.x; + const int64_t r1 = tgpig.y; + const int64_t im = tgpig.z; + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02; + + device const half * x = (device const half *) (src0 + offset0); + device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); + + float sumf = 0; + if (ne00 < 128) { + for (int i = tiisg; i < ne00; i += 32) { + sumf += (float) x[i] * (float) y[i]; + } + float all_sum = simd_sum(sumf); + if (tiisg == 0) { + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } + } else { + device const half4 * x4 = (device const half4 *) x; + device const float4 * y4 = (device const float4 *) y; + for (int i = tiisg; i < ne00/4; i += 32) { + for (int k = 0; k < 4; ++k) sumf += (float)x4[i][k] * y4[i][k]; + } + float all_sum = simd_sum(sumf); + if (tiisg == 0) { + for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i]; + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } + } +} + +[[host_name("kernel_mul_mv_f16_f32_1row")]] +kernel void kernel_mul_mv_f16_f32_1row( + device const char * src0, + device const char * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]]) { + kernel_mul_mv_f16_f32_1row_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg); +} + +#define N_F16_F32 4 + +void kernel_mul_mv_f16_f32_impl( + device const char * src0, + device const char * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]]) { + + const int64_t r0 = tgpig.x; + const int64_t rb = tgpig.y*N_F16_F32; + const int64_t im = tgpig.z; + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02; + + device const half * x = (device const half *) (src0 + offset0); + + if (ne00 < 128) { + for (int row = 0; row < N_F16_F32; ++row) { + int r1 = rb + row; + if (r1 >= ne11) { + break; + } + + device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); + + float sumf = 0; + for (int i = tiisg; i < ne00; i += 32) { + sumf += (float) x[i] * (float) y[i]; + } + + float all_sum = simd_sum(sumf); + if (tiisg == 0) { + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } + } + } else { + device const half4 * x4 = (device const half4 *)x; + for (int row = 0; row < N_F16_F32; ++row) { + int r1 = rb + row; + if (r1 >= ne11) { + break; + } + + device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); + device const float4 * y4 = (device const float4 *) y; + + float sumf = 0; + for (int i = tiisg; i < ne00/4; i += 32) { + for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k]; + } + + float all_sum = simd_sum(sumf); + if (tiisg == 0) { + for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i]; + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } + } + } +} + +[[host_name("kernel_mul_mv_f16_f32")]] +kernel void kernel_mul_mv_f16_f32( + device const char * src0, + device const char * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]]) { + kernel_mul_mv_f16_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg); +} + +// Assumes row size (ne00) is a multiple of 4 +kernel void kernel_mul_mv_f16_f32_l4( + device const char * src0, + device const char * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]]) { + + const int nrows = ne11; + const int64_t r0 = tgpig.x; + const int64_t im = tgpig.z; + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02; + + device const half4 * x4 = (device const half4 *) (src0 + offset0); + + for (int r1 = 0; r1 < nrows; ++r1) { + device const float4 * y4 = (device const float4 *) (src1 + r1*nb11 + im*nb12); + + float sumf = 0; + for (int i = tiisg; i < ne00/4; i += 32) { + for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k]; + } + + float all_sum = simd_sum(sumf); + if (tiisg == 0) { + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } + } +} + +kernel void kernel_alibi_f32( + device const float * src0, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + constant float & m0, + constant float & m1, + constant int & n_heads_log2_floor, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t i03 = tgpig[2]; + const int64_t i02 = tgpig[1]; + const int64_t i01 = tgpig[0]; + + const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + + const int64_t i3 = n / (ne2*ne1*ne0); + const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); + const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; + //const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); + + const int64_t k = i3*ne3 + i2; + + float m_k; + if (k < n_heads_log2_floor) { + m_k = pow(m0, k + 1); + } else { + m_k = pow(m1, 2 * (k - n_heads_log2_floor) + 1); + } + + device char * dst_row = (device char *) dst + i3*nb3 + i2*nb2 + i1*nb1; + device const char * src_row = (device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01; + for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) { + const float src_v = *(device float *)(src_row + i00*nb00); + device float * dst_v = (device float *)(dst_row + i00*nb0); + *dst_v = i00 * m_k + src_v; + } +} + +static float rope_yarn_ramp(const float low, const float high, const int i0) { + const float y = (i0 / 2 - low) / max(0.001f, high - low); + return 1.0f - min(1.0f, max(0.0f, y)); +} + +// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn +// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng. +static void rope_yarn( + float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale, + thread float * cos_theta, thread float * sin_theta +) { + // Get n-d rotational scaling corrected for extrapolation + float theta_interp = freq_scale * theta_extrap; + float theta = theta_interp; + if (ext_factor != 0.0f) { + float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor; + theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix; + + // Get n-d magnitude scaling corrected for interpolation + mscale *= 1.0f + 0.1f * log(1.0f / freq_scale); + } + *cos_theta = cos(theta) * mscale; + *sin_theta = sin(theta) * mscale; +} + +// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get +// `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))` +static float rope_yarn_corr_factor(int n_dims, int n_orig_ctx, float n_rot, float base) { + return n_dims * log(n_orig_ctx / (n_rot * 2 * M_PI_F)) / (2 * log(base)); +} + +static void rope_yarn_corr_dims( + int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2] +) { + // start and end correction dims + dims[0] = max(0.0f, floor(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_fast, freq_base))); + dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_slow, freq_base))); +} + +typedef void (rope_t)( + device const void * src0, + device const int32_t * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + constant int & n_past, + constant int & n_dims, + constant int & mode, + constant int & n_orig_ctx, + constant float & freq_base, + constant float & freq_scale, + constant float & ext_factor, + constant float & attn_factor, + constant float & beta_fast, + constant float & beta_slow, + uint tiitg[[thread_index_in_threadgroup]], + uint3 tptg[[threads_per_threadgroup]], + uint3 tgpig[[threadgroup_position_in_grid]]); + +template +kernel void kernel_rope( + device const void * src0, + device const int32_t * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + constant int & n_past, + constant int & n_dims, + constant int & mode, + constant int & n_orig_ctx, + constant float & freq_base, + constant float & freq_scale, + constant float & ext_factor, + constant float & attn_factor, + constant float & beta_fast, + constant float & beta_slow, + uint tiitg[[thread_index_in_threadgroup]], + uint3 tptg[[threads_per_threadgroup]], + uint3 tgpig[[threadgroup_position_in_grid]]) { + const int64_t i3 = tgpig[2]; + const int64_t i2 = tgpig[1]; + const int64_t i1 = tgpig[0]; + + const bool is_neox = mode & 2; + + float corr_dims[2]; + rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims); + + device const int32_t * pos = src1; + + const int64_t p = pos[i2]; + + const float theta_0 = (float)p; + const float inv_ndims = -1.f/n_dims; + + if (!is_neox) { + for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) { + + const float theta = theta_0 * pow(freq_base, inv_ndims*i0); + float cos_theta, sin_theta; + rope_yarn(theta, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta); + + device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + const T x0 = src[0]; + const T x1 = src[1]; + + dst_data[0] = x0*cos_theta - x1*sin_theta; + dst_data[1] = x0*sin_theta + x1*cos_theta; + } + } else { + for (int64_t ic = 2*tiitg; ic < ne0; ic += 2*tptg.x) { + if (ic < n_dims) { + const int64_t ib = 0; + + // simplified from `(ib * n_dims + ic) * inv_ndims` + const float cur_rot = inv_ndims*ic - ib; + + const float theta = theta_0 * pow(freq_base, cur_rot); + float cos_theta, sin_theta; + rope_yarn(theta, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta); + + const int64_t i0 = ib*n_dims + ic/2; + + device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + const float x0 = src[0]; + const float x1 = src[n_dims/2]; + + dst_data[0] = x0*cos_theta - x1*sin_theta; + dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta; + } else { + const int64_t i0 = ic; + + device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + dst_data[0] = src[0]; + dst_data[1] = src[1]; + } + } + } +} + +template [[host_name("kernel_rope_f32")]] kernel rope_t kernel_rope; +template [[host_name("kernel_rope_f16")]] kernel rope_t kernel_rope; + +kernel void kernel_im2col_f16( + device const float * x, + device half * dst, + constant int32_t & ofs0, + constant int32_t & ofs1, + constant int32_t & IW, + constant int32_t & IH, + constant int32_t & CHW, + constant int32_t & s0, + constant int32_t & s1, + constant int32_t & p0, + constant int32_t & p1, + constant int32_t & d0, + constant int32_t & d1, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tgpg[[threadgroups_per_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int32_t iiw = tgpig[2] * s0 + tpitg[2] * d0 - p0; + const int32_t iih = tgpig[1] * s1 + tpitg[1] * d1 - p1; + + const int32_t offset_dst = + (tpitg[0] * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * CHW + + (tgpig[0] * (ntg[1] * ntg[2]) + tpitg[1] * ntg[2] + tpitg[2]); + + if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) { + dst[offset_dst] = 0.0f; + } else { + const int32_t offset_src = tpitg[0] * ofs0 + tgpig[0] * ofs1; + dst[offset_dst] = x[offset_src + iih * IW + iiw]; + } +} + +kernel void kernel_upscale_f32( + device const char * src0, + device char * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + constant int32_t & sf, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + + const int64_t i3 = tgpig.z; + const int64_t i2 = tgpig.y; + const int64_t i1 = tgpig.x; + + const int64_t i03 = i3; + const int64_t i02 = i2; + const int64_t i01 = i1/sf; + + device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01); + device float * dst_ptr = (device float *) (dst + i3*nb3 + i2*nb2 + i1*nb1); + + for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { + dst_ptr[i0] = src0_ptr[i0/sf]; + } +} + +kernel void kernel_pad_f32( + device const char * src0, + device char * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + + const int64_t i3 = tgpig.z; + const int64_t i2 = tgpig.y; + const int64_t i1 = tgpig.x; + + const int64_t i03 = i3; + const int64_t i02 = i2; + const int64_t i01 = i1; + + device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01); + device float * dst_ptr = (device float *) (dst + i3*nb3 + i2*nb2 + i1*nb1); + + if (i1 < ne01 && i2 < ne02 && i3 < ne03) { + for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { + if (i0 < ne00) { + dst_ptr[i0] = src0_ptr[i0]; + } else { + dst_ptr[i0] = 0.0f; + } + } + + return; + } + + for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { + dst_ptr[i0] = 0.0f; + } +} + +// bitonic sort implementation following the CUDA kernels as reference +typedef void (argsort_t)( + device const float * x, + device int32_t * dst, + constant int64_t & ncols, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]]); + +template +kernel void kernel_argsort_f32_i32( + device const float * x, + device int32_t * dst, + constant int64_t & ncols, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]]) { + // bitonic sort + int col = tpitg[0]; + int row = tgpig[1]; + + if (col >= ncols) return; + + device const float * x_row = x + row * ncols; + device int32_t * dst_row = dst + row * ncols; + + // initialize indices + if (col < ncols) { + dst_row[col] = col; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + for (int k = 2; k <= ncols; k *= 2) { + for (int j = k / 2; j > 0; j /= 2) { + int ixj = col ^ j; + if (ixj > col) { + if ((col & k) == 0) { + if (order == GGML_SORT_ASC ? x_row[dst_row[col]] > x_row[dst_row[ixj]] : x_row[dst_row[col]] < x_row[dst_row[ixj]]) { + SWAP(dst_row[col], dst_row[ixj]); + } + } else { + if (order == GGML_SORT_ASC ? x_row[dst_row[col]] < x_row[dst_row[ixj]] : x_row[dst_row[col]] > x_row[dst_row[ixj]]) { + SWAP(dst_row[col], dst_row[ixj]); + } + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + } + } +} + +template [[host_name("kernel_argsort_f32_i32_asc")]] kernel argsort_t kernel_argsort_f32_i32; +template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32; + +kernel void kernel_leaky_relu_f32( + device const float * src0, + device float * dst, + constant float & slope, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * slope; +} + +kernel void kernel_cpy_f16_f16( + device const half * src0, + device half * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t i03 = tgpig[2]; + const int64_t i02 = tgpig[1]; + const int64_t i01 = tgpig[0]; + + const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + + const int64_t i3 = n / (ne2*ne1*ne0); + const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); + const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; + const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); + + device half * dst_data = (device half *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) { + device const half * src = (device half *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); + dst_data[i00] = src[0]; + } +} + +kernel void kernel_cpy_f16_f32( + device const half * src0, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t i03 = tgpig[2]; + const int64_t i02 = tgpig[1]; + const int64_t i01 = tgpig[0]; + + const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + + const int64_t i3 = n / (ne2*ne1*ne0); + const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); + const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; + const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); + + device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) { + device const half * src = (device half *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); + dst_data[i00] = src[0]; + } +} + +kernel void kernel_cpy_f32_f16( + device const float * src0, + device half * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t i03 = tgpig[2]; + const int64_t i02 = tgpig[1]; + const int64_t i01 = tgpig[0]; + + const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + + const int64_t i3 = n / (ne2*ne1*ne0); + const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); + const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; + const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); + + device half * dst_data = (device half *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) { + device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); + + dst_data[i00] = src[0]; + } +} + +kernel void kernel_cpy_f32_f32( + device const float * src0, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t i03 = tgpig[2]; + const int64_t i02 = tgpig[1]; + const int64_t i01 = tgpig[0]; + + const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + + const int64_t i3 = n / (ne2*ne1*ne0); + const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); + const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; + const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); + + device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) { + device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); + + dst_data[i00] = src[0]; + } +} + +kernel void kernel_cpy_f32_q8_0( + device const float * src0, + device void * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t i03 = tgpig[2]; + const int64_t i02 = tgpig[1]; + const int64_t i01 = tgpig[0]; + + const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + + const int64_t i3 = n / (ne2*ne1*ne0); + const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); + const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; + const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK8_0; + + device block_q8_0 * dst_data = (device block_q8_0 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + for (int64_t i00 = tpitg.x*QK8_0; i00 < ne00; i00 += ntg.x*QK8_0) { + device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); + + float amax = 0.0f; // absolute max + + for (int j = 0; j < QK8_0; j++) { + const float v = src[j]; + amax = MAX(amax, fabs(v)); + } + + const float d = amax / ((1 << 7) - 1); + const float id = d ? 1.0f/d : 0.0f; + + dst_data[i00/QK8_0].d = d; + + for (int j = 0; j < QK8_0; ++j) { + const float x0 = src[j]*id; + + dst_data[i00/QK8_0].qs[j] = round(x0); + } + } +} + +kernel void kernel_cpy_f32_q4_0( + device const float * src0, + device void * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t i03 = tgpig[2]; + const int64_t i02 = tgpig[1]; + const int64_t i01 = tgpig[0]; + + const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + + const int64_t i3 = n / (ne2*ne1*ne0); + const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); + const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; + const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK4_0; + + device block_q4_0 * dst_data = (device block_q4_0 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + for (int64_t i00 = tpitg.x*QK4_0; i00 < ne00; i00 += ntg.x*QK4_0) { + device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); + + float amax = 0.0f; // absolute max + float max = 0.0f; + + for (int j = 0; j < QK4_0; j++) { + const float v = src[j]; + if (amax < fabs(v)) { + amax = fabs(v); + max = v; + } + } + + const float d = max / -8; + const float id = d ? 1.0f/d : 0.0f; + + dst_data[i00/QK4_0].d = d; + + for (int j = 0; j < QK4_0/2; ++j) { + const float x0 = src[0 + j]*id; + const float x1 = src[QK4_0/2 + j]*id; + + const uint8_t xi0 = MIN(15, (int8_t)(x0 + 8.5f)); + const uint8_t xi1 = MIN(15, (int8_t)(x1 + 8.5f)); + + dst_data[i00/QK4_0].qs[j] = xi0; + dst_data[i00/QK4_0].qs[j] |= xi1 << 4; + } + } +} + +kernel void kernel_cpy_f32_q4_1( + device const float * src0, + device void * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t i03 = tgpig[2]; + const int64_t i02 = tgpig[1]; + const int64_t i01 = tgpig[0]; + + const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + + const int64_t i3 = n / (ne2*ne1*ne0); + const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); + const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; + const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK4_1; + + device block_q4_1 * dst_data = (device block_q4_1 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + for (int64_t i00 = tpitg.x*QK4_1; i00 < ne00; i00 += ntg.x*QK4_1) { + device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); + + float min = FLT_MAX; + float max = -FLT_MAX; + + for (int j = 0; j < QK4_1; j++) { + const float v = src[j]; + if (min > v) min = v; + if (max < v) max = v; + } + + const float d = (max - min) / ((1 << 4) - 1); + const float id = d ? 1.0f/d : 0.0f; + + dst_data[i00/QK4_1].d = d; + dst_data[i00/QK4_1].m = min; + + for (int j = 0; j < QK4_1/2; ++j) { + const float x0 = (src[0 + j] - min)*id; + const float x1 = (src[QK4_1/2 + j] - min)*id; + + const uint8_t xi0 = MIN(15, (int8_t)(x0 + 0.5f)); + const uint8_t xi1 = MIN(15, (int8_t)(x1 + 0.5f)); + + dst_data[i00/QK4_1].qs[j] = xi0; + dst_data[i00/QK4_1].qs[j] |= xi1 << 4; + } + } +} + +kernel void kernel_concat( + device const char * src0, + device const char * src1, + device char * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant uint64_t & nb13, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + + const int64_t i03 = tgpig.z; + const int64_t i02 = tgpig.y; + const int64_t i01 = tgpig.x; + + const int64_t i13 = i03 % ne13; + const int64_t i12 = i02 % ne12; + const int64_t i11 = i01 % ne11; + + device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + tpitg.x*nb00; + device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11 + tpitg.x*nb10; + device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + tpitg.x*nb0; + + for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { + if (i02 < ne02) { + ((device float *)dst_ptr)[0] = ((device float *)src0_ptr)[0]; + src0_ptr += ntg.x*nb00; + } else { + ((device float *)dst_ptr)[0] = ((device float *)src1_ptr)[0]; + src1_ptr += ntg.x*nb10; + } + dst_ptr += ntg.x*nb0; + } +} + +//============================================ k-quants ====================================================== + +#ifndef QK_K +#define QK_K 256 +#else +static_assert(QK_K == 256 || QK_K == 64, "QK_K must be 256 or 64"); +#endif + +#if QK_K == 256 +#define K_SCALE_SIZE 12 +#else +#define K_SCALE_SIZE 4 +#endif + +typedef struct { + uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits + uint8_t qs[QK_K/4]; // quants + half d; // super-block scale for quantized scales + half dmin; // super-block scale for quantized mins +} block_q2_K; +// 84 bytes / block + +typedef struct { + uint8_t hmask[QK_K/8]; // quants - high bit + uint8_t qs[QK_K/4]; // quants - low 2 bits +#if QK_K == 64 + uint8_t scales[2]; +#else + uint8_t scales[K_SCALE_SIZE]; // scales, quantized with 6 bits +#endif + half d; // super-block scale +} block_q3_K; + +#if QK_K == 64 +typedef struct { + half d[2]; // super-block scales/mins + uint8_t scales[2]; + uint8_t qs[QK_K/2]; // 4-bit quants +} block_q4_K; +#else +typedef struct { + half d; // super-block scale for quantized scales + half dmin; // super-block scale for quantized mins + uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits + uint8_t qs[QK_K/2]; // 4--bit quants +} block_q4_K; +#endif + +#if QK_K == 64 +typedef struct { + half d; // super-block scales/mins + int8_t scales[QK_K/16]; // 8-bit block scales + uint8_t qh[QK_K/8]; // quants, high bit + uint8_t qs[QK_K/2]; // quants, low 4 bits +} block_q5_K; +#else +typedef struct { + half d; // super-block scale for quantized scales + half dmin; // super-block scale for quantized mins + uint8_t scales[3*QK_K/64]; // scales and mins, quantized with 6 bits + uint8_t qh[QK_K/8]; // quants, high bit + uint8_t qs[QK_K/2]; // quants, low 4 bits +} block_q5_K; +// 176 bytes / block +#endif + +typedef struct { + uint8_t ql[QK_K/2]; // quants, lower 4 bits + uint8_t qh[QK_K/4]; // quants, upper 2 bits + int8_t scales[QK_K/16]; // scales, quantized with 8 bits + half d; // super-block scale +} block_q6_K; +// 210 bytes / block + +//====================================== dot products ========================= + +void kernel_mul_mv_q2_K_f32_impl( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne10, + constant int64_t & ne12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + const int nb = ne00/QK_K; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + + const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; + const int ib_row = first_row * nb; + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + + device const block_q2_K * x = (device const block_q2_K *) src0 + ib_row + offset0; + device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + + float yl[32]; + float sumf[N_DST]={0.f}, all_sum; + + const int step = sizeof(block_q2_K) * nb; + +#if QK_K == 256 + const int ix = tiisg/8; // 0...3 + const int it = tiisg%8; // 0...7 + const int iq = it/4; // 0 or 1 + const int ir = it%4; // 0...3 + const int is = (8*ir)/16;// 0 or 1 + + device const float * y4 = y + ix * QK_K + 128 * iq + 8 * ir; + + for (int ib = ix; ib < nb; ib += 4) { + + float4 sumy = {0.f, 0.f, 0.f, 0.f}; + for (int i = 0; i < 8; ++i) { + yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0]; + yl[i+ 8] = y4[i+32]; sumy[1] += yl[i+ 8]; + yl[i+16] = y4[i+64]; sumy[2] += yl[i+16]; + yl[i+24] = y4[i+96]; sumy[3] += yl[i+24]; + } + + device const uint8_t * sc = (device const uint8_t *)x[ib].scales + 8*iq + is; + device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir; + device const half * dh = &x[ib].d; + + for (int row = 0; row < N_DST; row++) { + + float4 acc1 = {0.f, 0.f, 0.f, 0.f}; + float4 acc2 = {0.f, 0.f, 0.f, 0.f}; + for (int i = 0; i < 8; i += 2) { + acc1[0] += yl[i+ 0] * (qs[i/2] & 0x0003); + acc2[0] += yl[i+ 1] * (qs[i/2] & 0x0300); + acc1[1] += yl[i+ 8] * (qs[i/2] & 0x000c); + acc2[1] += yl[i+ 9] * (qs[i/2] & 0x0c00); + acc1[2] += yl[i+16] * (qs[i/2] & 0x0030); + acc2[2] += yl[i+17] * (qs[i/2] & 0x3000); + acc1[3] += yl[i+24] * (qs[i/2] & 0x00c0); + acc2[3] += yl[i+25] * (qs[i/2] & 0xc000); + } + float dall = dh[0]; + float dmin = dh[1] * 1.f/16.f; + sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc2[0]) * (sc[0] & 0xF) * 1.f/ 1.f + + (acc1[1] + 1.f/256.f * acc2[1]) * (sc[2] & 0xF) * 1.f/ 4.f + + (acc1[2] + 1.f/256.f * acc2[2]) * (sc[4] & 0xF) * 1.f/16.f + + (acc1[3] + 1.f/256.f * acc2[3]) * (sc[6] & 0xF) * 1.f/64.f) - + dmin * (sumy[0] * (sc[0] & 0xF0) + sumy[1] * (sc[2] & 0xF0) + sumy[2] * (sc[4] & 0xF0) + sumy[3] * (sc[6] & 0xF0)); + + qs += step/2; + sc += step; + dh += step/2; + } + + y4 += 4 * QK_K; + } +#else + const int ix = tiisg/2; // 0...15 + const int it = tiisg%2; // 0...1 + + device const float * y4 = y + ix * QK_K + 8 * it; + + for (int ib = ix; ib < nb; ib += 16) { + + float4 sumy = {0.f, 0.f, 0.f, 0.f}; + for (int i = 0; i < 8; ++i) { + yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0]; + yl[i+ 8] = y4[i+16]; sumy[1] += yl[i+ 8]; + yl[i+16] = y4[i+32]; sumy[2] += yl[i+16]; + yl[i+24] = y4[i+48]; sumy[3] += yl[i+24]; + } + + device const uint8_t * sc = (device const uint8_t *)x[ib].scales; + device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 4 * it; + device const half * dh = &x[ib].d; + + for (int row = 0; row < N_DST; row++) { + + float4 acc1 = {0.f, 0.f, 0.f, 0.f}; + float4 acc2 = {0.f, 0.f, 0.f, 0.f}; + for (int i = 0; i < 8; i += 2) { + acc1[0] += yl[i+ 0] * (qs[i/2] & 0x0003); + acc2[0] += yl[i+ 1] * (qs[i/2] & 0x0300); + acc1[1] += yl[i+ 8] * (qs[i/2] & 0x000c); + acc2[1] += yl[i+ 9] * (qs[i/2] & 0x0c00); + acc1[2] += yl[i+16] * (qs[i/2] & 0x0030); + acc2[2] += yl[i+17] * (qs[i/2] & 0x3000); + acc1[3] += yl[i+24] * (qs[i/2] & 0x00c0); + acc2[3] += yl[i+25] * (qs[i/2] & 0xc000); + } + + float dall = dh[0]; + float dmin = dh[1]; + sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc2[0]) * (sc[0] & 0xF) * 1.f/ 1.f + + (acc1[1] + 1.f/256.f * acc2[1]) * (sc[1] & 0xF) * 1.f/ 4.f + + (acc1[2] + 1.f/256.f * acc2[2]) * (sc[2] & 0xF) * 1.f/16.f + + (acc1[3] + 1.f/256.f * acc2[3]) * (sc[3] & 0xF) * 1.f/64.f) - + dmin * (sumy[0] * (sc[0] >> 4) + sumy[1] * (sc[1] >> 4) + sumy[2] * (sc[2] >> 4) + sumy[3] * (sc[3] >> 4)); + + qs += step/2; + sc += step; + dh += step/2; + } + + y4 += 16 * QK_K; + } +#endif + + for (int row = 0; row < N_DST; ++row) { + all_sum = simd_sum(sumf[row]); + if (tiisg == 0) { + dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum; + } + } +} + +[[host_name("kernel_mul_mv_q2_K_f32")]] +kernel void kernel_mul_mv_q2_K_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_q2_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg); +} + +#if QK_K == 256 +void kernel_mul_mv_q3_K_f32_impl( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne10, + constant int64_t & ne12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + const int nb = ne00/QK_K; + + const int64_t r0 = tgpig.x; + const int64_t r1 = tgpig.y; + const int64_t im = tgpig.z; + + const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2; + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + + device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb + offset0; + device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + + float yl[32]; + + //const uint16_t kmask1 = 0x3030; + //const uint16_t kmask2 = 0x0f0f; + + const int tid = tiisg/4; + const int ix = tiisg%4; + const int ip = tid/4; // 0 or 1 + const int il = 2*((tid%4)/2); // 0 or 2 + const int ir = tid%2; + const int n = 8; + const int l0 = n*ir; + + // One would think that the Metal compiler would figure out that ip and il can only have + // 4 possible states, and optimize accordingly. Well, no. It needs help, and we do it + // with these two tales. + // + // Possible masks for the high bit + const ushort4 mm[4] = {{0x0001, 0x0100, 0x0002, 0x0200}, // ip = 0, il = 0 + {0x0004, 0x0400, 0x0008, 0x0800}, // ip = 0, il = 2 + {0x0010, 0x1000, 0x0020, 0x2000}, // ip = 1, il = 0 + {0x0040, 0x4000, 0x0080, 0x8000}}; // ip = 1, il = 2 + + // Possible masks for the low 2 bits + const int4 qm[2] = {{0x0003, 0x0300, 0x000c, 0x0c00}, {0x0030, 0x3000, 0x00c0, 0xc000}}; + + const ushort4 hm = mm[2*ip + il/2]; + + const int shift = 2*il; + const float v1 = il == 0 ? 4.f : 64.f; + const float v2 = 4.f * v1; + + const uint16_t s_shift1 = 4*ip; + const uint16_t s_shift2 = s_shift1 + il; + + const int q_offset = 32*ip + l0; + const int y_offset = 128*ip + 32*il + l0; + + const int step = sizeof(block_q3_K) * nb / 2; + + device const float * y1 = yy + ix*QK_K + y_offset; + + uint32_t scales32, aux32; + thread uint16_t * scales16 = (thread uint16_t *)&scales32; + thread const int8_t * scales = (thread const int8_t *)&scales32; + + float sumf1[2] = {0.f}; + float sumf2[2] = {0.f}; + for (int i = ix; i < nb; i += 4) { + + for (int l = 0; l < 8; ++l) { + yl[l+ 0] = y1[l+ 0]; + yl[l+ 8] = y1[l+16]; + yl[l+16] = y1[l+32]; + yl[l+24] = y1[l+48]; + } + + device const uint16_t * q = (device const uint16_t *)(x[i].qs + q_offset); + device const uint16_t * h = (device const uint16_t *)(x[i].hmask + l0); + device const uint16_t * a = (device const uint16_t *)(x[i].scales); + device const half * dh = &x[i].d; + + for (int row = 0; row < 2; ++row) { + + const float d_all = (float)dh[0]; + + scales16[0] = a[4]; + scales16[1] = a[5]; + aux32 = ((scales32 >> s_shift2) << 4) & 0x30303030; + scales16[0] = a[il+0]; + scales16[1] = a[il+1]; + scales32 = ((scales32 >> s_shift1) & 0x0f0f0f0f) | aux32; + + float s1 = 0, s2 = 0, s3 = 0, s4 = 0, s5 = 0, s6 = 0; + for (int l = 0; l < n; l += 2) { + const int32_t qs = q[l/2]; + s1 += yl[l+0] * (qs & qm[il/2][0]); + s2 += yl[l+1] * (qs & qm[il/2][1]); + s3 += ((h[l/2] & hm[0]) ? 0.f : yl[l+0]) + ((h[l/2] & hm[1]) ? 0.f : yl[l+1]); + s4 += yl[l+16] * (qs & qm[il/2][2]); + s5 += yl[l+17] * (qs & qm[il/2][3]); + s6 += ((h[l/2] & hm[2]) ? 0.f : yl[l+16]) + ((h[l/2] & hm[3]) ? 0.f : yl[l+17]); + } + float d1 = d_all * (s1 + 1.f/256.f * s2 - s3*v1); + float d2 = d_all * (s4 + 1.f/256.f * s5 - s6*v2); + sumf1[row] += d1 * (scales[0] - 32); + sumf2[row] += d2 * (scales[2] - 32); + + s1 = s2 = s3 = s4 = s5 = s6 = 0; + for (int l = 0; l < n; l += 2) { + const int32_t qs = q[l/2+8]; + s1 += yl[l+8] * (qs & qm[il/2][0]); + s2 += yl[l+9] * (qs & qm[il/2][1]); + s3 += ((h[l/2+8] & hm[0]) ? 0.f : yl[l+8]) + ((h[l/2+8] & hm[1]) ? 0.f : yl[l+9]); + s4 += yl[l+24] * (qs & qm[il/2][2]); + s5 += yl[l+25] * (qs & qm[il/2][3]); + s6 += ((h[l/2+8] & hm[2]) ? 0.f : yl[l+24]) + ((h[l/2+8] & hm[3]) ? 0.f : yl[l+25]); + } + d1 = d_all * (s1 + 1.f/256.f * s2 - s3*v1); + d2 = d_all * (s4 + 1.f/256.f * s5 - s6*v2); + sumf1[row] += d1 * (scales[1] - 32); + sumf2[row] += d2 * (scales[3] - 32); + + q += step; + h += step; + a += step; + dh += step; + + } + + y1 += 4 * QK_K; + + } + + for (int row = 0; row < 2; ++row) { + const float sumf = (sumf1[row] + 0.25f * sumf2[row]) / (1 << shift); + sumf1[row] = simd_sum(sumf); + } + if (tiisg == 0) { + for (int row = 0; row < 2; ++row) { + dst[r1*ne0 + im*ne0*ne1 + first_row + row] = sumf1[row]; + } + } +} +#else +void kernel_mul_mv_q3_K_f32_impl( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne10, + constant int64_t & ne12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + const int nb = ne00/QK_K; + + const int64_t r0 = tgpig.x; + const int64_t r1 = tgpig.y; + const int64_t im = tgpig.z; + + const int row = 2 * r0 + sgitg; + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + + device const block_q3_K * x = (device const block_q3_K *) src0 + row*nb + offset0; + device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + + const int ix = tiisg/4; + const int il = 4 * (tiisg%4);// 0, 4, 8, 12 + const int iq = il/8; // 0, 0, 1, 1 + const int in = il%8; // 0, 4, 0, 4 + + float2 sum = {0.f, 0.f}; + + for (int i = ix; i < nb; i += 8) { + + const float d_all = (float)(x[i].d); + + device const uint16_t * q = (device const uint16_t *)(x[i].qs + il); + device const uint16_t * h = (device const uint16_t *)(x[i].hmask + in); + device const uint16_t * s = (device const uint16_t *)(x[i].scales); + device const float * y = yy + i * QK_K + il; + + const float d1 = d_all * ((int32_t)(s[0] & 0x000F) - 8); + const float d2 = d_all * ((int32_t)(s[0] & 0x00F0) - 128) * 1.f/64.f; + const float d3 = d_all * ((int32_t)(s[0] & 0x0F00) - 2048) * 1.f/4096.f; + const float d4 = d_all * ((int32_t)(s[0] & 0xF000) - 32768) * 1.f/262144.f; + + for (int l = 0; l < 4; l += 2) { + const uint16_t hm = h[l/2] >> iq; + sum[0] += y[l+ 0] * d1 * ((int32_t)(q[l/2] & 0x0003) - ((hm & 0x0001) ? 0 : 4)) + + y[l+16] * d2 * ((int32_t)(q[l/2] & 0x000c) - ((hm & 0x0004) ? 0 : 16)) + + y[l+32] * d3 * ((int32_t)(q[l/2] & 0x0030) - ((hm & 0x0010) ? 0 : 64)) + + y[l+48] * d4 * ((int32_t)(q[l/2] & 0x00c0) - ((hm & 0x0040) ? 0 : 256)); + sum[1] += y[l+ 1] * d1 * ((int32_t)(q[l/2] & 0x0300) - ((hm & 0x0100) ? 0 : 1024)) + + y[l+17] * d2 * ((int32_t)(q[l/2] & 0x0c00) - ((hm & 0x0400) ? 0 : 4096)) + + y[l+33] * d3 * ((int32_t)(q[l/2] & 0x3000) - ((hm & 0x1000) ? 0 : 16384)) + + y[l+49] * d4 * ((int32_t)(q[l/2] & 0xc000) - ((hm & 0x4000) ? 0 : 65536)); + } + + } + const float sumf = sum[0] + sum[1] * 1.f/256.f; + + const float tot = simd_sum(sumf); + if (tiisg == 0) { + dst[r1*ne0 + im*ne0*ne1 + row] = tot; + } + +} +#endif + +[[host_name("kernel_mul_mv_q3_K_f32")]] +kernel void kernel_mul_mv_q3_K_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_q3_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg); +} + +#if QK_K == 256 +void kernel_mul_mv_q4_K_f32_impl( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne10, + constant int64_t & ne12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + const uint16_t kmask1 = 0x3f3f; + const uint16_t kmask2 = 0x0f0f; + const uint16_t kmask3 = 0xc0c0; + + const int ix = tiisg/8; // 0...3 + const int it = tiisg%8; // 0...7 + const int iq = it/4; // 0 or 1 + const int ir = it%4; // 0...3 + + const int nb = ne00/QK_K; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + //const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; + const int first_row = r0 * N_DST; + const int ib_row = first_row * nb; + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + + device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0; + device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + + float yl[16]; + float yh[16]; + float sumf[N_DST]={0.f}, all_sum; + + const int step = sizeof(block_q4_K) * nb / 2; + + device const float * y4 = y + ix * QK_K + 64 * iq + 8 * ir; + + uint16_t sc16[4]; + thread const uint8_t * sc8 = (thread const uint8_t *)sc16; + + for (int ib = ix; ib < nb; ib += 4) { + + float4 sumy = {0.f, 0.f, 0.f, 0.f}; + for (int i = 0; i < 8; ++i) { + yl[i+0] = y4[i+ 0]; sumy[0] += yl[i+0]; + yl[i+8] = y4[i+ 32]; sumy[1] += yl[i+8]; + yh[i+0] = y4[i+128]; sumy[2] += yh[i+0]; + yh[i+8] = y4[i+160]; sumy[3] += yh[i+8]; + } + + device const uint16_t * sc = (device const uint16_t *)x[ib].scales + iq; + device const uint16_t * q1 = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir; + device const half * dh = &x[ib].d; + + for (int row = 0; row < N_DST; row++) { + + sc16[0] = sc[0] & kmask1; + sc16[1] = sc[2] & kmask1; + sc16[2] = ((sc[4] >> 0) & kmask2) | ((sc[0] & kmask3) >> 2); + sc16[3] = ((sc[4] >> 4) & kmask2) | ((sc[2] & kmask3) >> 2); + + device const uint16_t * q2 = q1 + 32; + + float4 acc1 = {0.f, 0.f, 0.f, 0.f}; + float4 acc2 = {0.f, 0.f, 0.f, 0.f}; + for (int i = 0; i < 8; i += 2) { + acc1[0] += yl[i+0] * (q1[i/2] & 0x000F); + acc1[1] += yl[i+1] * (q1[i/2] & 0x0F00); + acc1[2] += yl[i+8] * (q1[i/2] & 0x00F0); + acc1[3] += yl[i+9] * (q1[i/2] & 0xF000); + acc2[0] += yh[i+0] * (q2[i/2] & 0x000F); + acc2[1] += yh[i+1] * (q2[i/2] & 0x0F00); + acc2[2] += yh[i+8] * (q2[i/2] & 0x00F0); + acc2[3] += yh[i+9] * (q2[i/2] & 0xF000); + } + + float dall = dh[0]; + float dmin = dh[1]; + sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc8[0] + + (acc1[2] + 1.f/256.f * acc1[3]) * sc8[1] * 1.f/16.f + + (acc2[0] + 1.f/256.f * acc2[1]) * sc8[4] + + (acc2[2] + 1.f/256.f * acc2[3]) * sc8[5] * 1.f/16.f) - + dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]); + + q1 += step; + sc += step; + dh += step; + } + + y4 += 4 * QK_K; + } + + for (int row = 0; row < N_DST; ++row) { + all_sum = simd_sum(sumf[row]); + if (tiisg == 0) { + dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum; + } + } +} +#else +void kernel_mul_mv_q4_K_f32_impl( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne10, + constant int64_t & ne12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + const int ix = tiisg/4; // 0...7 + const int it = tiisg%4; // 0...3 + + const int nb = ne00/QK_K; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + const int first_row = r0 * N_DST; + const int ib_row = first_row * nb; + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + + device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0; + device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + + float yl[8]; + float yh[8]; + float sumf[N_DST]={0.f}, all_sum; + + const int step = sizeof(block_q4_K) * nb / 2; + + device const float * y4 = y + ix * QK_K + 8 * it; + + uint16_t sc16[4]; + + for (int ib = ix; ib < nb; ib += 8) { + + float2 sumy = {0.f, 0.f}; + for (int i = 0; i < 8; ++i) { + yl[i] = y4[i+ 0]; sumy[0] += yl[i]; + yh[i] = y4[i+32]; sumy[1] += yh[i]; + } + + device const uint16_t * sc = (device const uint16_t *)x[ib].scales; + device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 4 * it; + device const half * dh = x[ib].d; + + for (int row = 0; row < N_DST; row++) { + + sc16[0] = sc[0] & 0x000f; + sc16[1] = sc[0] & 0x0f00; + sc16[2] = sc[0] & 0x00f0; + sc16[3] = sc[0] & 0xf000; + + float2 acc1 = {0.f, 0.f}; + float2 acc2 = {0.f, 0.f}; + for (int i = 0; i < 8; i += 2) { + acc1[0] += yl[i+0] * (qs[i/2] & 0x000F); + acc1[1] += yl[i+1] * (qs[i/2] & 0x0F00); + acc2[0] += yh[i+0] * (qs[i/2] & 0x00F0); + acc2[1] += yh[i+1] * (qs[i/2] & 0xF000); + } + + float dall = dh[0]; + float dmin = dh[1]; + sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc16[0] + + (acc2[0] + 1.f/256.f * acc2[1]) * sc16[1] * 1.f/4096.f) - + dmin * 1.f/16.f * (sumy[0] * sc16[2] + sumy[1] * sc16[3] * 1.f/256.f); + + qs += step; + sc += step; + dh += step; + } + + y4 += 8 * QK_K; + } + + for (int row = 0; row < N_DST; ++row) { + all_sum = simd_sum(sumf[row]); + if (tiisg == 0) { + dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum; + } + } +} +#endif + +[[host_name("kernel_mul_mv_q4_K_f32")]] +kernel void kernel_mul_mv_q4_K_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_q4_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg); +} + +void kernel_mul_mv_q5_K_f32_impl( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne10, + constant int64_t & ne12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + const int nb = ne00/QK_K; + + const int64_t r0 = tgpig.x; + const int64_t r1 = tgpig.y; + const int im = tgpig.z; + + const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2; + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + + device const block_q5_K * x = (device const block_q5_K *) src0 + first_row*nb + offset0; + device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + + float sumf[2]={0.f}; + + const int step = sizeof(block_q5_K) * nb; + +#if QK_K == 256 +# + float yl[16], yh[16]; + + const uint16_t kmask1 = 0x3f3f; + const uint16_t kmask2 = 0x0f0f; + const uint16_t kmask3 = 0xc0c0; + + const int tid = tiisg/4; + const int ix = tiisg%4; + const int iq = tid/4; + const int ir = tid%4; + const int n = 8; + + const int l0 = n*ir; + const int q_offset = 32*iq + l0; + const int y_offset = 64*iq + l0; + + const uint8_t hm1 = 1u << (2*iq); + const uint8_t hm2 = hm1 << 1; + const uint8_t hm3 = hm1 << 4; + const uint8_t hm4 = hm2 << 4; + + uint16_t sc16[4]; + thread const uint8_t * sc8 = (thread const uint8_t *)sc16; + + device const float * y1 = yy + ix*QK_K + y_offset; + + for (int i = ix; i < nb; i += 4) { + + device const uint8_t * q1 = x[i].qs + q_offset; + device const uint8_t * qh = x[i].qh + l0; + device const half * dh = &x[i].d; + device const uint16_t * a = (device const uint16_t *)x[i].scales + iq; + + device const float * y2 = y1 + 128; + float4 sumy = {0.f, 0.f, 0.f, 0.f}; + for (int l = 0; l < 8; ++l) { + yl[l+0] = y1[l+ 0]; sumy[0] += yl[l+0]; + yl[l+8] = y1[l+32]; sumy[1] += yl[l+8]; + yh[l+0] = y2[l+ 0]; sumy[2] += yh[l+0]; + yh[l+8] = y2[l+32]; sumy[3] += yh[l+8]; + } + + for (int row = 0; row < 2; ++row) { + + device const uint8_t * q2 = q1 + 64; + + sc16[0] = a[0] & kmask1; + sc16[1] = a[2] & kmask1; + sc16[2] = ((a[4] >> 0) & kmask2) | ((a[0] & kmask3) >> 2); + sc16[3] = ((a[4] >> 4) & kmask2) | ((a[2] & kmask3) >> 2); + + float4 acc1 = {0.f}; + float4 acc2 = {0.f}; + for (int l = 0; l < n; ++l) { + uint8_t h = qh[l]; + acc1[0] += yl[l+0] * (q1[l] & 0x0F); + acc1[1] += yl[l+8] * (q1[l] & 0xF0); + acc1[2] += yh[l+0] * (q2[l] & 0x0F); + acc1[3] += yh[l+8] * (q2[l] & 0xF0); + acc2[0] += h & hm1 ? yl[l+0] : 0.f; + acc2[1] += h & hm2 ? yl[l+8] : 0.f; + acc2[2] += h & hm3 ? yh[l+0] : 0.f; + acc2[3] += h & hm4 ? yh[l+8] : 0.f; + } + const float dall = dh[0]; + const float dmin = dh[1]; + sumf[row] += dall * (sc8[0] * (acc1[0] + 16.f*acc2[0]) + + sc8[1] * (acc1[1]/16.f + 16.f*acc2[1]) + + sc8[4] * (acc1[2] + 16.f*acc2[2]) + + sc8[5] * (acc1[3]/16.f + 16.f*acc2[3])) - + dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]); + + q1 += step; + qh += step; + dh += step/2; + a += step/2; + + } + + y1 += 4 * QK_K; + + } +#else + float yl[8], yh[8]; + + const int il = 4 * (tiisg/8); // 0, 4, 8, 12 + const int ix = tiisg%8; + const int iq = il/8; // 0, 0, 1, 1 + const int in = il%8; // 0, 4, 0, 4 + + device const float * y = yy + ix*QK_K + il; + + for (int i = ix; i < nb; i += 8) { + + for (int l = 0; l < 4; ++l) { + yl[l+0] = y[l+ 0]; + yl[l+4] = y[l+16]; + yh[l+0] = y[l+32]; + yh[l+4] = y[l+48]; + } + + device const half * dh = &x[i].d; + device const uint8_t * q = x[i].qs + il; + device const uint8_t * h = x[i].qh + in; + device const int8_t * s = x[i].scales; + + for (int row = 0; row < 2; ++row) { + + const float d = dh[0]; + + float2 acc = {0.f, 0.f}; + for (int l = 0; l < 4; ++l) { + const uint8_t hl = h[l] >> iq; + acc[0] += yl[l+0] * s[0] * ((int16_t)(q[l+ 0] & 0x0F) - (hl & 0x01 ? 0 : 16)) + + yl[l+4] * s[1] * ((int16_t)(q[l+16] & 0x0F) - (hl & 0x04 ? 0 : 16)); + acc[1] += yh[l+0] * s[2] * ((int16_t)(q[l+ 0] & 0xF0) - (hl & 0x10 ? 0 : 256)) + + yh[l+4] * s[3] * ((int16_t)(q[l+16] & 0xF0) - (hl & 0x40 ? 0 : 256)); + } + sumf[row] += d * (acc[0] + 1.f/16.f * acc[1]); + + q += step; + h += step; + s += step; + dh += step/2; + + } + + y += 8 * QK_K; + } +#endif + + for (int row = 0; row < 2; ++row) { + const float tot = simd_sum(sumf[row]); + if (tiisg == 0) { + dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot; + } + } +} + +[[host_name("kernel_mul_mv_q5_K_f32")]] +kernel void kernel_mul_mv_q5_K_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_q5_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg); +} + +void kernel_mul_mv_q6_K_f32_impl( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne10, + constant int64_t & ne12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + const uint8_t kmask1 = 0x03; + const uint8_t kmask2 = 0x0C; + const uint8_t kmask3 = 0x30; + const uint8_t kmask4 = 0xC0; + + const int nb = ne00/QK_K; + + const int64_t r0 = tgpig.x; + const int64_t r1 = tgpig.y; + const int im = tgpig.z; + + const int row = 2 * r0 + sgitg; + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + + device const block_q6_K * x = (device const block_q6_K *) src0 + row * nb + offset0; + device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + + float sumf = 0; + +#if QK_K == 256 + const int tid = tiisg/2; + const int ix = tiisg%2; + const int ip = tid/8; // 0 or 1 + const int il = tid%8; + const int n = 4; + const int l0 = n*il; + const int is = 8*ip + l0/16; + + const int y_offset = 128*ip + l0; + const int q_offset_l = 64*ip + l0; + const int q_offset_h = 32*ip + l0; + + for (int i = ix; i < nb; i += 2) { + + device const uint8_t * q1 = x[i].ql + q_offset_l; + device const uint8_t * q2 = q1 + 32; + device const uint8_t * qh = x[i].qh + q_offset_h; + device const int8_t * sc = x[i].scales + is; + + device const float * y = yy + i * QK_K + y_offset; + + const float dall = x[i].d; + + float4 sums = {0.f, 0.f, 0.f, 0.f}; + for (int l = 0; l < n; ++l) { + sums[0] += y[l+ 0] * ((int8_t)((q1[l] & 0xF) | ((qh[l] & kmask1) << 4)) - 32); + sums[1] += y[l+32] * ((int8_t)((q2[l] & 0xF) | ((qh[l] & kmask2) << 2)) - 32); + sums[2] += y[l+64] * ((int8_t)((q1[l] >> 4) | ((qh[l] & kmask3) << 0)) - 32); + sums[3] += y[l+96] * ((int8_t)((q2[l] >> 4) | ((qh[l] & kmask4) >> 2)) - 32); + } + + sumf += dall * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]); + + } + +#else + const int ix = tiisg/4; + const int il = 4*(tiisg%4); + + for (int i = ix; i < nb; i += 8) { + device const float * y = yy + i * QK_K + il; + device const uint8_t * ql = x[i].ql + il; + device const uint8_t * qh = x[i].qh + il; + device const int8_t * s = x[i].scales; + + const float d = x[i].d; + + float4 sums = {0.f, 0.f, 0.f, 0.f}; + for (int l = 0; l < 4; ++l) { + sums[0] += y[l+ 0] * ((int8_t)((ql[l+ 0] & 0xF) | ((qh[l] & kmask1) << 4)) - 32); + sums[1] += y[l+16] * ((int8_t)((ql[l+16] & 0xF) | ((qh[l] & kmask2) << 2)) - 32); + sums[2] += y[l+32] * ((int8_t)((ql[l+ 0] >> 4) | ((qh[l] & kmask3) >> 0)) - 32); + sums[3] += y[l+48] * ((int8_t)((ql[l+16] >> 4) | ((qh[l] & kmask4) >> 2)) - 32); + } + sumf += d * (sums[0] * s[0] + sums[1] * s[1] + sums[2] * s[2] + sums[3] * s[3]); + } + +#endif + + const float tot = simd_sum(sumf); + if (tiisg == 0) { + dst[r1*ne0 + im*ne0*ne1 + row] = tot; + } +} + +[[host_name("kernel_mul_mv_q6_K_f32")]] +kernel void kernel_mul_mv_q6_K_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_q6_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg); +} + +//============================= templates and their specializations ============================= + +// NOTE: this is not dequantizing - we are simply fitting the template +template +void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) { + float4x4 temp = *(((device float4x4 *)src)); + for (int i = 0; i < 16; i++){ + reg[i/4][i%4] = temp[i/4][i%4]; + } +} + +template +void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) { + half4x4 temp = *(((device half4x4 *)src)); + for (int i = 0; i < 16; i++){ + reg[i/4][i%4] = temp[i/4][i%4]; + } +} + +template +void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) { + device const uint16_t * qs = ((device const uint16_t *)xb + 1); + const float d1 = il ? (xb->d / 16.h) : xb->d; + const float d2 = d1 / 256.f; + const float md = -8.h * xb->d; + const ushort mask0 = il ? 0x00F0 : 0x000F; + const ushort mask1 = mask0 << 8; + + for (int i=0;i<8;i++) { + reg[i/2][2*(i%2)+0] = d1 * (qs[i] & mask0) + md; + reg[i/2][2*(i%2)+1] = d2 * (qs[i] & mask1) + md; + } +} + +template +void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg) { + device const uint16_t * qs = ((device const uint16_t *)xb + 2); + const float d1 = il ? (xb->d / 16.h) : xb->d; + const float d2 = d1 / 256.f; + const float m = xb->m; + const ushort mask0 = il ? 0x00F0 : 0x000F; + const ushort mask1 = mask0 << 8; + + for (int i=0;i<8;i++) { + reg[i/2][2*(i%2)+0] = ((qs[i] & mask0) * d1) + m; + reg[i/2][2*(i%2)+1] = ((qs[i] & mask1) * d2) + m; + } +} + +template +void dequantize_q5_0(device const block_q5_0 *xb, short il, thread type4x4 & reg) { + device const uint16_t * qs = ((device const uint16_t *)xb + 3); + const float d = xb->d; + const float md = -16.h * xb->d; + const ushort mask = il ? 0x00F0 : 0x000F; + + const uint32_t qh = *((device const uint32_t *)xb->qh); + + const int x_mv = il ? 4 : 0; + + const int gh_mv = il ? 12 : 0; + const int gh_bk = il ? 0 : 4; + + for (int i = 0; i < 8; i++) { + // extract the 5-th bits for x0 and x1 + const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10; + const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10; + + // combine the 4-bits from qs with the 5th bit + const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0); + const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1); + + reg[i/2][2*(i%2)+0] = d * x0 + md; + reg[i/2][2*(i%2)+1] = d * x1 + md; + } +} + +template +void dequantize_q5_1(device const block_q5_1 *xb, short il, thread type4x4 & reg) { + device const uint16_t * qs = ((device const uint16_t *)xb + 4); + const float d = xb->d; + const float m = xb->m; + const ushort mask = il ? 0x00F0 : 0x000F; + + const uint32_t qh = *((device const uint32_t *)xb->qh); + + const int x_mv = il ? 4 : 0; + + const int gh_mv = il ? 12 : 0; + const int gh_bk = il ? 0 : 4; + + for (int i = 0; i < 8; i++) { + // extract the 5-th bits for x0 and x1 + const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10; + const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10; + + // combine the 4-bits from qs with the 5th bit + const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0); + const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1); + + reg[i/2][2*(i%2)+0] = d * x0 + m; + reg[i/2][2*(i%2)+1] = d * x1 + m; + } +} + +template +void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) { + device const int8_t * qs = ((device const int8_t *)xb->qs); + const half d = xb->d; + + for (int i = 0; i < 16; i++) { + reg[i/4][i%4] = (qs[i + 16*il] * d); + } +} + +template +void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) { + const float d = xb->d; + const float min = xb->dmin; + device const uint8_t * q = (device const uint8_t *)xb->qs; + float dl, ml; + uint8_t sc = xb->scales[il]; + +#if QK_K == 256 + q = q + 32*(il/8) + 16*(il&1); + il = (il/2)%4; +#endif + half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h); + uchar mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3); + dl = d * (sc & 0xF) * coef, ml = min * (sc >> 4); + for (int i = 0; i < 16; ++i) { + reg[i/4][i%4] = dl * (q[i] & mask) - ml; + } +} + +template +void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg) { + const half d_all = xb->d; + device const uint8_t * q = (device const uint8_t *)xb->qs; + device const uint8_t * h = (device const uint8_t *)xb->hmask; + device const int8_t * scales = (device const int8_t *)xb->scales; + +#if QK_K == 256 + q = q + 32 * (il/8) + 16 * (il&1); + h = h + 16 * (il&1); + uint8_t m = 1 << (il/2); + uint16_t kmask1 = (il/4)>1 ? ((il/4)>2 ? 192 : 48) : \ + ((il/4)>0 ? 12 : 3); + uint16_t kmask2 = il/8 ? 0xF0 : 0x0F; + uint16_t scale_2 = scales[il%8], scale_1 = scales[8 + il%4]; + int16_t dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2) + : (scale_2&kmask2) | ((scale_1&kmask1) << 4); + half dl = il<8 ? d_all * (dl_int - 32.h) : d_all * (dl_int / 16.h - 32.h); + const half ml = 4.h * dl; + + il = (il/2) & 3; + const half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h); + const uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3); + dl *= coef; + + for (int i = 0; i < 16; ++i) { + reg[i/4][i%4] = dl * (q[i] & mask) - (h[i] & m ? 0 : ml); + } +#else + float kcoef = il&1 ? 1.f/16.f : 1.f; + uint16_t kmask = il&1 ? 0xF0 : 0x0F; + float dl = d_all * ((scales[il/2] & kmask) * kcoef - 8); + float coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h); + uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3); + uint8_t m = 1<<(il*2); + for (int i = 0; i < 16; ++i) { + reg[i/4][i%4] = coef * dl * ((q[i] & mask) - ((h[i%8] & (m * (1 + i/8))) ? 0 : 4.f/coef)); + } +#endif +} + +static inline uchar2 get_scale_min_k4_just2(int j, int k, device const uchar * q) { + return j < 4 ? uchar2{uchar(q[j+0+k] & 63), uchar(q[j+4+k] & 63)} + : uchar2{uchar((q[j+4+k] & 0xF) | ((q[j-4+k] & 0xc0) >> 2)), uchar((q[j+4+k] >> 4) | ((q[j-0+k] & 0xc0) >> 2))}; +} + +template +void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg) { + device const uchar * q = xb->qs; + +#if QK_K == 256 + short is = (il/4) * 2; + q = q + (il/4) * 32 + 16 * (il&1); + il = il & 3; + const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales); + const float d = il < 2 ? xb->d : xb->d / 16.h; + const float min = xb->dmin; + const float dl = d * sc[0]; + const float ml = min * sc[1]; +#else + q = q + 16 * (il&1); + device const uint8_t * s = xb->scales; + device const half2 * dh = (device const half2 *)xb->d; + const float2 d = (float2)dh[0]; + const float dl = il<2 ? d[0] * (s[0]&0xF) : d[0] * (s[1]&0xF)/16.h; + const float ml = il<2 ? d[1] * (s[0]>>4) : d[1] * (s[1]>>4); +#endif + const ushort mask = il<2 ? 0x0F : 0xF0; + for (int i = 0; i < 16; ++i) { + reg[i/4][i%4] = dl * (q[i] & mask) - ml; + } +} + +template +void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg) { + device const uint8_t * q = xb->qs; + device const uint8_t * qh = xb->qh; + +#if QK_K == 256 + short is = (il/4) * 2; + q = q + 32 * (il/4) + 16 * (il&1); + qh = qh + 16 * (il&1); + uint8_t ul = 1 << (il/2); + il = il & 3; + const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales); + const float d = il < 2 ? xb->d : xb->d / 16.h; + const float min = xb->dmin; + const float dl = d * sc[0]; + const float ml = min * sc[1]; + + const ushort mask = il<2 ? 0x0F : 0xF0; + const float qh_val = il<2 ? 16.f : 256.f; + for (int i = 0; i < 16; ++i) { + reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml; + } +#else + q = q + 16 * (il&1); + device const int8_t * s = xb->scales; + const float dl = xb->d * s[il]; + uint8_t m = 1<<(il*2); + const float coef = il<2 ? 1.f : 1.f/16.f; + const ushort mask = il<2 ? 0x0F : 0xF0; + for (int i = 0; i < 16; ++i) { + reg[i/4][i%4] = coef * dl * ((q[i] & mask) - (qh[i%8] & (m*(1+i/8)) ? 0.f : 16.f/coef)); + } +#endif +} + +template +void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg) { + const half d_all = xb->d; + device const uint8_t * ql = (device const uint8_t *)xb->ql; + device const uint8_t * qh = (device const uint8_t *)xb->qh; + device const int8_t * scales = (device const int8_t *)xb->scales; + +#if QK_K == 256 + ql = ql + 64*(il/8) + 32*((il/2)&1) + 16*(il&1); + qh = qh + 32*(il/8) + 16*(il&1); + half sc = scales[(il%2) + 2 * ((il/2))]; + il = (il/2) & 3; +#else + ql = ql + 16 * (il&1); + half sc = scales[il]; +#endif + const uint16_t kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3); + const uint16_t kmask2 = il>1 ? 0xF0 : 0x0F; + const half coef = il>1 ? 1.f/16.h : 1.h; + const half ml = d_all * sc * 32.h; + const half dl = d_all * sc * coef; + for (int i = 0; i < 16; ++i) { + const half q = il&1 ? ((ql[i] & kmask2) | ((qh[i] & kmask1) << 2)) + : ((ql[i] & kmask2) | ((qh[i] & kmask1) << 4)); + reg[i/4][i%4] = dl * q - ml; + } +} + +template +kernel void kernel_get_rows( + device const void * src0, + device const char * src1, + device float * dst, + constant int64_t & ne00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb1, + constant uint64_t & nb2, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint3 tptg [[threads_per_threadgroup]]) { + //const int64_t i = tgpig; + //const int64_t r = ((device int32_t *) src1)[i]; + + const int64_t i10 = tgpig.x; + const int64_t i11 = tgpig.y; + + const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0]; + + const int64_t i02 = i11; + + for (int64_t ind = tiitg; ind < ne00/16; ind += tptg.x) { + float4x4 temp; + dequantize_func( + ((device const block_q *) ((device char *) src0 + r*nb01 + i02*nb02)) + ind/nl, ind%nl, temp); + *(((device float4x4 *) ((device char *) dst + i11*nb2 + i10*nb1)) + ind) = temp; + } +} + +kernel void kernel_get_rows_f32( + device const void * src0, + device const char * src1, + device float * dst, + constant int64_t & ne00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb1, + constant uint64_t & nb2, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint3 tptg [[threads_per_threadgroup]]) { + const int64_t i10 = tgpig.x; + const int64_t i11 = tgpig.y; + + const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0]; + + const int64_t i02 = i11; + + for (int ind = tiitg; ind < ne00; ind += tptg.x) { + ((device float *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] = + ((device float *) ((device char *) src0 + r*nb01 + i02*nb02))[ind]; + } +} + +kernel void kernel_get_rows_f16( + device const void * src0, + device const char * src1, + device float * dst, + constant int64_t & ne00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb1, + constant uint64_t & nb2, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint3 tptg [[threads_per_threadgroup]]) { + const int64_t i10 = tgpig.x; + const int64_t i11 = tgpig.y; + + const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0]; + + const int64_t i02 = i11; + + for (int ind = tiitg; ind < ne00; ind += tptg.x) { + ((device float *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] = + ((device half *) ((device char *) src0 + r*nb01 + i02*nb02))[ind]; + } +} + +#define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A +#define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix B +#define BLOCK_SIZE_K 32 +#define THREAD_MAT_M 4 // each thread take 4 simdgroup matrices from matrix A +#define THREAD_MAT_N 2 // each thread take 2 simdgroup matrices from matrix B +#define THREAD_PER_BLOCK 128 +#define THREAD_PER_ROW 2 // 2 thread for each row in matrix A to load numbers +#define THREAD_PER_COL 4 // 4 thread for each row in matrix B to load numbers +#define SG_MAT_SIZE 64 // simdgroup matrix is of shape 8x8 +#define SG_MAT_ROW 8 + +// each block_q contains 16*nl weights +template +void kernel_mul_mm_impl(device const uchar * src0, + device const uchar * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne02, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + threadgroup uchar * shared_memory [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + threadgroup half * sa = (threadgroup half *)(shared_memory); + threadgroup float * sb = (threadgroup float *)(shared_memory + 4096); + + const uint r0 = tgpig.y; + const uint r1 = tgpig.x; + const uint im = tgpig.z; + + // if this block is of 64x32 shape or smaller + short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M; + short n_cols = (ne1 - r1 * BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1 * BLOCK_SIZE_N) : BLOCK_SIZE_N; + + // a thread shouldn't load data outside of the matrix + short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1; + short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1; + + simdgroup_half8x8 ma[4]; + simdgroup_float8x8 mb[2]; + simdgroup_float8x8 c_res[8]; + for (int i = 0; i < 8; i++){ + c_res[i] = make_filled_simdgroup_matrix(0.f); + } + + short il = (tiitg % THREAD_PER_ROW); + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + uint offset0 = (i12/r2)*nb02 + (i13/r3)*(nb02*ne02); + ushort offset1 = il/nl; + + device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1; + device const float * y = (device const float *)(src1 + + nb12 * im + + nb11 * (r1 * BLOCK_SIZE_N + thread_col) + + nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL))); + + for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) { + // load data and store to threadgroup memory + half4x4 temp_a; + dequantize_func(x, il, temp_a); + threadgroup_barrier(mem_flags::mem_threadgroup); + + #pragma unroll(16) + for (int i = 0; i < 16; i++) { + *(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \ + + (tiitg % THREAD_PER_ROW) * 16 + (i / 8) * 8) \ + + (tiitg / THREAD_PER_ROW) % 8 + (i & 7) * 8) = temp_a[i/4][i%4]; + } + + *(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) = *((device float2x4 *)y); + + il = (il + 2 < nl) ? il + 2 : il % 2; + x = (il < 2) ? x + (2+nl-1)/nl : x; + y += BLOCK_SIZE_K; + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // load matrices from threadgroup memory and conduct outer products + threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2)); + threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2)); + + #pragma unroll(4) + for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) { + #pragma unroll(4) + for (int i = 0; i < 4; i++) { + simdgroup_load(ma[i],lsma + SG_MAT_SIZE * i); + } + simdgroup_barrier(mem_flags::mem_none); + #pragma unroll(2) + for (int i = 0; i < 2; i++) { + simdgroup_load(mb[i],lsmb + SG_MAT_SIZE * i); + } + + lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE; + lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE; + + #pragma unroll(8) + for (int i = 0; i < 8; i++){ + simdgroup_multiply_accumulate(c_res[i], mb[i/4], ma[i%4], c_res[i]); + } + } + } + + if ((r0 + 1) * BLOCK_SIZE_M <= ne0 && (r1 + 1) * BLOCK_SIZE_N <= ne1) { + device float * C = dst + (BLOCK_SIZE_M * r0 + 32 * (sgitg & 1)) \ + + (BLOCK_SIZE_N * r1 + 16 * (sgitg >> 1)) * ne0 + im*ne1*ne0; + for (int i = 0; i < 8; i++) { + simdgroup_store(c_res[i], C + 8 * (i%4) + 8 * ne0 * (i/4), ne0); + } + } else { + // block is smaller than 64x32, we should avoid writing data outside of the matrix + threadgroup_barrier(mem_flags::mem_threadgroup); + threadgroup float * temp_str = ((threadgroup float *)shared_memory) \ + + 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M; + for (int i = 0; i < 8; i++) { + simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M); + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + device float * C = dst + (BLOCK_SIZE_M * r0) + (BLOCK_SIZE_N * r1) * ne0 + im*ne1*ne0; + if (sgitg == 0) { + for (int i = 0; i < n_rows; i++) { + for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) { + *(C + i + j * ne0) = *(temp_str + i + j * BLOCK_SIZE_M); + } + } + } + } +} + +// same as kernel_mul_mm_impl, but src1 and dst are accessed via indices stored in src1ids +template +void kernel_mul_mm_id_impl( + device const uchar * src0, + device const uchar * src1, + thread short * src1ids, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne02, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + int64_t ne1, + constant uint & r2, + constant uint & r3, + threadgroup uchar * shared_memory, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + threadgroup half * sa = (threadgroup half *)(shared_memory); + threadgroup float * sb = (threadgroup float *)(shared_memory + 4096); + + const uint r0 = tgpig.y; + const uint r1 = tgpig.x; + const uint im = tgpig.z; + + if (r1 * BLOCK_SIZE_N >= ne1) return; + + // if this block is of 64x32 shape or smaller + short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M; + short n_cols = (ne1 - r1 * BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1 * BLOCK_SIZE_N) : BLOCK_SIZE_N; + + // a thread shouldn't load data outside of the matrix + short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1; + short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1; + + simdgroup_half8x8 ma[4]; + simdgroup_float8x8 mb[2]; + simdgroup_float8x8 c_res[8]; + for (int i = 0; i < 8; i++){ + c_res[i] = make_filled_simdgroup_matrix(0.f); + } + + short il = (tiitg % THREAD_PER_ROW); + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + uint offset0 = (i12/r2)*nb02 + (i13/r3)*(nb02*ne02); + ushort offset1 = il/nl; + + device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1; + device const float * y = (device const float *)(src1 + + nb12 * im + + nb11 * src1ids[r1 * BLOCK_SIZE_N + thread_col] + + nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL))); + + for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) { + // load data and store to threadgroup memory + half4x4 temp_a; + dequantize_func(x, il, temp_a); + threadgroup_barrier(mem_flags::mem_threadgroup); + + for (int i = 0; i < 16; i++) { + *(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \ + + (tiitg % THREAD_PER_ROW) * 16 + (i / 8) * 8) \ + + (tiitg / THREAD_PER_ROW) % 8 + (i & 7) * 8) = temp_a[i/4][i%4]; + } + + *(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) = *((device float2x4 *)y); + + il = (il + 2 < nl) ? il + 2 : il % 2; + x = (il < 2) ? x + (2+nl-1)/nl : x; + y += BLOCK_SIZE_K; + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // load matrices from threadgroup memory and conduct outer products + threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2)); + threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2)); + + for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) { + for (int i = 0; i < 4; i++) { + simdgroup_load(ma[i],lsma + SG_MAT_SIZE * i); + } + simdgroup_barrier(mem_flags::mem_none); + for (int i = 0; i < 2; i++) { + simdgroup_load(mb[i],lsmb + SG_MAT_SIZE * i); + } + + lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE; + lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE; + + for (int i = 0; i < 8; i++){ + simdgroup_multiply_accumulate(c_res[i], mb[i/4], ma[i%4], c_res[i]); + } + } + } + + { + threadgroup_barrier(mem_flags::mem_threadgroup); + threadgroup float * temp_str = ((threadgroup float *)shared_memory) \ + + 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M; + for (int i = 0; i < 8; i++) { + simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M); + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + device float * C = dst + (BLOCK_SIZE_M * r0) + im*ne1*ne0; + if (sgitg == 0) { + for (int i = 0; i < n_rows; i++) { + for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) { + *(C + i + src1ids[j + r1*BLOCK_SIZE_N] * ne0) = *(temp_str + i + j * BLOCK_SIZE_M); + } + } + } + } +} + +template +kernel void kernel_mul_mm(device const uchar * src0, + device const uchar * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne02, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + threadgroup uchar * shared_memory [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + kernel_mul_mm_impl( + src0, + src1, + dst, + ne00, + ne02, + nb01, + nb02, + ne12, + nb10, + nb11, + nb12, + ne0, + ne1, + r2, + r3, + shared_memory, + tgpig, + tiitg, + sgitg); +} + +template +kernel void kernel_mul_mm_id( + device const uchar * ids, + device const uchar * src1, + device float * dst, + constant uint64_t & nbi1, + constant int64_t & ne00, + constant int64_t & ne02, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint64_t & nb1, + constant uint & r2, + constant uint & r3, + constant int & idx, + device const uchar * src00, + device const uchar * src01, + device const uchar * src02, + device const uchar * src03, + device const uchar * src04, + device const uchar * src05, + device const uchar * src06, + device const uchar * src07, + threadgroup uchar * shared_memory [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + device const uchar * src0s[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; + + // expert id + const int32_t id = tgpig.z/(ne12*ne13); + + tgpig.z = tgpig.z%(ne12*ne13); + + // row indices of src1 for expert id + int64_t _ne1 = 0; + short src1ids[512]; + + for (int64_t i1 = 0; i1 < ne1; i1++) { + if (((device int32_t *) (ids + i1*nbi1))[idx] == id) { + src1ids[_ne1++] = i1; + } + } + + kernel_mul_mm_id_impl( + src0s[id], + src1, + src1ids, + dst, + ne00, + ne02, + nb01, + nb02, + ne12, + nb10, + nb11, + nb12, + ne0, + _ne1, + r2, + r3, + shared_memory, + tgpig, + tiitg, + sgitg); +} + +#if QK_K == 256 +#define QK_NL 16 +#else +#define QK_NL 4 +#endif + +// +// get rows +// + +typedef void (get_rows_t)( + device const void * src0, + device const char * src1, + device float * dst, + constant int64_t & ne00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb1, + constant uint64_t & nb2, + uint3, uint, uint3); + +//template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows; +//template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows; +template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows; +template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows; +template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_t kernel_get_rows; +template [[host_name("kernel_get_rows_q5_1")]] kernel get_rows_t kernel_get_rows; +template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_t kernel_get_rows; +template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_t kernel_get_rows; +template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_t kernel_get_rows; +template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows; +template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows; +template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows; + +// +// matrix-matrix multiplication +// + +typedef void (mat_mm_t)( + device const uchar * src0, + device const uchar * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne02, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + threadgroup uchar *, + uint3, uint, uint); + +template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm; + +// +// indirect matrix-matrix multiplication +// + +typedef void (mat_mm_id_t)( + device const uchar * ids, + device const uchar * src1, + device float * dst, + constant uint64_t & nbi1, + constant int64_t & ne00, + constant int64_t & ne02, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint64_t & nb1, + constant uint & r2, + constant uint & r3, + constant int & idx, + device const uchar * src00, + device const uchar * src01, + device const uchar * src02, + device const uchar * src03, + device const uchar * src04, + device const uchar * src05, + device const uchar * src06, + device const uchar * src07, + threadgroup uchar *, + uint3, uint, uint); + +template [[host_name("kernel_mul_mm_id_f32_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_f16_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q4_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q4_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q5_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q8_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q2_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; + +// +// matrix-vector multiplication +// + +[[host_name("kernel_mul_mv_id_f32_f32")]] +kernel void kernel_mul_mv_id_f32_f32( + device const char * ids, + device const char * src1, + device float * dst, + constant uint64_t & nbi1, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint64_t & nb1, + constant uint & r2, + constant uint & r3, + constant int & idx, + device const char * src00, + device const char * src01, + device const char * src02, + device const char * src03, + device const char * src04, + device const char * src05, + device const char * src06, + device const char * src07, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; + + const int64_t bid = tgpig.z/(ne12*ne13); + + tgpig.z = tgpig.z%(ne12*ne13); + + const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + + kernel_mul_mv_f32_f32_impl( + src0[id], + src1 + bid*nb11, + dst + bid*ne0, + ne00, + ne01, + ne02, + nb00, + nb01, + nb02, + ne10, + ne11, + ne12, + nb10, + nb11, + nb12, + ne0, + ne1, + r2, + r3, + tgpig, + tiisg); +} + +[[host_name("kernel_mul_mv_id_f16_f32")]] +kernel void kernel_mul_mv_id_f16_f32( + device const char * ids, + device const char * src1, + device float * dst, + constant uint64_t & nbi1, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint64_t & nb1, + constant uint & r2, + constant uint & r3, + constant int & idx, + device const char * src00, + device const char * src01, + device const char * src02, + device const char * src03, + device const char * src04, + device const char * src05, + device const char * src06, + device const char * src07, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; + + const int64_t bid = tgpig.z/(ne12*ne13); + + tgpig.z = tgpig.z%(ne12*ne13); + + const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + + kernel_mul_mv_f16_f32_impl( + src0[id], + src1 + bid*nb11, + dst + bid*ne0, + ne00, + ne01, + ne02, + nb00, + nb01, + nb02, + ne10, + ne11, + ne12, + nb10, + nb11, + nb12, + ne0, + ne1, + r2, + r3, + tgpig, + tiisg); +} + +[[host_name("kernel_mul_mv_id_q8_0_f32")]] +kernel void kernel_mul_mv_id_q8_0_f32( + device const char * ids, + device const char * src1, + device float * dst, + constant uint64_t & nbi1, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint64_t & nb1, + constant uint & r2, + constant uint & r3, + constant int & idx, + device const char * src00, + device const char * src01, + device const char * src02, + device const char * src03, + device const char * src04, + device const char * src05, + device const char * src06, + device const char * src07, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; + + const int64_t bid = tgpig.z/(ne12*ne13); + + tgpig.z = tgpig.z%(ne12*ne13); + + const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + + kernel_mul_mv_q8_0_f32_impl( + src0[id], + (device const float *) (src1 + bid*nb11), + dst + bid*ne0, + ne00, + ne01, + ne02, + ne10, + ne12, + ne0, + ne1, + r2, + r3, + tgpig, + tiisg, + sgitg); +} + +[[host_name("kernel_mul_mv_id_q4_0_f32")]] +kernel void kernel_mul_mv_id_q4_0_f32( + device const char * ids, + device const char * src1, + device float * dst, + constant uint64_t & nbi1, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint64_t & nb1, + constant uint & r2, + constant uint & r3, + constant int & idx, + device const char * src00, + device const char * src01, + device const char * src02, + device const char * src03, + device const char * src04, + device const char * src05, + device const char * src06, + device const char * src07, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; + + const int64_t bid = tgpig.z/(ne12*ne13); + + tgpig.z = tgpig.z%(ne12*ne13); + + const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + + mul_vec_q_n_f32_impl( + src0[id], + (device const float *) (src1 + bid*nb11), + dst + bid*ne0, + ne00, + ne01, + ne02, + ne10, + ne12, + ne0, + ne1, + r2, + r3, + tgpig, + tiisg, + sgitg); +} + +[[host_name("kernel_mul_mv_id_q4_1_f32")]] +kernel void kernel_mul_mv_id_q4_1_f32( + device const char * ids, + device const char * src1, + device float * dst, + constant uint64_t & nbi1, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint64_t & nb1, + constant uint & r2, + constant uint & r3, + constant int & idx, + device const char * src00, + device const char * src01, + device const char * src02, + device const char * src03, + device const char * src04, + device const char * src05, + device const char * src06, + device const char * src07, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; + + const int64_t bid = tgpig.z/(ne12*ne13); + + tgpig.z = tgpig.z%(ne12*ne13); + + const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + + mul_vec_q_n_f32_impl( + src0[id], + (device const float *) (src1 + bid*nb11), + dst + bid*ne0, + ne00, + ne01, + ne02, + ne10, + ne12, + ne0, + ne1, + r2, + r3, + tgpig, + tiisg, + sgitg); +} + +[[host_name("kernel_mul_mv_id_q5_0_f32")]] +kernel void kernel_mul_mv_id_q5_0_f32( + device const char * ids, + device const char * src1, + device float * dst, + constant uint64_t & nbi1, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint64_t & nb1, + constant uint & r2, + constant uint & r3, + constant int & idx, + device const char * src00, + device const char * src01, + device const char * src02, + device const char * src03, + device const char * src04, + device const char * src05, + device const char * src06, + device const char * src07, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; + + const int64_t bid = tgpig.z/(ne12*ne13); + + tgpig.z = tgpig.z%(ne12*ne13); + + const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + + mul_vec_q_n_f32_impl( + src0[id], + (device const float *) (src1 + bid*nb11), + dst + bid*ne0, + ne00, + ne01, + ne02, + ne10, + ne12, + ne0, + ne1, + r2, + r3, + tgpig, + tiisg, + sgitg); +} + +[[host_name("kernel_mul_mv_id_q5_1_f32")]] +kernel void kernel_mul_mv_id_q5_1_f32( + device const char * ids, + device const char * src1, + device float * dst, + constant uint64_t & nbi1, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint64_t & nb1, + constant uint & r2, + constant uint & r3, + constant int & idx, + device const char * src00, + device const char * src01, + device const char * src02, + device const char * src03, + device const char * src04, + device const char * src05, + device const char * src06, + device const char * src07, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; + + const int64_t bid = tgpig.z/(ne12*ne13); + + tgpig.z = tgpig.z%(ne12*ne13); + + const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + + mul_vec_q_n_f32_impl( + src0[id], + (device const float *) (src1 + bid*nb11), + dst + bid*ne0, + ne00, + ne01, + ne02, + ne10, + ne12, + ne0, + ne1, + r2, + r3, + tgpig, + tiisg, + sgitg); +} + +[[host_name("kernel_mul_mv_id_q2_K_f32")]] +kernel void kernel_mul_mv_id_q2_K_f32( + device const char * ids, + device const char * src1, + device float * dst, + constant uint64_t & nbi1, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint64_t & nb1, + constant uint & r2, + constant uint & r3, + constant int & idx, + device const char * src00, + device const char * src01, + device const char * src02, + device const char * src03, + device const char * src04, + device const char * src05, + device const char * src06, + device const char * src07, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; + + const int64_t bid = tgpig.z/(ne12*ne13); + + tgpig.z = tgpig.z%(ne12*ne13); + + const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + + kernel_mul_mv_q2_K_f32_impl( + src0[id], + (device const float *) (src1 + bid*nb11), + dst + bid*ne0, + ne00, + ne01, + ne02, + ne10, + ne12, + ne0, + ne1, + r2, + r3, + tgpig, + tiisg, + sgitg); +} + +[[host_name("kernel_mul_mv_id_q3_K_f32")]] +kernel void kernel_mul_mv_id_q3_K_f32( + device const char * ids, + device const char * src1, + device float * dst, + constant uint64_t & nbi1, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint64_t & nb1, + constant uint & r2, + constant uint & r3, + constant int & idx, + device const char * src00, + device const char * src01, + device const char * src02, + device const char * src03, + device const char * src04, + device const char * src05, + device const char * src06, + device const char * src07, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; + + const int64_t bid = tgpig.z/(ne12*ne13); + + tgpig.z = tgpig.z%(ne12*ne13); + + const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + + kernel_mul_mv_q3_K_f32_impl( + src0[id], + (device const float *) (src1 + bid*nb11), + dst + bid*ne0, + ne00, + ne01, + ne02, + ne10, + ne12, + ne0, + ne1, + r2, + r3, + tgpig, + tiisg, + sgitg); +} + +[[host_name("kernel_mul_mv_id_q4_K_f32")]] +kernel void kernel_mul_mv_id_q4_K_f32( + device const char * ids, + device const char * src1, + device float * dst, + constant uint64_t & nbi1, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint64_t & nb1, + constant uint & r2, + constant uint & r3, + constant int & idx, + device const char * src00, + device const char * src01, + device const char * src02, + device const char * src03, + device const char * src04, + device const char * src05, + device const char * src06, + device const char * src07, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; + + const int64_t bid = tgpig.z/(ne12*ne13); + + tgpig.z = tgpig.z%(ne12*ne13); + + const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + + kernel_mul_mv_q4_K_f32_impl( + src0[id], + (device const float *) (src1 + bid*nb11), + dst + bid*ne0, + ne00, + ne01, + ne02, + ne10, + ne12, + ne0, + ne1, + r2, + r3, + tgpig, + tiisg, + sgitg); +} + +[[host_name("kernel_mul_mv_id_q5_K_f32")]] +kernel void kernel_mul_mv_id_q5_K_f32( + device const char * ids, + device const char * src1, + device float * dst, + constant uint64_t & nbi1, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint64_t & nb1, + constant uint & r2, + constant uint & r3, + constant int & idx, + device const char * src00, + device const char * src01, + device const char * src02, + device const char * src03, + device const char * src04, + device const char * src05, + device const char * src06, + device const char * src07, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; + + const int64_t bid = tgpig.z/(ne12*ne13); + + tgpig.z = tgpig.z%(ne12*ne13); + + const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + + kernel_mul_mv_q5_K_f32_impl( + src0[id], + (device const float *) (src1 + bid*nb11), + dst + bid*ne0, + ne00, + ne01, + ne02, + ne10, + ne12, + ne0, + ne1, + r2, + r3, + tgpig, + tiisg, + sgitg); +} + +[[host_name("kernel_mul_mv_id_q6_K_f32")]] +kernel void kernel_mul_mv_id_q6_K_f32( + device const char * ids, + device const char * src1, + device float * dst, + constant uint64_t & nbi1, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint64_t & nb1, + constant uint & r2, + constant uint & r3, + constant int & idx, + device const char * src00, + device const char * src01, + device const char * src02, + device const char * src03, + device const char * src04, + device const char * src05, + device const char * src06, + device const char * src07, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; + + const int64_t bid = tgpig.z/(ne12*ne13); + + tgpig.z = tgpig.z%(ne12*ne13); + + const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + + kernel_mul_mv_q6_K_f32_impl( + src0[id], + (device const float *) (src1 + bid*nb11), + dst + bid*ne0, + ne00, + ne01, + ne02, + ne10, + ne12, + ne0, + ne1, + r2, + r3, + tgpig, + tiisg, + sgitg); +} diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs index 87f8ac45e6..787a7d45ed 100644 --- a/candle-metal-kernels/src/tests.rs +++ b/candle-metal-kernels/src/tests.rs @@ -37,8 +37,7 @@ fn approx_bf16(v: Vec, digits: i32) -> Vec { fn run(v: &[T], name: unary::contiguous::Kernel) -> Vec { let device = device(); - let fence = device.new_fence(); - let kernels = Kernels::new(fence); + let kernels = Kernels::new(); let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); let input = new_buffer(&device, v); @@ -60,8 +59,7 @@ fn run(v: &[T], name: unary::contiguous::Kernel) -> Vec { fn run_binary(x: &[T], y: &[T], name: binary::contiguous::Kernel) -> Vec { let device = device(); - let fence = device.new_fence(); - let kernels = Kernels::new(fence); + let kernels = Kernels::new(); let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); let options = MTLResourceOptions::StorageModeManaged; @@ -96,8 +94,7 @@ fn run_strided( let command_buffer = command_queue.new_command_buffer(); let input = new_buffer(&device, v); let output = new_buffer(&device, v); - let fence = device.new_fence(); - let kernels = Kernels::new(fence); + let kernels = Kernels::new(); call_unary_strided( &device, command_buffer, @@ -278,8 +275,7 @@ fn binary_ops_bf16() { fn cast(v: &[T], name: &'static str) -> Vec { let device = device(); - let fence = device.new_fence(); - let kernels = Kernels::new(fence); + let kernels = Kernels::new(); let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); let input = new_buffer(&device, v); @@ -409,8 +405,7 @@ fn it_cast_f16_bf16() { fn run_affine(v: &[T], mul: f64, add: f64) -> Vec { let device = device(); - let fence = device.new_fence(); - let kernels = Kernels::new(fence); + let kernels = Kernels::new(); let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); @@ -445,8 +440,7 @@ fn run_affine_strided( add: f64, ) -> Vec { let device = device(); - let fence = device.new_fence(); - let kernels = Kernels::new(fence); + let kernels = Kernels::new(); let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); @@ -595,8 +589,7 @@ fn run_index_select( let dst_el = ids.len() * left_size * right_size; let dst_buffer = new_buffer(&device, &vec![0.0f32; dst_el]); - let fence = device.new_fence(); - let kernels = Kernels::new(fence); + let kernels = Kernels::new(); call_index_select( &device, &command_buffer, @@ -631,8 +624,7 @@ fn cos_f16() { fn run_reduce(v: &[T], out_length: usize, name: &'static str) -> Vec { let device = device(); - let fence = device.new_fence(); - let kernels = Kernels::new(fence); + let kernels = Kernels::new(); let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); let input = new_buffer(&device, v); @@ -662,8 +654,7 @@ fn run_reduce(v: &[T], out_length: usize, name: &'static str) -> Vec(v: &[T], last_dim: usize, name: &'static str) -> Vec { let device = device(); - let fence = device.new_fence(); - let kernels = Kernels::new(fence); + let kernels = Kernels::new(); let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); let input = new_buffer(&device, v); @@ -782,8 +773,7 @@ fn run_where_cond( name: &'static str, ) -> Vec { let device = device(); - let fence = device.new_fence(); - let kernels = Kernels::new(fence); + let kernels = Kernels::new(); let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); let options = MTLResourceOptions::StorageModeManaged; @@ -859,8 +849,7 @@ fn run_gemm( rhs_offset: usize, ) -> Vec { let device = device(); - let fence = device.new_fence(); - let kernels = Kernels::new(fence); + let kernels = Kernels::new(); let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); let options = MTLResourceOptions::StorageModeManaged; diff --git a/candle-metal-kernels/src/unary.metal b/candle-metal-kernels/src/unary.metal index dcf803d870..7add58fd3f 100644 --- a/candle-metal-kernels/src/unary.metal +++ b/candle-metal-kernels/src/unary.metal @@ -117,7 +117,6 @@ UNARY_OP(erf) UNARY_OP(tanh) UNARY_OP(recip) UNARY_OP(relu) - UNARY(id, float, copy_f32, copy_f32_strided) UNARY(id, half, copy_f16, copy_f16_strided) UNARY(id, uint8_t, copy_u8, copy_u8_strided) @@ -136,6 +135,7 @@ BFLOAT_UNARY_OP(neg) BFLOAT_UNARY_OP(exp) BFLOAT_UNARY_OP(log) BFLOAT_UNARY_OP(gelu) +BFLOAT_UNARY_OP(abs) BFLOAT_UNARY_OP(ceil) BFLOAT_UNARY_OP(floor) BFLOAT_UNARY_OP(round) diff --git a/candle-nn/examples/cpu_benchmarks.rs b/candle-nn/examples/cpu_benchmarks.rs index 68d384a6df..001be11681 100644 --- a/candle-nn/examples/cpu_benchmarks.rs +++ b/candle-nn/examples/cpu_benchmarks.rs @@ -222,7 +222,10 @@ impl Benchmark for QMatMul { type RunResult = Tensor; fn preprocess() -> Result { let zeros = vec![candle::quantized::k_quants::BlockQ4_0::zeros(); 4096 * 11008 / 32]; - let mm = candle::quantized::QTensor::new(zeros, (4096, 11008))?; + let mm = candle::quantized::QTensor::new( + candle::quantized::QStorage::Cpu(Box::new(zeros)), + (4096, 11008), + )?; let mm = candle::quantized::QMatMul::from_qtensor(mm)?; let arg = Tensor::randn(0f32, 1., (128, 11008), &Device::Cpu)?; Ok((mm, arg)) diff --git a/candle-pyo3/py_src/candle/utils/__init__.pyi b/candle-pyo3/py_src/candle/utils/__init__.pyi index 4ee51c290b..c9a9f9f3c1 100644 --- a/candle-pyo3/py_src/candle/utils/__init__.pyi +++ b/candle-pyo3/py_src/candle/utils/__init__.pyi @@ -33,7 +33,9 @@ def has_mkl() -> bool: pass @staticmethod -def load_ggml(path: Union[str, PathLike]) -> Tuple[Dict[str, QTensor], Dict[str, Any], List[str]]: +def load_ggml( + path: Union[str, PathLike], device: Optional[Device] = None +) -> Tuple[Dict[str, QTensor], Dict[str, Any], List[str]]: """ Load a GGML file. Returns a tuple of three objects: a dictionary mapping tensor names to tensors, a dictionary mapping hyperparameter names to hyperparameter values, and a vocabulary. @@ -41,7 +43,9 @@ def load_ggml(path: Union[str, PathLike]) -> Tuple[Dict[str, QTensor], Dict[str, pass @staticmethod -def load_gguf(path: Union[str, PathLike]) -> Tuple[Dict[str, QTensor], Dict[str, Any]]: +def load_gguf( + path: Union[str, PathLike], device: Optional[Device] = None +) -> Tuple[Dict[str, QTensor], Dict[str, Any]]: """ Loads a GGUF file. Returns a tuple of two dictionaries: the first maps tensor names to tensors, and the second maps metadata keys to metadata values. diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs index 90826b98b0..ca40687607 100644 --- a/candle-pyo3/src/lib.rs +++ b/candle-pyo3/src/lib.rs @@ -1074,20 +1074,20 @@ impl PyTensor { fn quantize(&self, quantized_dtype: &str) -> PyResult { use ::candle::quantized; let res = match quantized_dtype.to_lowercase().as_str() { - "q2k" => quantized::QTensor::quantize::(self), - "q3k" => quantized::QTensor::quantize::(self), - "q4_0" => quantized::QTensor::quantize::(self), - "q4_1" => quantized::QTensor::quantize::(self), - "q4k" => quantized::QTensor::quantize::(self), - "q5_0" => quantized::QTensor::quantize::(self), - "q5_1" => quantized::QTensor::quantize::(self), - "q5k" => quantized::QTensor::quantize::(self), - "q6k" => quantized::QTensor::quantize::(self), - "q8_0" => quantized::QTensor::quantize::(self), - "q8_1" => quantized::QTensor::quantize::(self), - "q8k" => quantized::QTensor::quantize::(self), - "f16" => quantized::QTensor::quantize::(self), - "f32" => quantized::QTensor::quantize::(self), + "q2k" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q2K), + "q3k" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q3K), + "q4_0" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q4_0), + "q4_1" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q4_1), + "q4k" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q4K), + "q5_0" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q5_0), + "q5_1" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q5_1), + "q5k" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q5K), + "q6k" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q6K), + "q8_0" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q8_0), + "q8_1" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q8_1), + "q8k" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q8K), + "f16" => quantized::QTensor::quantize(self, quantized::GgmlDType::F16), + "f32" => quantized::QTensor::quantize(self, quantized::GgmlDType::F32), dt => { return Err(PyErr::new::(format!( "unknown quantized-dtype {dt}" @@ -1278,13 +1278,19 @@ fn save_safetensors( } #[pyfunction] -#[pyo3(text_signature = "(path:Union[str,PathLike])")] +#[pyo3(text_signature = "(path:Union[str,PathLike], device: Optional[Device] = None)")] /// Load a GGML file. Returns a tuple of three objects: a dictionary mapping tensor names to tensors, /// a dictionary mapping hyperparameter names to hyperparameter values, and a vocabulary. /// &RETURNS&: Tuple[Dict[str,QTensor], Dict[str,Any], List[str]] -fn load_ggml(path: &str, py: Python<'_>) -> PyResult<(PyObject, PyObject, PyObject)> { +fn load_ggml( + path: &str, + device: Option, + py: Python<'_>, +) -> PyResult<(PyObject, PyObject, PyObject)> { let mut file = std::fs::File::open(path)?; - let ggml = ::candle::quantized::ggml_file::Content::read(&mut file).map_err(wrap_err)?; + let device = device.unwrap_or(PyDevice::Cpu).as_device()?; + let ggml = + ::candle::quantized::ggml_file::Content::read(&mut file, &device).map_err(wrap_err)?; let tensors = ggml .tensors .into_iter() @@ -1313,11 +1319,16 @@ fn load_ggml(path: &str, py: Python<'_>) -> PyResult<(PyObject, PyObject, PyObje } #[pyfunction] -#[pyo3(text_signature = "(path:Union[str,PathLike])")] +#[pyo3(text_signature = "(path:Union[str,PathLike], device: Optional[Device] = None)")] /// Loads a GGUF file. Returns a tuple of two dictionaries: the first maps tensor names to tensors, /// and the second maps metadata keys to metadata values. /// &RETURNS&: Tuple[Dict[str,QTensor], Dict[str,Any]] -fn load_gguf(path: &str, py: Python<'_>) -> PyResult<(PyObject, PyObject)> { +fn load_gguf( + path: &str, + device: Option, + py: Python<'_>, +) -> PyResult<(PyObject, PyObject)> { + let device = device.unwrap_or(PyDevice::Cpu).as_device()?; use ::candle::quantized::gguf_file; fn gguf_value_to_pyobject(v: &gguf_file::Value, py: Python<'_>) -> PyResult { let v: PyObject = match v { @@ -1349,7 +1360,7 @@ fn load_gguf(path: &str, py: Python<'_>) -> PyResult<(PyObject, PyObject)> { .tensor_infos .keys() .map(|key| { - let qtensor = gguf.tensor(&mut file, key)?; + let qtensor = gguf.tensor(&mut file, key, &device)?; Ok((key, PyQTensor(Arc::new(qtensor)).into_py(py))) }) .collect::<::candle::Result>>() diff --git a/candle-transformers/src/models/quantized_llama.rs b/candle-transformers/src/models/quantized_llama.rs index 1fb2d9e26f..8aa0608822 100644 --- a/candle-transformers/src/models/quantized_llama.rs +++ b/candle-transformers/src/models/quantized_llama.rs @@ -356,6 +356,7 @@ impl ModelWeights { pub fn from_gguf( ct: gguf_file::Content, reader: &mut R, + device: &Device, ) -> Result { let cpu = &Device::Cpu; let md_get = |s: &str| match ct.metadata.get(s) { @@ -383,21 +384,28 @@ impl ModelWeights { .unwrap_or(10000f32); let (cos, sin) = precomput_freqs_cis(rope_dim, rope_freq_base)?; - let tok_embeddings = ct.tensor(reader, "token_embd.weight")?; + let tok_embeddings = ct.tensor(reader, "token_embd.weight", device)?; let tok_embeddings = tok_embeddings.dequantize(cpu)?; - let norm = RmsNorm::new(ct.tensor(reader, "output_norm.weight")?, rms_norm_eps)?; - let output = ct.tensor(reader, "output.weight")?; + let norm = RmsNorm::new( + ct.tensor(reader, "output_norm.weight", device)?, + rms_norm_eps, + )?; + let output = ct.tensor(reader, "output.weight", device)?; let mut layers = Vec::with_capacity(block_count); for layer_idx in 0..block_count { let prefix = format!("blk.{layer_idx}"); - let attention_wq = ct.tensor(reader, &format!("{prefix}.attn_q.weight"))?; - let attention_wk = ct.tensor(reader, &format!("{prefix}.attn_k.weight"))?; - let attention_wv = ct.tensor(reader, &format!("{prefix}.attn_v.weight"))?; - let attention_wo = ct.tensor(reader, &format!("{prefix}.attn_output.weight"))?; + let attention_wq = ct.tensor(reader, &format!("{prefix}.attn_q.weight"), device)?; + let attention_wk = ct.tensor(reader, &format!("{prefix}.attn_k.weight"), device)?; + let attention_wv = ct.tensor(reader, &format!("{prefix}.attn_v.weight"), device)?; + let attention_wo = + ct.tensor(reader, &format!("{prefix}.attn_output.weight"), device)?; let mlp_or_moe = if n_expert <= 1 { - let feed_forward_w1 = ct.tensor(reader, &format!("{prefix}.ffn_gate.weight"))?; - let feed_forward_w2 = ct.tensor(reader, &format!("{prefix}.ffn_down.weight"))?; - let feed_forward_w3 = ct.tensor(reader, &format!("{prefix}.ffn_up.weight"))?; + let feed_forward_w1 = + ct.tensor(reader, &format!("{prefix}.ffn_gate.weight"), device)?; + let feed_forward_w2 = + ct.tensor(reader, &format!("{prefix}.ffn_down.weight"), device)?; + let feed_forward_w3 = + ct.tensor(reader, &format!("{prefix}.ffn_up.weight"), device)?; MlpOrMoe::Mlp(Mlp { feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1)?, feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2)?, @@ -405,15 +413,15 @@ impl ModelWeights { }) } else { let feed_forward_gate_inp = - ct.tensor(reader, &format!("{prefix}.ffn_gate_inp.weight"))?; + ct.tensor(reader, &format!("{prefix}.ffn_gate_inp.weight"), device)?; let mut experts = Vec::with_capacity(n_expert); for i in 0..n_expert { let feed_forward_w1 = - ct.tensor(reader, &format!("{prefix}.ffn_gate.{i}.weight"))?; + ct.tensor(reader, &format!("{prefix}.ffn_gate.{i}.weight"), device)?; let feed_forward_w2 = - ct.tensor(reader, &format!("{prefix}.ffn_down.{i}.weight"))?; + ct.tensor(reader, &format!("{prefix}.ffn_down.{i}.weight"), device)?; let feed_forward_w3 = - ct.tensor(reader, &format!("{prefix}.ffn_up.{i}.weight"))?; + ct.tensor(reader, &format!("{prefix}.ffn_up.{i}.weight"), device)?; experts.push(Mlp { feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1)?, feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2)?, @@ -426,8 +434,9 @@ impl ModelWeights { experts, } }; - let attention_norm = ct.tensor(reader, &format!("{prefix}.attn_norm.weight"))?; - let ffn_norm = ct.tensor(reader, &format!("{prefix}.ffn_norm.weight"))?; + let attention_norm = + ct.tensor(reader, &format!("{prefix}.attn_norm.weight"), device)?; + let ffn_norm = ct.tensor(reader, &format!("{prefix}.ffn_norm.weight"), device)?; let span_attn = tracing::span!(tracing::Level::TRACE, "attn"); let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot"); let span_mlp = tracing::span!(tracing::Level::TRACE, "attn-mlp"); diff --git a/candle-transformers/src/models/quantized_mixformer.rs b/candle-transformers/src/models/quantized_mixformer.rs index 1a3cd4ac61..882f4cf8fa 100644 --- a/candle-transformers/src/models/quantized_mixformer.rs +++ b/candle-transformers/src/models/quantized_mixformer.rs @@ -311,7 +311,7 @@ impl MixFormerSequentialForCausalLM { let mut blocks = Vec::new(); for i in 0..cfg.n_layer { let block = ParallelBlock::new(cfg, vb.pp(i + 1))?; - blocks.push(block) + blocks.push(block); } let head = CausalLMHead::new(cfg, vb.pp(cfg.n_layer + 1))?; Ok(Self { @@ -332,7 +332,7 @@ impl MixFormerSequentialForCausalLM { Some(get_mask(seq_len, xs.device())?) }; for block in self.blocks.iter_mut() { - xs = block.forward(&xs, mask.as_ref())? + xs = block.forward(&xs, mask.as_ref())?; } xs.narrow(1, seq_len - 1, 1)?.apply(&self.head)?.squeeze(1) } diff --git a/candle-transformers/src/quantized_var_builder.rs b/candle-transformers/src/quantized_var_builder.rs index 63101f4cdb..bfd0629f22 100644 --- a/candle-transformers/src/quantized_var_builder.rs +++ b/candle-transformers/src/quantized_var_builder.rs @@ -10,33 +10,33 @@ pub struct VarBuilder { } impl VarBuilder { - pub fn from_gguf>(p: P) -> Result { + pub fn from_gguf>(p: P, device: &Device) -> Result { let mut file = std::fs::File::open(p)?; let content = candle::quantized::gguf_file::Content::read(&mut file)?; let mut data = std::collections::HashMap::new(); for tensor_name in content.tensor_infos.keys() { - let tensor = content.tensor(&mut file, tensor_name)?; + let tensor = content.tensor(&mut file, tensor_name, device)?; data.insert(tensor_name.to_string(), Arc::new(tensor)); } Ok(Self { data: Arc::new(data), path: Vec::new(), - device: Device::Cpu, + device: device.clone(), }) } - pub fn from_gguf_buffer(buffer: &[u8]) -> Result { + pub fn from_gguf_buffer(buffer: &[u8], device: &Device) -> Result { let mut cursor = std::io::Cursor::new(buffer); let content = candle::quantized::gguf_file::Content::read(&mut cursor)?; let mut data = std::collections::HashMap::new(); for tensor_name in content.tensor_infos.keys() { - let tensor = content.tensor(&mut cursor, tensor_name)?; + let tensor = content.tensor(&mut cursor, tensor_name, device)?; data.insert(tensor_name.to_string(), Arc::new(tensor)); } Ok(Self { data: Arc::new(data), path: Vec::new(), - device: Device::Cpu, + device: device.clone(), }) } diff --git a/candle-wasm-examples/blip/src/bin/m.rs b/candle-wasm-examples/blip/src/bin/m.rs index 660bb71743..e2ba4fed48 100644 --- a/candle-wasm-examples/blip/src/bin/m.rs +++ b/candle-wasm-examples/blip/src/bin/m.rs @@ -61,7 +61,7 @@ impl Model { let start = Date::now(); let model: SelectedModel = if quantized { - let vb = quantized_blip::VarBuilder::from_gguf_buffer(&weights)?; + let vb = quantized_blip::VarBuilder::from_gguf_buffer(&weights, &device)?; let model = quantized_blip::BlipForConditionalGeneration::new(&config, vb)?; SelectedModel::Q(model) } else { diff --git a/candle-wasm-examples/phi/src/bin/m.rs b/candle-wasm-examples/phi/src/bin/m.rs index 999f276df7..859e58cbbc 100644 --- a/candle-wasm-examples/phi/src/bin/m.rs +++ b/candle-wasm-examples/phi/src/bin/m.rs @@ -41,6 +41,7 @@ impl Model { ) -> Result { console_error_panic_hook::set_once(); console_log!("loading model"); + let device = Device::Cpu; let name: ModelName = serde_json::from_slice(&config)?; let config: Config = serde_json::from_slice(&config)?; @@ -50,8 +51,9 @@ impl Model { let start = Date::now(); console_log!("weights len: {:?}", weights.len()); let model = if quantized { - let vb = - candle_transformers::quantized_var_builder::VarBuilder::from_gguf_buffer(&weights)?; + let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf_buffer( + &weights, &device, + )?; console_log!("weights loaded"); if name._name_or_path == "microsoft/phi-2" { let model = QMixFormer::new_v2(&config, vb)?; diff --git a/candle-wasm-examples/t5/src/bin/m-quantized.rs b/candle-wasm-examples/t5/src/bin/m-quantized.rs index 2f490b84d2..3b99a27510 100644 --- a/candle-wasm-examples/t5/src/bin/m-quantized.rs +++ b/candle-wasm-examples/t5/src/bin/m-quantized.rs @@ -7,6 +7,7 @@ pub use candle_transformers::models::quantized_t5::{ use candle_wasm_example_t5::console_log; use tokenizers::Tokenizer; use wasm_bindgen::prelude::*; +const DEVICE: Device = Device::Cpu; #[wasm_bindgen] pub struct ModelEncoder { @@ -31,7 +32,7 @@ impl ModelConditionalGeneration { ) -> Result { console_error_panic_hook::set_once(); console_log!("loading model"); - let vb = VarBuilder::from_gguf_buffer(&weights)?; + let vb = VarBuilder::from_gguf_buffer(&weights, &DEVICE)?; let mut config: Config = serde_json::from_slice(&config)?; let tokenizer = Tokenizer::from_bytes(&tokenizer).map_err(|m| JsError::new(&m.to_string()))?; @@ -46,7 +47,7 @@ impl ModelConditionalGeneration { pub fn decode(&mut self, input: JsValue) -> Result { let input: ConditionalGenerationParams = serde_wasm_bindgen::from_value(input).map_err(|m| JsError::new(&m.to_string()))?; - let device = &Device::Cpu; + let device = &DEVICE; self.model.clear_kv_cache(); let mut output_token_ids = [self.config.pad_token_id as u32].to_vec(); let prompt = input.prompt; @@ -128,7 +129,7 @@ impl ModelEncoder { ) -> Result { console_error_panic_hook::set_once(); console_log!("loading model"); - let vb = VarBuilder::from_gguf_buffer(&weights)?; + let vb = VarBuilder::from_gguf_buffer(&weights, &DEVICE)?; let mut config: Config = serde_json::from_slice(&config)?; config.use_cache = false; let tokenizer = @@ -138,7 +139,7 @@ impl ModelEncoder { } pub fn decode(&mut self, input: JsValue) -> Result { - let device = &Device::Cpu; + let device = &DEVICE; let input: DecoderParams = serde_wasm_bindgen::from_value(input).map_err(|m| JsError::new(&m.to_string()))?; diff --git a/candle-wasm-examples/whisper/src/worker.rs b/candle-wasm-examples/whisper/src/worker.rs index fd91fa8c11..898996a77b 100644 --- a/candle-wasm-examples/whisper/src/worker.rs +++ b/candle-wasm-examples/whisper/src/worker.rs @@ -315,6 +315,7 @@ impl Decoder { let model = if md.quantized { let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf_buffer( &md.weights, + &device, )?; Model::Quantized(m::quantized_model::Whisper::load(&vb, config)?) } else { diff --git a/candle-wasm-tests/tests/quantized_tests.rs b/candle-wasm-tests/tests/quantized_tests.rs index e5fa7dec23..fc107e611d 100644 --- a/candle-wasm-tests/tests/quantized_tests.rs +++ b/candle-wasm-tests/tests/quantized_tests.rs @@ -40,7 +40,7 @@ fn quantized_matmul_neg() -> Result<()> { ] ); - let qtensor = quantized::QTensor::new(rhs_t, (4, 64))?; + let qtensor = quantized::QTensor::new(quantized::QStorage::Cpu(Box::new(rhs_t)), (4, 64))?; let matmul = quantized::QMatMul::from_qtensor(qtensor)?; let res = matmul.forward(&tensor_lhs)?; assert_eq!(