diff --git a/Dockerfile b/Dockerfile index 6c5988e..4736d5b 100644 --- a/Dockerfile +++ b/Dockerfile @@ -6,7 +6,7 @@ FROM nvidia/cuda:11.7.0-devel-ubuntu22.04 AS builder RUN apt-get update && apt-get install -y python3 python3-pip git -RUN pip3 install --upgrade pip +RUN pip3 install --upgrade pip # Some of the requirements expect some python packages in their setup.py, just install them first. RUN --mount=type=cache,target=/root/.cache/pip pip install --user torch==2.0.0 @@ -15,7 +15,9 @@ RUN --mount=type=cache,target=/root/.cache/pip pip install --user semantic-versi # The docker build environment has trouble detecting CUDA version, build for all reasonable archs ENV TORCH_CUDA_ARCH_LIST="6.0 6.1 7.0 7.5 8.0 8.6" COPY requirements.txt requirements.txt -RUN --mount=type=cache,target=/root/.cache pip install --user -r requirements.txt +COPY setup.py setup.py +COPY src src +RUN --mount=type=cache,target=/root/.cache pip install --user . # ------------------------------- @@ -61,14 +63,14 @@ RUN cd text-generation-webui-tmp && python download-model.py --text-only decapod # Get LoRA RUN cd text-generation-webui-tmp && python download-model.py samwit/alpaca7b-lora && mv loras/samwit_alpaca7b-lora ../alpaca7b_lora -COPY *.py . +COPY src src COPY text-generation-webui text-generation-webui -COPY monkeypatch text-generation-webui/monkeypatch +COPY src/alpaca_lora_4bit/monkeypatch text-generation-webui/monkeypatch RUN mv -f text-generation-webui-tmp/* text-generation-webui/ # Symlink for monkeypatch -RUN cd text-generation-webui && ln -s ../autograd_4bit.py ./autograd_4bit.py && ln -s ../matmul_utils_4bit.py . +RUN cd text-generation-webui && ln -s ../src/alpaca_lora_4bit/autograd_4bit.py ./autograd_4bit.py && ln -s ../src/alpaca_lora_4bit/matmul_utils_4bit.py . && ln -s ../src/alpaca_lora_4bit/models.py . # Swap to the 7bn parameter model RUN sed -i 's/llama-13b-4bit/llama-7b-4bit/g' text-generation-webui/custom_monkey_patch.py && sed -i 's/alpaca13b_lora/alpaca7b_lora/g' text-generation-webui/custom_monkey_patch.py diff --git a/finetune.py b/finetune.py index 142692b..1253e86 100644 --- a/finetune.py +++ b/finetune.py @@ -16,21 +16,28 @@ } ] """ +import os +import sys +# set src so alpaca_lora_4bit package is available without installing +project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +src_dir = os.path.join(project_root, "src") +sys.path.insert(0, src_dir) + # Early load config to replace attn if needed -from arg_parser import get_config +from alpaca_lora_4bit.arg_parser import get_config ft_config = get_config() -from monkeypatch.peft_tuners_lora_monkey_patch import replace_peft_model_with_gptq_lora_model -replace_peft_model_with_gptq_lora_model() +from alpaca_lora_4bit.monkeypatch.peft_tuners_lora_monkey_patch import replace_peft_model_with_int4_lora_model +replace_peft_model_with_int4_lora_model() if ft_config.flash_attention: - from monkeypatch.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn + from alpaca_lora_4bit.monkeypatch.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn replace_llama_attn_with_flash_attn() elif ft_config.xformers: - from monkeypatch.llama_attn_hijack_xformers import hijack_llama_attention + from alpaca_lora_4bit.monkeypatch.llama_attn_hijack_xformers import hijack_llama_attention hijack_llama_attention() -import autograd_4bit +from alpaca_lora_4bit import autograd_4bit if ft_config.backend.lower() == 'triton': autograd_4bit.switch_backend_to('triton') else: @@ -44,11 +51,11 @@ import torch import transformers -from autograd_4bit import load_llama_model_4bit_low_ram +from alpaca_lora_4bit.autograd_4bit import load_llama_model_4bit_low_ram from peft import LoraConfig, get_peft_model, get_peft_model_state_dict, PeftModel, set_peft_model_state_dict # ! Config -import train_data +from alpaca_lora_4bit import train_data # * Show loaded parameters if ft_config.local_rank == 0: @@ -92,8 +99,8 @@ # Scales to half print('Fitting 4bit scales and zeros to half') for n, m in model.named_modules(): - if '4bit' in str(type(m)): - if m.is_v1_model: + if 'Autograd4bitQuantLinear' in str(type(m)) or 'Linear4bitLt' in str(type(m)): + if hasattr(m, "is_v1_model") and m.is_v1_model: m.zeros = m.zeros.half() m.scales = m.scales.half() @@ -120,7 +127,7 @@ # Use gradient checkpointing if ft_config.gradient_checkpointing: print('Applying gradient checkpointing ...') - from gradient_checkpointing import apply_gradient_checkpointing + from alpaca_lora_4bit.gradient_checkpointing import apply_gradient_checkpointing apply_gradient_checkpointing(model, checkpoint_ratio=ft_config.gradient_checkpointing_ratio) # Disable Trainer's DataParallel for multigpu diff --git a/inference.py b/inference.py index 9e290cd..3a32929 100644 --- a/inference.py +++ b/inference.py @@ -1,10 +1,15 @@ import os import sys +# set src so alpaca_lora_4bit package is available without installing +project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +src_dir = os.path.join(project_root, "src") +sys.path.insert(0, src_dir) + import time import torch -from autograd_4bit import load_llama_model_4bit_low_ram, Autograd4bitQuantLinear -from monkeypatch.peft_tuners_lora_monkey_patch import replace_peft_model_with_gptq_lora_model -replace_peft_model_with_gptq_lora_model() +from alpaca_lora_4bit.autograd_4bit import load_llama_model_4bit_low_ram, Autograd4bitQuantLinear +from alpaca_lora_4bit.monkeypatch.peft_tuners_lora_monkey_patch import replace_peft_model_with_int4_lora_model +replace_peft_model_with_int4_lora_model() config_path = './llama-13b-4bit/' model_path = './llama-13b-4bit.pt' diff --git a/requirements.txt b/requirements.txt index e536fe5..aafe3c7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,6 +6,5 @@ sentencepiece safetensors einops colorama -git+https://github.com/huggingface/peft.git@70af02a2bca5a63921790036b2c9430edf4037e2 -git+https://github.com/huggingface/transformers.git -git+https://github.com/sterlind/GPTQ-for-LLaMa.git@lora_4bit +peft @ git+https://github.com/huggingface/peft.git@70af02a2bca5a63921790036b2c9430edf4037e2 +transformers @ git+https://github.com/huggingface/transformers.git diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..793b566 --- /dev/null +++ b/setup.py @@ -0,0 +1,30 @@ +import sys +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + +install_requires = [] +with open("./requirements.txt", "r") as requirements_file: + reqs = [r.strip() for r in requirements_file.readlines()] + for r in reqs: + install_requires.append(r) + +quant_cuda_module = CUDAExtension( + 'alpaca_lora_4bit.quant_cuda', + sources=[ + 'src/alpaca_lora_4bit/quant_cuda/quant_cuda.cpp', + 'src/alpaca_lora_4bit/quant_cuda/quant_cuda_kernel.cu' + ]) + +setup( + name='alpaca_lora_4bit', + version='0.1', + description='Alpaca LoRA 4-bit', + package_dir={'alpaca_lora_4bit': 'src/alpaca_lora_4bit'}, + packages=['alpaca_lora_4bit', 'alpaca_lora_4bit.monkeypatch', 'alpaca_lora_4bit.quant_cuda'], + install_requires=install_requires, + extras_require={ + 'triton': 'triton', + }, + ext_modules=[quant_cuda_module], + cmdclass={'build_ext': BuildExtension}, +) diff --git a/Finetune4bConfig.py b/src/alpaca_lora_4bit/Finetune4bConfig.py similarity index 100% rename from Finetune4bConfig.py rename to src/alpaca_lora_4bit/Finetune4bConfig.py diff --git a/src/alpaca_lora_4bit/__init__.py b/src/alpaca_lora_4bit/__init__.py new file mode 100644 index 0000000..5ffa53a --- /dev/null +++ b/src/alpaca_lora_4bit/__init__.py @@ -0,0 +1,12 @@ +from . import monkeypatch +from . import amp_wrapper +from . import arg_parser +from . import autograd_4bit +from . import custom_autotune +from . import Finetune4bConfig +from . import gradient_checkpointing +from . import models +from . import train_data +# We don't import these automatically as it is dependent on whether we need cuda or triton +# from . import matmul_utils_4bit +# from . import triton_utils diff --git a/amp_wrapper.py b/src/alpaca_lora_4bit/amp_wrapper.py similarity index 100% rename from amp_wrapper.py rename to src/alpaca_lora_4bit/amp_wrapper.py diff --git a/arg_parser.py b/src/alpaca_lora_4bit/arg_parser.py similarity index 93% rename from arg_parser.py rename to src/alpaca_lora_4bit/arg_parser.py index 8ab0fe1..0bddb74 100644 --- a/arg_parser.py +++ b/src/alpaca_lora_4bit/arg_parser.py @@ -1,6 +1,6 @@ import os import argparse -from Finetune4bConfig import Finetune4bConfig +from .Finetune4bConfig import Finetune4bConfig def parse_commandline(): parser = argparse.ArgumentParser( @@ -8,12 +8,12 @@ def parse_commandline(): description="Produce LoRA in 4bit training", usage="%(prog)s [config] [training]\n\nAll arguments are optional" ) - + parser.add_argument("dataset", nargs="?", - default="./dataset.json", + default="./dataset.json", help="Path to dataset file. Default: %(default)s" ) - + parser_config = parser.add_argument_group("config") parser_training = parser.add_argument_group("training") @@ -60,14 +60,14 @@ def parse_commandline(): # Data args parser_training.add_argument("--txt_row_thd", default=-1, type=int, help="Custom thd for txt rows.") parser_training.add_argument("--use_eos_token", default=1, type=int, help="Use eos token instead if padding with 0. enable with 1, disable with 0.") - + # V2 model support parser_training.add_argument("--groupsize", type=int, default=-1, help="Groupsize of v2 model") parser_training.add_argument("--v1", action="store_true", help="Use V1 model") # Multi GPU Support parser_training.add_argument("--local_rank", type=int, default=0, help="local rank if using torch.distributed.launch") - + # Flash Attention parser_training.add_argument("--flash_attention", action="store_true", help="enables flash attention, can improve performance and reduce VRAM use") parser_training.add_argument("--xformers", action="store_true", help="enables xformers memory efficient attention, can improve performance and reduce VRAM use") @@ -81,20 +81,20 @@ def parse_commandline(): def get_config() -> Finetune4bConfig: args = parse_commandline() return Finetune4bConfig( - dataset=args["dataset"], - ds_type=args["ds_type"], - lora_out_dir=args["lora_out_dir"], + dataset=args["dataset"], + ds_type=args["ds_type"], + lora_out_dir=args["lora_out_dir"], lora_apply_dir=args["lora_apply_dir"], resume_checkpoint=args["resume_checkpoint"], llama_q4_config_dir=args["llama_q4_config_dir"], llama_q4_model=args["llama_q4_model"], mbatch_size=args["mbatch_size"], batch_size=args["batch_size"], - epochs=args["epochs"], + epochs=args["epochs"], lr=args["lr"], cutoff_len=args["cutoff_len"], - lora_r=args["lora_r"], - lora_alpha=args["lora_alpha"], + lora_r=args["lora_r"], + lora_alpha=args["lora_alpha"], lora_dropout=args["lora_dropout"], val_set_size=args["val_set_size"], gradient_checkpointing=args["grad_chckpt"], diff --git a/autograd_4bit.py b/src/alpaca_lora_4bit/autograd_4bit.py similarity index 78% rename from autograd_4bit.py rename to src/alpaca_lora_4bit/autograd_4bit.py index 5a47921..72085a9 100644 --- a/autograd_4bit.py +++ b/src/alpaca_lora_4bit/autograd_4bit.py @@ -1,40 +1,67 @@ -import matmul_utils_4bit as mm4b +import logging + import torch import torch.nn as nn import time import math from torch.cuda.amp import custom_bwd, custom_fwd from colorama import init, Fore, Back, Style +from huggingface_hub.utils._validators import HFValidationError init(autoreset=True) -class AutogradMatmul4bitCuda(torch.autograd.Function): +gptq_backend_loaded = False +triton_backend_loaded = False + +class AutogradMatmul4bitNotImplemented(torch.autograd.Function): @staticmethod @custom_fwd(cast_inputs=torch.float16) def forward(ctx, x, qweight, scales, zeros, g_idx, bits, maxq): - ctx.save_for_backward(qweight, scales, zeros, g_idx) - if g_idx is None: - output = mm4b._matmul4bit_v1_recons(x, qweight, scales, zeros) - else: - output = mm4b._matmul4bit_v2_recons(x, qweight, scales, zeros, g_idx) - output = output.clone() - return output + raise NotImplementedError() @staticmethod @custom_bwd def backward(ctx, grad_output): - qweight, scales, zeros, g_idx = ctx.saved_tensors - if ctx.needs_input_grad[0]: + raise NotImplementedError() + + +try: + from . import matmul_utils_4bit as mm4b + + class AutogradMatmul4bitCuda(torch.autograd.Function): + + @staticmethod + @custom_fwd(cast_inputs=torch.float16) + def forward(ctx, x, qweight, scales, zeros, g_idx, bits, maxq): + ctx.save_for_backward(qweight, scales, zeros, g_idx) if g_idx is None: - grad = mm4b._matmul4bit_v1_recons(grad_output, qweight, scales, zeros, transpose=True) + output = mm4b._matmul4bit_v1_recons(x, qweight, scales, zeros) else: - grad = mm4b._matmul4bit_v2_recons(grad_output, qweight, scales, zeros, g_idx, transpose=True) - return grad, None, None, None, None, None, None + output = mm4b._matmul4bit_v2_recons(x, qweight, scales, zeros, g_idx) + output = output.clone() + return output + + @staticmethod + @custom_bwd + def backward(ctx, grad_output): + qweight, scales, zeros, g_idx = ctx.saved_tensors + if ctx.needs_input_grad[0]: + if g_idx is None: + grad = mm4b._matmul4bit_v1_recons(grad_output, qweight, scales, zeros, transpose=True) + else: + grad = mm4b._matmul4bit_v2_recons(grad_output, qweight, scales, zeros, g_idx, transpose=True) + return grad, None, None, None, None, None, None + + + gptq_backend_loaded = True +except ImportError: + print('quant_cuda not found. Please run "pip install alpaca_lora_4bit[cuda]".') try: - import triton_utils as tu + from . import triton_utils as tu + class AutogradMatmul4bitTriton(torch.autograd.Function): @@ -46,7 +73,7 @@ def forward(ctx, x, qweight, scales, qzeros, g_idx, bits, maxq): ctx.bits, ctx.maxq = bits, maxq output = output.clone() return output - + @staticmethod @custom_bwd def backward(ctx, grad_output): @@ -57,25 +84,45 @@ def backward(ctx, grad_output): if ctx.needs_input_grad[0]: grad_input = tu.triton_matmul_transpose(grad_output, qweight, scales, qzeros, g_idx, bits, maxq) return grad_input, None, None, None, None, None, None - + + + triton_backend_loaded = True except ImportError: print('Triton not found. Please run "pip install triton".') -AutogradMatmul4bit = AutogradMatmul4bitCuda -backend = 'cuda' +def is_triton_backend_available(): + return 'AutogradMatmul4bitTriton' in globals() + + +def is_gptq_backend_available(): + return 'AutogradMatmul4bitCuda' in globals() + + +AutogradMatmul4bit = AutogradMatmul4bitNotImplemented +backend = None +if is_gptq_backend_available(): + AutogradMatmul4bit = AutogradMatmul4bitCuda + backend = 'cuda' +elif is_triton_backend_available(): + AutogradMatmul4bit = AutogradMatmul4bitTriton + backend = 'triton' +else: + logging.warning("Neither gptq/cuda or triton backends are available.") def switch_backend_to(to_backend): global AutogradMatmul4bit global backend if to_backend == 'cuda': + if not is_gptq_backend_available(): + raise ValueError('quant_cuda not found. Please reinstall with pip install .') AutogradMatmul4bit = AutogradMatmul4bitCuda backend = 'cuda' print(Style.BRIGHT + Fore.GREEN + 'Using CUDA implementation.') elif to_backend == 'triton': # detect if AutogradMatmul4bitTriton is defined - if 'AutogradMatmul4bitTriton' not in globals(): + if not is_triton_backend_available(): raise ValueError('Triton not found. Please install triton') AutogradMatmul4bit = AutogradMatmul4bitTriton backend = 'triton' @@ -211,7 +258,10 @@ def load_llama_model_4bit_low_ram(config_path, model_path, groupsize=-1, half=Fa if half: model_to_half(model) - tokenizer = LlamaTokenizer.from_pretrained(config_path) + try: + tokenizer = LlamaTokenizer.from_pretrained(config_path) + except HFValidationError as e: + tokenizer = LlamaTokenizer.from_pretrained(model) tokenizer.truncation_side = 'left' print(Style.BRIGHT + Fore.GREEN + f"Loaded the model in {(time.time()-t0):.2f} seconds.") @@ -248,7 +298,7 @@ def load_llama_model_4bit_low_ram_and_offload(config_path, model_path, lora_path if lora_path is not None: from peft import PeftModel - from monkeypatch.peft_tuners_lora_monkey_patch import Linear4bitLt + from .models import Linear4bitLt model = PeftModel.from_pretrained(model, lora_path, device_map={'': 'cpu'}, torch_dtype=torch.float32) print(Style.BRIGHT + Fore.GREEN + '{} Lora Applied.'.format(lora_path)) diff --git a/custom_autotune.py b/src/alpaca_lora_4bit/custom_autotune.py similarity index 100% rename from custom_autotune.py rename to src/alpaca_lora_4bit/custom_autotune.py diff --git a/gradient_checkpointing.py b/src/alpaca_lora_4bit/gradient_checkpointing.py similarity index 100% rename from gradient_checkpointing.py rename to src/alpaca_lora_4bit/gradient_checkpointing.py diff --git a/matmul_utils_4bit.py b/src/alpaca_lora_4bit/matmul_utils_4bit.py similarity index 93% rename from matmul_utils_4bit.py rename to src/alpaca_lora_4bit/matmul_utils_4bit.py index 2aaa0ad..0cd9605 100644 --- a/matmul_utils_4bit.py +++ b/src/alpaca_lora_4bit/matmul_utils_4bit.py @@ -1,7 +1,12 @@ +import logging + import torch import numpy as np -from gptq_llama import quant_cuda - +try: + from alpaca_lora_4bit import quant_cuda +except (ImportError, ModuleNotFoundError) as e: + logging.exception("Please run: `pip install alpaca_lora_4bit[cuda]`") + raise e # Global Buffer buffer_mat_dic = {} diff --git a/monkeypatch/peft_tuners_lora_monkey_patch.py b/src/alpaca_lora_4bit/models.py similarity index 78% rename from monkeypatch/peft_tuners_lora_monkey_patch.py rename to src/alpaca_lora_4bit/models.py index 4e6b677..1c3cea9 100644 --- a/monkeypatch/peft_tuners_lora_monkey_patch.py +++ b/src/alpaca_lora_4bit/models.py @@ -1,4 +1,3 @@ -import math import re import torch import warnings @@ -6,17 +5,16 @@ from peft.tuners import lora from peft.tuners.lora import is_bnb_available, Linear, Linear8bitLt, LoraLayer -from peft.utils import _get_submodules, PeftType -from torch import nn +from peft.utils import _get_submodules from transformers.pytorch_utils import Conv1D -from autograd_4bit import Autograd4bitQuantLinear +from alpaca_lora_4bit.autograd_4bit import Autograd4bitQuantLinear class Linear4bitLt(Autograd4bitQuantLinear, LoraLayer): - # Lora implemented in a dense layer - def __init__( + # Lora implemented in a dense layer + def __init__( self, adapter_name, in_features, @@ -27,62 +25,62 @@ def __init__( lora_alpha: int = 1, lora_dropout: float = 0.0, **kwargs, - ): - Autograd4bitQuantLinear.__init__( - self, - in_features, - out_features, - groupsize, - is_v1_model - ) - LoraLayer.__init__(self, in_features=in_features, out_features=out_features) + ): + Autograd4bitQuantLinear.__init__( + self, + in_features, + out_features, + groupsize, + is_v1_model + ) + LoraLayer.__init__(self, in_features=in_features, out_features=out_features) + + # Freezing the pre-trained weight matrix + self.qweight.requires_grad = False + self.scales.requires_grad = False + if self.is_v1_model: + self.zeros.requires_grad = False + else: + self.qzeros.requires_grad = False + self.g_idx.requires_grad = False + self.bias.requires_grad = False - # Freezing the pre-trained weight matrix - self.qweight.requires_grad = False - self.scales.requires_grad = False - if self.is_v1_model: - self.zeros.requires_grad = False - else: - self.qzeros.requires_grad = False - self.g_idx.requires_grad = False - self.bias.requires_grad = False - - init_lora_weights = kwargs.pop("init_lora_weights", True) - self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights) - self.active_adapter = adapter_name - - def forward(self, x: torch.Tensor): - result = super().forward(x) - - if self.disable_adapters or self.active_adapter not in self.lora_A.keys(): - return result - elif self.r[self.active_adapter] > 0: - if not torch.is_autocast_enabled(): - expected_dtype = result.dtype - - if x.dtype != torch.float32: - x = x.float() - output = ( + init_lora_weights = kwargs.pop("init_lora_weights", True) + self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights) + self.active_adapter = adapter_name + + def forward(self, x: torch.Tensor): + result = super().forward(x) + + if self.disable_adapters or self.active_adapter not in self.lora_A.keys(): + return result + elif self.r[self.active_adapter] > 0: + if not torch.is_autocast_enabled(): + expected_dtype = result.dtype + + if x.dtype != torch.float32: + x = x.float() + output = ( self.lora_B[self.active_adapter]( self.lora_A[self.active_adapter](self.lora_dropout[self.active_adapter](x)) ).to(expected_dtype) * self.scaling[self.active_adapter] - ) - else: - output = ( + ) + else: + output = ( self.lora_B[self.active_adapter]( self.lora_A[self.active_adapter](self.lora_dropout[self.active_adapter](x)) ) * self.scaling[self.active_adapter] - ) - result += output - return result - - @property - def weight(self): - class WeightDeviceClass: - device = self.qweight.device - return WeightDeviceClass() + ) + result += output + return result + + @property + def weight(self): + class WeightDeviceClass: + device = self.qweight.device + return WeightDeviceClass() class GPTQLoraModel(lora.LoraModel): @@ -200,8 +198,3 @@ def _replace_module(self, parent_module, child_name, new_module, old_module): for name, module in new_module.named_modules(): if "lora_" in name: module.to(old_module.weight.device) - - -def replace_peft_model_with_gptq_lora_model(): - import peft.peft_model - peft.peft_model.PEFT_TYPE_TO_MODEL_MAPPING[PeftType.LORA] = GPTQLoraModel diff --git a/monkeypatch/__init__.py b/src/alpaca_lora_4bit/monkeypatch/__init__.py similarity index 100% rename from monkeypatch/__init__.py rename to src/alpaca_lora_4bit/monkeypatch/__init__.py diff --git a/monkeypatch/llama_attn_hijack_xformers.py b/src/alpaca_lora_4bit/monkeypatch/llama_attn_hijack_xformers.py similarity index 100% rename from monkeypatch/llama_attn_hijack_xformers.py rename to src/alpaca_lora_4bit/monkeypatch/llama_attn_hijack_xformers.py diff --git a/monkeypatch/llama_flash_attn_monkey_patch.py b/src/alpaca_lora_4bit/monkeypatch/llama_flash_attn_monkey_patch.py similarity index 100% rename from monkeypatch/llama_flash_attn_monkey_patch.py rename to src/alpaca_lora_4bit/monkeypatch/llama_flash_attn_monkey_patch.py diff --git a/src/alpaca_lora_4bit/monkeypatch/peft_tuners_lora_monkey_patch.py b/src/alpaca_lora_4bit/monkeypatch/peft_tuners_lora_monkey_patch.py new file mode 100644 index 0000000..31c61ae --- /dev/null +++ b/src/alpaca_lora_4bit/monkeypatch/peft_tuners_lora_monkey_patch.py @@ -0,0 +1,5 @@ +def replace_peft_model_with_int4_lora_model(): + import peft.peft_model + from peft import PeftType + from ..models import GPTQLoraModel + peft.peft_model.PEFT_TYPE_TO_MODEL_MAPPING[PeftType.LORA] = GPTQLoraModel diff --git a/src/alpaca_lora_4bit/quant_cuda/quant_cuda.cpp b/src/alpaca_lora_4bit/quant_cuda/quant_cuda.cpp new file mode 100644 index 0000000..7a23b50 --- /dev/null +++ b/src/alpaca_lora_4bit/quant_cuda/quant_cuda.cpp @@ -0,0 +1,163 @@ +#include +#include +#include + +void vecquant2matmul_cuda( + torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor scales, torch::Tensor zeros, + int groupsize +); + +void vecquant2matmul( + torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor scales, torch::Tensor zeros, + int groupsize +) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(vec)); + vecquant2matmul_cuda(vec, mat, mul, scales, zeros,groupsize); +} + +void vecquant3matmul_cuda( + torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor scales, torch::Tensor zeros, + int groupsize +); + +void vecquant3matmul( + torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor scales, torch::Tensor zeros, + int groupsize +) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(vec)); + vecquant3matmul_cuda(vec, mat, mul, scales, zeros, groupsize); +} + +void vecquant4matmul_cuda( + torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor scales, torch::Tensor zeros, + int groupsize +); + +void vecquant4matmul( + torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor scales, torch::Tensor zeros, + int groupsize +) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(vec)); + vecquant4matmul_cuda(vec, mat, mul, scales, zeros, groupsize); +} + +void vecquant8matmul_cuda( + torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor scales, torch::Tensor zeros, + int groupsize +); + +void vecquant8matmul( + torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor scales, torch::Tensor zeros, + int groupsize +) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(vec)); + vecquant8matmul_cuda(vec, mat, mul, scales, zeros, groupsize); +} + +void vecquant2matmul_faster_cuda( + torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor scales, torch::Tensor zeros, + int groupsize, int vec_height +); + +void vecquant2matmul_faster( + torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor scales, torch::Tensor zeros, + int groupsize, int vec_height +) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(vec)); + vecquant2matmul_faster_cuda(vec, mat, mul, scales, zeros, groupsize, vec_height); +} + +void vecquant3matmul_faster_cuda( + torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor scales, torch::Tensor zeros, + int groupsize, int vec_height +); + +void vecquant3matmul_faster( + torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor scales, torch::Tensor zeros, + int groupsize, int vec_height +) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(vec)); + vecquant3matmul_faster_cuda(vec, mat, mul, scales, zeros, groupsize, vec_height); +} + +void vecquant4matmul_faster_cuda( + torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor scales, torch::Tensor zeros, + torch::Tensor g_idx, int vec_height +); + +void vecquant4matmul_faster( + torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor scales, torch::Tensor zeros, + torch::Tensor g_idx, int vec_height +) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(vec)); + vecquant4matmul_faster_cuda(vec, mat, mul, scales, zeros, g_idx, vec_height); +} + +void vecquant4recons_v1_cuda( + torch::Tensor mat, torch::Tensor res, torch::Tensor scales, torch::Tensor zeros +); + +void vecquant4recons_v1( + torch::Tensor mat, torch::Tensor res, + torch::Tensor scales, torch::Tensor zeros +) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(scales)); + vecquant4recons_v1_cuda(mat, res, scales, zeros); +} + +void vecquant4recons_v2_cuda( + torch::Tensor mat, torch::Tensor res, + torch::Tensor scales, torch::Tensor zeros, + torch::Tensor g_idx +); + +void vecquant4recons_v2( + torch::Tensor mat, torch::Tensor res, torch::Tensor scales, torch::Tensor zeros, torch::Tensor g_idx +) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(scales)); + vecquant4recons_v2_cuda(mat, res, scales, zeros, g_idx); +} + +void vecquant4matmul_v1_faster_cuda( + torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor scales, torch::Tensor zeros +); + +void vecquant4matmul_v1_faster( + torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor scales, torch::Tensor zeros +) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(vec)); + vecquant4matmul_v1_faster_cuda(vec, mat, mul, scales, zeros); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("vecquant2matmul", &vecquant2matmul, "Vector 2-bit Quantized Matrix Multiplication (CUDA)"); + m.def("vecquant3matmul", &vecquant3matmul, "Vector 3-bit Quantized Matrix Multiplication (CUDA)"); + m.def("vecquant4matmul", &vecquant4matmul, "Vector 4-bit Quantized Matrix Multiplication (CUDA)"); + m.def("vecquant8matmul", &vecquant8matmul, "Vector 8-bit Quantized Matrix Multiplication (CUDA)"); + m.def("vecquant2matmul_faster", &vecquant2matmul_faster, "Vector 4-bit Quantized Matrix Multiplication (CUDA), faster version"); + m.def("vecquant3matmul_faster", &vecquant3matmul_faster, "Vector 3-bit Quantized Matrix Multiplication (CUDA), faster version"); + m.def("vecquant4matmul_faster", &vecquant4matmul_faster, "Vector 4-bit Quantized Matrix Multiplication (CUDA), faster version"); + + // V1 Support for vecquant4matmul_faster + m.def("vecquant4matmul_v1_faster", &vecquant4matmul_v1_faster, "Vector 4-bit Quantized Matrix Multiplication (CUDA), faster version, v1 support"); + + // Reconstruction Kernel + m.def("vecquant4recons_v1", &vecquant4recons_v1, "Vector 4-bit Quantized Matrix Reconstruction (CUDA)"); + m.def("vecquant4recons_v2", &vecquant4recons_v2, "Vector 4-bit Quantized Matrix Reconstruction (CUDA) with group-size support"); +} \ No newline at end of file diff --git a/src/alpaca_lora_4bit/quant_cuda/quant_cuda_kernel.cu b/src/alpaca_lora_4bit/quant_cuda/quant_cuda_kernel.cu new file mode 100644 index 0000000..4f8396b --- /dev/null +++ b/src/alpaca_lora_4bit/quant_cuda/quant_cuda_kernel.cu @@ -0,0 +1,1118 @@ +#include +#include +#include +#include +#include + +// atomicAdd for double-precision floating-point numbers on hardware with +// compute capability < 6.0 from: +// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#atomic-functions +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 600 +__device__ double atomicAdd( + double* address, + double val +) { + unsigned long long int* address_as_ull = (unsigned long long int*)address; + unsigned long long int old = *address_as_ull, assumed; + + do { + assumed = old; + old = atomicCAS( + address_as_ull, + assumed, + __double_as_longlong(val + __longlong_as_double(assumed)) + ); + + // Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN) + } while (assumed != old); + + return __longlong_as_double(old); +} +#endif + +__device__ inline int as_int(int i) { + return *reinterpret_cast(&i); +} + +template +__global__ void VecQuant2MatMulKernel( + const scalar_t* __restrict__ vec, + const int* __restrict__ mat, + scalar_t* __restrict__ mul, + const scalar_t* __restrict__ scales, + const int* __restrict__ zeros, + int batch, + int vec_height, + int height, + int width, + int zero_width, + int groupsize +); + +template +__global__ void VecQuant3MatMulKernel( + const scalar_t* __restrict__ vec, + const int* __restrict__ mat, + scalar_t* __restrict__ mul, + const scalar_t* __restrict__ scales, + const int* __restrict__ zeros, + int batch, + int vec_height, + int height, + int width, + int zero_width, + int groupsize +); + +template +__global__ void VecQuant4MatMulKernel( + const scalar_t* __restrict__ vec, + const int* __restrict__ mat, + scalar_t* __restrict__ mul, + const scalar_t* __restrict__ scales, + const int* __restrict__ zeros, + int batch, + int vec_height, + int height, + int width, + int zero_width, + int groupsize +); + +template +__global__ void VecQuant8MatMulKernel( + const scalar_t* __restrict__ vec, + const int* __restrict__ mat, + scalar_t* __restrict__ mul, + const scalar_t* __restrict__ scales, + const int* __restrict__ zeros, + int batch, + int vec_height, + int height, + int width, + int zero_width, + int groupsize +); + +__global__ void VecQuant2MatMulKernelFaster( + const half2* __restrict__ vec, + const int* __restrict__ mat, + float* __restrict__ mul, + const float* __restrict__ scales, + const int* __restrict__ zeros, + int batch, + int vec_height, + int height, + int width, + int zero_width, + int groupsize +); + +__global__ void VecQuant3MatMulKernelFaster( + const half2* __restrict__ vec, + const int* __restrict__ mat, + float* __restrict__ mul, + const float* __restrict__ scales, + const int* __restrict__ zeros, + int batch, + int vec_height, + int height, + int width, + int zero_width, + int groupsize +); + +__global__ void VecQuant4MatMulKernelFaster( + const half2* __restrict__ vec, + const int* __restrict__ mat, + float* __restrict__ mul, + const float* __restrict__ scales, + const int* __restrict__ zeros, + const int* __restrict__ g_idx, + int batch, + int vec_height, + int height, + int width, + int zero_width +); + +const int BLOCKWIDTH = 256; +const int BLOCKHEIGHT2 = 16; +const int BLOCKHEIGHT3 = 24; +const int BLOCKHEIGHT4 = 32; +const int BLOCKHEIGHT8 = 64; + +__device__ inline unsigned int as_unsigned(int i) { + return *reinterpret_cast(&i); +} + +void vecquant2matmul_cuda( + torch::Tensor vec, + torch::Tensor mat, + torch::Tensor mul, + torch::Tensor scales, + torch::Tensor zeros, + int groupsize +) { + int batch = vec.size(0); + int vec_height = vec.size(1); + int height = mat.size(0); + int width = mat.size(1); + int zero_width = zeros.size(1); + + dim3 blocks( + (height + BLOCKHEIGHT2 - 1) / BLOCKHEIGHT2, + (width + BLOCKWIDTH - 1) / BLOCKWIDTH, + batch + ); + dim3 threads(BLOCKWIDTH); + + AT_DISPATCH_FLOATING_TYPES( + vec.type(), "vecquant2matmul_cuda", ([&] { + VecQuant2MatMulKernel<<>>( + vec.data(), mat.data(), mul.data(), + scales.data(), zeros.data(), + batch, vec_height, height, width, zero_width, groupsize + ); + }) + ); +} + +template +__global__ void VecQuant2MatMulKernel( + const scalar_t* __restrict__ vec, + const int* __restrict__ mat, + scalar_t* __restrict__ mul, + const scalar_t* __restrict__ scales, + const int* __restrict__ zeros, + int batch, + int vec_height, + int height, + int width, + int zero_width, + int groupsize +) { + int b = blockIdx.z; + int h = BLOCKHEIGHT2 * blockIdx.x; + int w = BLOCKWIDTH * blockIdx.y + threadIdx.x; + + __shared__ scalar_t blockvec[BLOCKWIDTH]; + blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * BLOCKWIDTH + threadIdx.x]; + __syncthreads(); + + scalar_t res = 0; + int i = width * h + w; + int g_h = h * 16; + int k = 0; + + int z_w = w / 16; + int z_mod = (w % 16) * 2; + + unsigned int tmp; + + while (k < BLOCKWIDTH) { + tmp = as_unsigned(mat[i]); + + int g = (g_h + k) / groupsize; + scalar_t scale = scales[g * width + w]; + scalar_t zero = scale * scalar_t((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod & 0x3) + 1); + + res += (scale * scalar_t((tmp >> 0) & 0x3) - zero) * blockvec[k + 0]; + res += (scale * scalar_t((tmp >> 2) & 0x3) - zero) * blockvec[k + 1]; + res += (scale * scalar_t((tmp >> 4) & 0x3) - zero) * blockvec[k + 2]; + res += (scale * scalar_t((tmp >> 6) & 0x3) - zero) * blockvec[k + 3]; + res += (scale * scalar_t((tmp >> 8) & 0x3) - zero) * blockvec[k + 4]; + res += (scale * scalar_t((tmp >> 10) & 0x3) - zero) * blockvec[k + 5]; + res += (scale * scalar_t((tmp >> 12) & 0x3) - zero) * blockvec[k + 6]; + res += (scale * scalar_t((tmp >> 14) & 0x3) - zero) * blockvec[k + 7]; + res += (scale * scalar_t((tmp >> 16) & 0x3) - zero) * blockvec[k + 8]; + res += (scale * scalar_t((tmp >> 18) & 0x3) - zero) * blockvec[k + 9]; + res += (scale * scalar_t((tmp >> 20) & 0x3) - zero) * blockvec[k + 10]; + res += (scale * scalar_t((tmp >> 22) & 0x3) - zero) * blockvec[k + 11]; + res += (scale * scalar_t((tmp >> 24) & 0x3) - zero) * blockvec[k + 12]; + res += (scale * scalar_t((tmp >> 26) & 0x3) - zero) * blockvec[k + 13]; + res += (scale * scalar_t((tmp >> 28) & 0x3) - zero) * blockvec[k + 14]; + res += (scale * scalar_t((tmp >> 30) & 0x3) - zero) * blockvec[k + 15]; + + i += width; + k += 16; + } + + atomicAdd(&mul[b * width + w], res); +} + +void vecquant3matmul_cuda( + torch::Tensor vec, + torch::Tensor mat, + torch::Tensor mul, + torch::Tensor scales, + torch::Tensor zeros, + int groupsize +) { + int batch = vec.size(0); + int vec_height = vec.size(1); + int height = mat.size(0); + int width = mat.size(1); + int zero_width = zeros.size(1); + + dim3 blocks( + (height + BLOCKHEIGHT3 - 1) / BLOCKHEIGHT3, + (width + BLOCKWIDTH - 1) / BLOCKWIDTH, + batch + ); + dim3 threads(BLOCKWIDTH); + + AT_DISPATCH_FLOATING_TYPES( + vec.type(), "vecquant3matmul_cuda", ([&] { + VecQuant3MatMulKernel<<>>( + vec.data(), mat.data(), mul.data(), + scales.data(), zeros.data(), + batch, vec_height, height, width, zero_width, groupsize + ); + }) + ); +} + +template +__global__ void VecQuant3MatMulKernel( + const scalar_t* __restrict__ vec, + const int* __restrict__ mat, + scalar_t* __restrict__ mul, + const scalar_t* __restrict__ scales, + const int* __restrict__ zeros, + int batch, + int vec_height, + int height, + int width, + int zero_width, + int groupsize +) { + int b = blockIdx.z; + int h = BLOCKHEIGHT3 * blockIdx.x; + int w = BLOCKWIDTH * blockIdx.y + threadIdx.x; + + __shared__ scalar_t blockvec[BLOCKWIDTH]; + blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * BLOCKWIDTH + threadIdx.x]; + __syncthreads(); + + scalar_t res = 0; + int i = width * h + w; + int g_h = (h / 3) * 32; + int k = 0; + + int z_w = (w / 32) * 3; + int z_mod = w % 32; + int z_bit; + + if (z_mod != 10){ + if (z_mod != 21){ + z_bit = z_mod; + if (z_bit > 21){ + z_bit -= 22; + z_bit *= 3; + z_bit += 2; + z_w += 2; + } else if (z_bit > 10){ + z_bit -= 11; + z_bit *= 3; + z_bit += 1; + z_w += 1; + } else { + z_bit *= 3; + } + } else { + z_w += 1; + } + } + + unsigned int tmp1; + unsigned int tmp2; + unsigned int tmp; + unsigned int z_tmp; + + while (k < BLOCKWIDTH) { + tmp1 = as_unsigned(mat[i]); + + int g = (g_h + k) / groupsize; + scalar_t scale = scales[g * width + w]; + scalar_t zero; + if (z_mod == 10) { + z_tmp = (as_unsigned(zeros[g * zero_width + z_w]) >> 30) | ((as_unsigned(zeros[g * zero_width + (z_w + 1)]) << 2) & 0x4); + zero = scale * scalar_t((z_tmp) + 1); + } else if (z_mod == 21){ + z_tmp = (as_unsigned(zeros[g * zero_width + z_w]) >> 31) | ((as_unsigned(zeros[g * zero_width + (z_w + 1)]) << 1) & 0x6); + zero = scale * scalar_t((z_tmp) + 1); + } else { + zero = scale * scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_bit) & 0x7) + 1); + } + + res += (scale * scalar_t((tmp1 >> 0) & 0x7) - zero) * blockvec[k + 0]; + res += (scale * scalar_t((tmp1 >> 3) & 0x7) - zero) * blockvec[k + 1]; + res += (scale * scalar_t((tmp1 >> 6) & 0x7) - zero) * blockvec[k + 2]; + res += (scale * scalar_t((tmp1 >> 9) & 0x7) - zero) * blockvec[k + 3]; + res += (scale * scalar_t((tmp1 >> 12) & 0x7) - zero) * blockvec[k + 4]; + res += (scale * scalar_t((tmp1 >> 15) & 0x7) - zero) * blockvec[k + 5]; + res += (scale * scalar_t((tmp1 >> 18) & 0x7) - zero) * blockvec[k + 6]; + res += (scale * scalar_t((tmp1 >> 21) & 0x7) - zero) * blockvec[k + 7]; + res += (scale * scalar_t((tmp1 >> 24) & 0x7) - zero) * blockvec[k + 8]; + res += (scale * scalar_t((tmp1 >> 27) & 0x7) - zero) * blockvec[k + 9]; + + i += width; + tmp2 = as_unsigned(mat[i]); + tmp = (tmp1 >> 30) | ((tmp2 << 2) & 0x4); + tmp2 >>= 1; + res += (scale * scalar_t(tmp) - zero) * blockvec[k + 10]; + k += 11; + + res += (scale * scalar_t((tmp2 >> 0) & 0x7) - zero) * blockvec[k + 0]; + res += (scale * scalar_t((tmp2 >> 3) & 0x7) - zero) * blockvec[k + 1]; + res += (scale * scalar_t((tmp2 >> 6) & 0x7) - zero) * blockvec[k + 2]; + res += (scale * scalar_t((tmp2 >> 9) & 0x7) - zero) * blockvec[k + 3]; + res += (scale * scalar_t((tmp2 >> 12) & 0x7) - zero) * blockvec[k + 4]; + res += (scale * scalar_t((tmp2 >> 15) & 0x7) - zero) * blockvec[k + 5]; + res += (scale * scalar_t((tmp2 >> 18) & 0x7) - zero) * blockvec[k + 6]; + res += (scale * scalar_t((tmp2 >> 21) & 0x7) - zero) * blockvec[k + 7]; + res += (scale * scalar_t((tmp2 >> 24) & 0x7) - zero) * blockvec[k + 8]; + res += (scale * scalar_t((tmp2 >> 27) & 0x7) - zero) * blockvec[k + 9]; + + i += width; + tmp1 = as_unsigned(mat[i]); + tmp = (tmp2 >> 30) | ((tmp1 << 1) & 0x6); + tmp1 >>= 2; + res += (scale * scalar_t(tmp) - zero) * blockvec[k + 10]; + k += 11; + + res += (scale * scalar_t((tmp1 >> 0) & 0x7) - zero) * blockvec[k + 0]; + res += (scale * scalar_t((tmp1 >> 3) & 0x7) - zero) * blockvec[k + 1]; + res += (scale * scalar_t((tmp1 >> 6) & 0x7) - zero) * blockvec[k + 2]; + res += (scale * scalar_t((tmp1 >> 9) & 0x7) - zero) * blockvec[k + 3]; + res += (scale * scalar_t((tmp1 >> 12) & 0x7) - zero) * blockvec[k + 4]; + res += (scale * scalar_t((tmp1 >> 15) & 0x7) - zero) * blockvec[k + 5]; + res += (scale * scalar_t((tmp1 >> 18) & 0x7) - zero) * blockvec[k + 6]; + res += (scale * scalar_t((tmp1 >> 21) & 0x7) - zero) * blockvec[k + 7]; + res += (scale * scalar_t((tmp1 >> 24) & 0x7) - zero) * blockvec[k + 8]; + res += (scale * scalar_t((tmp1 >> 27) & 0x7) - zero) * blockvec[k + 9]; + + i += width; + k += 10; + } + + atomicAdd(&mul[b * width + w], res); +} + +void vecquant4matmul_cuda( + torch::Tensor vec, + torch::Tensor mat, + torch::Tensor mul, + torch::Tensor scales, + torch::Tensor zeros, + int groupsize +) { + int batch = vec.size(0); + int vec_height = vec.size(1); + int height = mat.size(0); + int width = mat.size(1); + int zero_width = zeros.size(1); + + dim3 blocks( + (height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4, + (width + BLOCKWIDTH - 1) / BLOCKWIDTH, + batch + ); + dim3 threads(BLOCKWIDTH); + + AT_DISPATCH_FLOATING_TYPES( + vec.type(), "vecquant4matmul_cuda", ([&] { + VecQuant4MatMulKernel<<>>( + vec.data(), mat.data(), mul.data(), + scales.data(), zeros.data(), + batch, vec_height, height, width, zero_width, groupsize + ); + }) + ); +} + +template +__global__ void VecQuant4MatMulKernel( + const scalar_t* __restrict__ vec, + const int* __restrict__ mat, + scalar_t* __restrict__ mul, + const scalar_t* __restrict__ scales, + const int* __restrict__ zeros, + int batch, + int vec_height, + int height, + int width, + int zero_width, + int groupsize +) { + int b = blockIdx.z; + int h = BLOCKHEIGHT4 * blockIdx.x; + int w = BLOCKWIDTH * blockIdx.y + threadIdx.x; + + __shared__ scalar_t blockvec[BLOCKWIDTH]; + blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * BLOCKWIDTH + threadIdx.x]; + __syncthreads(); + + scalar_t res = 0; + int i = width * h + w; + int g_h = h * 8; + int k = 0; + + int z_w = w / 8; + int z_mod = (w % 8) * 4; + + unsigned int tmp; + + while (k < BLOCKWIDTH) { + tmp = as_unsigned(mat[i]); + + int g = (g_h + k) / groupsize; + scalar_t scale = scales[g * width + w]; + scalar_t zero = scale * scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xF) + 1); + + res += (scale * scalar_t((tmp >> 0) & 0xF) - zero) * blockvec[k + 0]; + res += (scale * scalar_t((tmp >> 4) & 0xF) - zero) * blockvec[k + 1]; + res += (scale * scalar_t((tmp >> 8) & 0xF) - zero) * blockvec[k + 2]; + res += (scale * scalar_t((tmp >> 12) & 0xF) - zero) * blockvec[k + 3]; + res += (scale * scalar_t((tmp >> 16) & 0xF) - zero) * blockvec[k + 4]; + res += (scale * scalar_t((tmp >> 20) & 0xF) - zero) * blockvec[k + 5]; + res += (scale * scalar_t((tmp >> 24) & 0xF) - zero) * blockvec[k + 6]; + res += (scale * scalar_t((tmp >> 28) & 0xF) - zero) * blockvec[k + 7]; + + i += width; + k += 8; + } + + atomicAdd(&mul[b * width + w], res); +} + +void vecquant8matmul_cuda( + torch::Tensor vec, + torch::Tensor mat, + torch::Tensor mul, + torch::Tensor scales, + torch::Tensor zeros, + int groupsize +) { + int batch = vec.size(0); + int vec_height = vec.size(1); + int height = mat.size(0); + int width = mat.size(1); + int zero_width = zeros.size(1); + + dim3 blocks( + (height + BLOCKHEIGHT8 - 1) / BLOCKHEIGHT8, + (width + BLOCKWIDTH - 1) / BLOCKWIDTH, + batch + ); + dim3 threads(BLOCKWIDTH); + + AT_DISPATCH_FLOATING_TYPES( + vec.type(), "vecquant8matmul_cuda", ([&] { + VecQuant8MatMulKernel<<>>( + vec.data(), mat.data(), mul.data(), + scales.data(), zeros.data(), + batch, vec_height, height, width, zero_width, groupsize + ); + }) + ); +} + +template +__global__ void VecQuant8MatMulKernel( + const scalar_t* __restrict__ vec, + const int* __restrict__ mat, + scalar_t* __restrict__ mul, + const scalar_t* __restrict__ scales, + const int* __restrict__ zeros, + int batch, + int vec_height, + int height, + int width, + int zero_width, + int groupsize +) { + int b = blockIdx.z; + int h = BLOCKHEIGHT8 * blockIdx.x; + int w = BLOCKWIDTH * blockIdx.y + threadIdx.x; + + __shared__ scalar_t blockvec[BLOCKWIDTH]; + blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * BLOCKWIDTH + threadIdx.x]; + __syncthreads(); + + scalar_t res = 0; + int i = width * h + w; + int g_h = h * 4; + int k = 0; + + int z_w = w / 4; + int z_mod = (w % 4) * 8; + + unsigned int tmp; + + while (k < BLOCKWIDTH) { + tmp = as_unsigned(mat[i]); + + int g = (g_h + k) / groupsize; + scalar_t scale = scales[g * width + w]; + scalar_t zero = scale * scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xFF) + 1); + + res += (scale * scalar_t((tmp >> 0) & 0xFF) - zero) * blockvec[k + 0]; + res += (scale * scalar_t((tmp >> 8) & 0xFF) - zero) * blockvec[k + 1]; + res += (scale * scalar_t((tmp >> 16) & 0xFF) - zero) * blockvec[k + 2]; + res += (scale * scalar_t((tmp >> 24) & 0xFF) - zero) * blockvec[k + 3]; + + i += width; + k += 4; + } + + atomicAdd(&mul[b * width + w], res); +} + + +void vecquant2matmul_faster_cuda( + torch::Tensor vec, + torch::Tensor mat, + torch::Tensor mul, + torch::Tensor scales, + torch::Tensor zeros, + int groupsize, + int vec_height +) { + int batch = vec.size(0); + int height = mat.size(0); + int width = mat.size(1); + int zero_width = zeros.size(1); + + dim3 blocks( + (height + BLOCKHEIGHT2 - 1) / BLOCKHEIGHT2, + (width + BLOCKWIDTH - 1) / BLOCKWIDTH, + batch + ); + dim3 threads(BLOCKWIDTH); + + VecQuant2MatMulKernelFaster<<>>( + (half2*) vec.data_ptr(), + mat.data_ptr(), + mul.data_ptr(), + scales.data_ptr(), + zeros.data_ptr(), + batch, vec_height, height, width, zero_width, groupsize + ); +} + +__global__ void VecQuant2MatMulKernelFaster( + const half2* __restrict__ vec, + const int* __restrict__ mat, + float* __restrict__ mul, + const float* __restrict__ scales, + const int* __restrict__ zeros, + int batch, + int vec_height, + int height, + int width, + int zero_width, + int groupsize +) { + const int blockwidth2 = BLOCKWIDTH / 2; + int b = blockIdx.z; + int h = BLOCKHEIGHT2 * blockIdx.x; + int w = BLOCKWIDTH * blockIdx.y + threadIdx.x; + + __shared__ half2 blockvec[blockwidth2]; + if (threadIdx.x < blockwidth2) + blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * blockwidth2 + threadIdx.x]; + + __shared__ half2 deq2[16][16]; + int val = threadIdx.x / 16; + int off = threadIdx.x % 16; + for (; val < 16; val += BLOCKWIDTH / 16) { + deq2[val][off] = __halves2half2( + __int2half_rn(val & 0x3), __int2half_rn(val >> 2) + ); + } + + int i = width * h + w; + int g_h = h * 16; + int k = 0; + + int z_w = w / 16; + int z_mod = (w % 16) * 2; + + float res = 0; + half2 res2; + + unsigned int tmp; + + __syncthreads(); + + while (k < blockwidth2) { + int g = (g_h + (k * 2)) / groupsize; + float scale_f = scales[g * width + w]; + half2 scale = __float2half2_rn(scale_f); + half2 zero = __float2half2_rn(-(scale_f * (((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0x3) + 1))); + + res2 = {}; + tmp = as_unsigned(mat[i]); + res2 = __hfma2(__hfma2(deq2[(tmp >> 0) & 0xf][off], scale, zero), blockvec[k + 0], res2); + res2 = __hfma2(__hfma2(deq2[(tmp >> 4) & 0xf][off], scale, zero), blockvec[k + 1], res2); + res2 = __hfma2(__hfma2(deq2[(tmp >> 8) & 0xf][off], scale, zero), blockvec[k + 2], res2); + res2 = __hfma2(__hfma2(deq2[(tmp >> 12) & 0xf][off], scale, zero), blockvec[k + 3], res2); + res2 = __hfma2(__hfma2(deq2[(tmp >> 16) & 0xf][off], scale, zero), blockvec[k + 4], res2); + res2 = __hfma2(__hfma2(deq2[(tmp >> 20) & 0xf][off], scale, zero), blockvec[k + 5], res2); + res2 = __hfma2(__hfma2(deq2[(tmp >> 24) & 0xf][off], scale, zero), blockvec[k + 6], res2); + res2 = __hfma2(__hfma2(deq2[(tmp >> 28) & 0xf][off], scale, zero), blockvec[k + 7], res2); + i += width; + k += 8; + res += __half2float(res2.x) + __half2float(res2.y); + } + + atomicAdd(&mul[b * width + w], res); +} + +void vecquant3matmul_faster_cuda( + torch::Tensor vec, + torch::Tensor mat, + torch::Tensor mul, + torch::Tensor scales, + torch::Tensor zeros, + int groupsize, + int vec_height +) { + int batch = vec.size(0); + int height = mat.size(0); + int width = mat.size(1); + int zero_width = zeros.size(1); + + dim3 blocks( + (height + BLOCKHEIGHT3 - 1) / BLOCKHEIGHT3, + (width + BLOCKWIDTH - 1) / BLOCKWIDTH, + batch + ); + dim3 threads(BLOCKWIDTH); + + VecQuant3MatMulKernelFaster<<>>( + (half2*) vec.data_ptr(), + mat.data_ptr(), + mul.data_ptr(), + scales.data_ptr(), + zeros.data_ptr(), + batch, vec_height, height, width, zero_width, groupsize + ); +} + +__global__ void VecQuant3MatMulKernelFaster( + const half2* __restrict__ vec, + const int* __restrict__ mat, + float* __restrict__ mul, + const float* __restrict__ scales, + const int* __restrict__ zeros, + int batch, + int vec_height, + int height, + int width, + int zero_width, + int groupsize +) { + const int blockwidth2 = BLOCKWIDTH / 2; + int b = blockIdx.z; + int h = BLOCKHEIGHT3 * blockIdx.x; + int w = BLOCKWIDTH * blockIdx.y + threadIdx.x; + + __shared__ half2 blockvec[blockwidth2]; + if (threadIdx.x < blockwidth2) + blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * blockwidth2 + threadIdx.x]; + + __shared__ half2 deq2[64][32]; + int val = threadIdx.x / 32; + int off = threadIdx.x % 32; + for (; val < 64; val += BLOCKWIDTH / 32) { + deq2[val][off] = __halves2half2( + __int2half_rn(val & 0x7), __int2half_rn(val >> 3) + ); + } + + int i = width * h + w; + int g_h = (h / 3) * 32; + int k = 0; + + int z_w = (w / 32) * 3; + int z_mod = w % 32; + int z_bit; + + if (z_mod != 10){ + if (z_mod != 21){ + z_bit = z_mod; + if (z_bit > 21){ + z_bit -= 22; + z_bit *= 3; + z_bit += 2; + z_w += 2; + } else if (z_bit > 10){ + z_bit -= 11; + z_bit *= 3; + z_bit += 1; + z_w += 1; + } else { + z_bit *= 3; + } + } else { + z_w += 1; + } + } + + float res = 0; + half2 res2; + + unsigned int tmp1; + unsigned int tmp2; + unsigned int tmp; + unsigned int z_tmp; + + __syncthreads(); + + while (k < blockwidth2) { + int g = (g_h + (k * 2)) / groupsize; + float scale_f = scales[g * width + w]; + half2 scale = __float2half2_rn(scale_f); + half2 zero; + if (z_mod == 10) { + z_tmp = (as_unsigned(zeros[g * zero_width + z_w]) >> 30) | ((as_unsigned(zeros[g * zero_width + (z_w + 1)]) << 2) & 0x4); + zero = __float2half2_rn(-(scale_f * ((z_tmp) + 1))); + } else if (z_mod == 21){ + z_tmp = (as_unsigned(zeros[g * zero_width + z_w]) >> 31) | ((as_unsigned(zeros[g * zero_width + (z_w + 1)]) << 1) & 0x6); + zero = __float2half2_rn(-(scale_f * ((z_tmp) + 1))); + } else { + zero = __float2half2_rn(-(scale_f * (((as_unsigned(zeros[g * zero_width + z_w]) >> z_bit) & 0x7) + 1))); + } + + res2 = {}; + tmp1 = as_unsigned(mat[i]); + res2 = __hfma2(__hfma2(deq2[(tmp1 >> 0) & 0x3f][off], scale, zero), blockvec[k + 0], res2); + res2 = __hfma2(__hfma2(deq2[(tmp1 >> 6) & 0x3f][off], scale, zero), blockvec[k + 1], res2); + res2 = __hfma2(__hfma2(deq2[(tmp1 >> 12) & 0x3f][off], scale, zero), blockvec[k + 2], res2); + res2 = __hfma2(__hfma2(deq2[(tmp1 >> 18) & 0x3f][off], scale, zero), blockvec[k + 3], res2); + res2 = __hfma2(__hfma2(deq2[(tmp1 >> 24) & 0x3f][off], scale, zero), blockvec[k + 4], res2); + i += width; + tmp2 = as_unsigned(mat[i]); + tmp = (tmp1 >> 30) | ((tmp2 << 2) & 0x3c); + res2 = __hfma2(__hfma2(deq2[tmp][off], scale, zero), blockvec[k + 5], res2); + tmp2 >>= 4; + k += 6; + res2 = __hfma2(__hfma2(deq2[(tmp2 >> 0) & 0x3f][off], scale, zero), blockvec[k + 0], res2); + res2 = __hfma2(__hfma2(deq2[(tmp2 >> 6) & 0x3f][off], scale, zero), blockvec[k + 1], res2); + res2 = __hfma2(__hfma2(deq2[(tmp2 >> 12) & 0x3f][off], scale, zero), blockvec[k + 2], res2); + res2 = __hfma2(__hfma2(deq2[(tmp2 >> 18) & 0x3f][off], scale, zero), blockvec[k + 3], res2); + i += width; + tmp1 = as_unsigned(mat[i]); + tmp = (tmp2 >> 24) | ((tmp1 << 4) & 0x30); + res2 = __hfma2(__hfma2(deq2[tmp][off], scale, zero), blockvec[k + 4], res2); + tmp1 >>= 2; + k += 5; + res2 = __hfma2(__hfma2(deq2[(tmp1 >> 0) & 0x3f][off], scale, zero), blockvec[k + 0], res2); + res2 = __hfma2(__hfma2(deq2[(tmp1 >> 6) & 0x3f][off], scale, zero), blockvec[k + 1], res2); + res2 = __hfma2(__hfma2(deq2[(tmp1 >> 12) & 0x3f][off], scale, zero), blockvec[k + 2], res2); + res2 = __hfma2(__hfma2(deq2[(tmp1 >> 18) & 0x3f][off], scale, zero), blockvec[k + 3], res2); + res2 = __hfma2(__hfma2(deq2[(tmp1 >> 24) & 0x3f][off], scale, zero), blockvec[k + 4], res2); + i += width; + k += 5; + res += __half2float(res2.x) + __half2float(res2.y); + } + + atomicAdd(&mul[b * width + w], res); +} + +void vecquant4matmul_faster_cuda( + torch::Tensor vec, + torch::Tensor mat, + torch::Tensor mul, + torch::Tensor scales, + torch::Tensor zeros, + torch::Tensor g_idx, + int vec_height +) { + int batch = vec.size(0); + int height = mat.size(0); + int width = mat.size(1); + int zero_width = zeros.size(1); + + dim3 blocks( + (height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4, + (width + BLOCKWIDTH - 1) / BLOCKWIDTH, + batch + ); + dim3 threads(BLOCKWIDTH); + + VecQuant4MatMulKernelFaster<<>>( + (half2*) vec.data_ptr(), + mat.data_ptr(), + mul.data_ptr(), + scales.data_ptr(), + zeros.data_ptr(), + g_idx.data_ptr(), + batch, vec_height, height, width, zero_width + ); +} + +__global__ void VecQuant4MatMulKernelFaster( + const half2* __restrict__ vec, + const int* __restrict__ mat, + float* __restrict__ mul, + const float* __restrict__ scales, + const int* __restrict__ zeros, + const int* __restrict__ g_idx, + int batch, + int vec_height, + int height, + int width, + int zero_width +) { + const int blockwidth2 = BLOCKWIDTH / 2; + int b = blockIdx.z; + int h = BLOCKHEIGHT4 * blockIdx.x; + int w = BLOCKWIDTH * blockIdx.y + threadIdx.x; + + __shared__ half2 blockvec[blockwidth2]; + if (threadIdx.x < blockwidth2) + blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * blockwidth2 + threadIdx.x]; + + __shared__ half2 deq2[256][8]; + int val = threadIdx.x / 8; + int off = threadIdx.x % 8; + for (; val < 256; val += BLOCKWIDTH / 8) { + deq2[val][off] = __halves2half2( + __int2half_rn(val & 0xF), __int2half_rn(val >> 4) + ); + } + + int i = width * h + w; + int g_h = h * 8; + int k = 0; + + int z_w = w / 8; + int z_mod = (w % 8) * 4; + + float res = 0; + half2 res2; + + unsigned int tmp; + + __syncthreads(); + + while (k < blockwidth2) { + int g = as_int(g_idx[g_h + (k * 2)]); + float scale_f = scales[g * width + w]; + half2 scale = __float2half2_rn(scale_f); + half2 zero = __float2half2_rn(-(scale_f * (((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xF) + 1))); + + res2 = {}; + tmp = as_unsigned(mat[i]); + res2 = __hfma2(__hfma2(deq2[(tmp >> 0) & 0xff][off], scale, zero), blockvec[k + 0], res2); + res2 = __hfma2(__hfma2(deq2[(tmp >> 8) & 0xff][off], scale, zero), blockvec[k + 1], res2); + res2 = __hfma2(__hfma2(deq2[(tmp >> 16) & 0xff][off], scale, zero), blockvec[k + 2], res2); + res2 = __hfma2(__hfma2(deq2[(tmp >> 24) & 0xff][off], scale, zero), blockvec[k + 3], res2); + i += width; + k += 4; + res += __half2float(res2.x) + __half2float(res2.y); + } + + atomicAdd(&mul[b * width + w], res); +} + +template +__global__ void VecQuant4ReconsV1Kernel( + const int* __restrict__ mat, + scalar_t* __restrict__ res, + const scalar_t* __restrict__ scales, + const scalar_t* __restrict__ zeros, + int height, + int width +) { + int b = blockIdx.z; + int h = BLOCKHEIGHT4 * blockIdx.x; + int w = BLOCKWIDTH * blockIdx.y + threadIdx.x; + int n_rows = h * 8 + b; + int n_cols = w; + scalar_t scale = scales[w]; + scalar_t zero = zeros[w]; + int i = width * h + width * (b / 8) + w; + int shift = b % 8 * 4; + unsigned int tmp = as_unsigned(mat[i]); + scalar_t result = (scale * scalar_t((tmp >> shift) & 0xF) - zero); + res[n_rows * width + n_cols] = result; +} + +void vecquant4recons_v1_cuda( + torch::Tensor mat, + torch::Tensor res, + torch::Tensor scales, + torch::Tensor zeros +) { + int batch = BLOCKWIDTH; + int height = mat.size(0); + int width = mat.size(1); + + dim3 blocks( + (height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4, + (width + BLOCKWIDTH - 1) / BLOCKWIDTH, + batch + ); + dim3 threads(BLOCKWIDTH); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + scales.type(), "vecquant4recons_v1_cuda", ([&] { + VecQuant4ReconsV1Kernel<<>>( + mat.data(), res.data(), + scales.data(), zeros.data(), + height, width + ); + }) + ); +} + +template +__global__ void VecQuant4ReconsV2Kernel( + const int* __restrict__ mat, + scalar_t* __restrict__ res, + const scalar_t* __restrict__ scales, + const int* __restrict__ zeros, + const int* __restrict__ g_idx, + int height, + int width, + int zero_width +) { + int b = blockIdx.z; + int h = BLOCKHEIGHT4 * blockIdx.x; + int w = BLOCKWIDTH * blockIdx.y + threadIdx.x; + int n_rows = h * 8 + b; + int n_cols = w; + int z_rows = as_int(g_idx[n_rows]); + int z_cols = n_cols / 8; + int z_shift = (n_cols % 8) * 4; + scalar_t scale = scales[z_rows * width + n_cols]; + scalar_t zero = scale * scalar_t(((as_unsigned(zeros[z_rows * zero_width + z_cols]) >> z_shift) & 0xF) + 1); + int i = width * h + width * (b / 8) + w; + int shift = b % 8 * 4; + unsigned int tmp = as_unsigned(mat[i]); + scalar_t result = (scale * scalar_t((tmp >> shift) & 0xF) - zero); + res[n_rows * width + n_cols] = result; +} + +void vecquant4recons_v2_cuda( + torch::Tensor mat, + torch::Tensor res, + torch::Tensor scales, + torch::Tensor zeros, + torch::Tensor g_idx +) { + int batch = BLOCKWIDTH; + int height = mat.size(0); + int width = mat.size(1); + int zero_width = zeros.size(1); + + dim3 blocks( + (height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4, + (width + BLOCKWIDTH - 1) / BLOCKWIDTH, + batch + ); + dim3 threads(BLOCKWIDTH); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + scales.type(), "vecquant4recons_v2_cuda", ([&] { + VecQuant4ReconsV2Kernel<<>>( + mat.data(), res.data(), + scales.data(), zeros.data(), + g_idx.data(), height, width, zero_width + ); + }) + ); +} + +__global__ void VecQuant4MatMulV1KernelFaster( + const half2* __restrict__ vec, + const int* __restrict__ mat, + float* __restrict__ mul, + const float* __restrict__ scales, + const float* __restrict__ zeros, + int batch, + int vec_height, + int height, + int width +) { + const int blockwidth2 = BLOCKWIDTH / 2; + int b = blockIdx.z; + int h = BLOCKHEIGHT4 * blockIdx.x; + int w = BLOCKWIDTH * blockIdx.y + threadIdx.x; + + __shared__ half2 blockvec[blockwidth2]; + if (threadIdx.x < blockwidth2) + blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * blockwidth2 + threadIdx.x]; + + __shared__ half2 deq2[256][8]; + int val = threadIdx.x / 8; + int off = threadIdx.x % 8; + for (; val < 256; val += BLOCKWIDTH / 8) { + deq2[val][off] = __halves2half2( + __int2half_rn(val & 0xF), __int2half_rn(val >> 4) + ); + } + + int i = width * h + w; + int k = 0; + + float res = 0; + half2 res2; + + unsigned int tmp; + + __syncthreads(); + + while (k < blockwidth2) { + float scale_f = scales[w]; + float zero_f = zeros[w]; + half2 scale = __float2half2_rn(scale_f); + half2 zero = __float2half2_rn(-zero_f); + + res2 = {}; + tmp = as_unsigned(mat[i]); + res2 = __hfma2(__hfma2(deq2[(tmp >> 0) & 0xff][off], scale, zero), blockvec[k + 0], res2); + res2 = __hfma2(__hfma2(deq2[(tmp >> 8) & 0xff][off], scale, zero), blockvec[k + 1], res2); + res2 = __hfma2(__hfma2(deq2[(tmp >> 16) & 0xff][off], scale, zero), blockvec[k + 2], res2); + res2 = __hfma2(__hfma2(deq2[(tmp >> 24) & 0xff][off], scale, zero), blockvec[k + 3], res2); + i += width; + k += 4; + res += __half2float(res2.x) + __half2float(res2.y); + } + + atomicAdd(&mul[b * width + w], res); +} + +void vecquant4matmul_v1_faster_cuda( + torch::Tensor vec, + torch::Tensor mat, + torch::Tensor mul, + torch::Tensor scales, + torch::Tensor zeros +) { + int batch = vec.size(0); + int vec_height = vec.size(1) / 2; + int height = mat.size(0); + int width = mat.size(1); + + dim3 blocks( + (height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4, + (width + BLOCKWIDTH - 1) / BLOCKWIDTH, + batch + ); + dim3 threads(BLOCKWIDTH); + + VecQuant4MatMulV1KernelFaster<<>>( + (half2*) vec.data_ptr(), + mat.data_ptr(), + mul.data_ptr(), + scales.data_ptr(), + zeros.data_ptr(), + batch, vec_height, height, width + ); +} \ No newline at end of file diff --git a/train_data.py b/src/alpaca_lora_4bit/train_data.py similarity index 100% rename from train_data.py rename to src/alpaca_lora_4bit/train_data.py diff --git a/triton_utils.py b/src/alpaca_lora_4bit/triton_utils.py similarity index 96% rename from triton_utils.py rename to src/alpaca_lora_4bit/triton_utils.py index 7afcf46..ea40f40 100644 --- a/triton_utils.py +++ b/src/alpaca_lora_4bit/triton_utils.py @@ -1,7 +1,7 @@ import triton import triton.language as tl import torch -import custom_autotune +from . import custom_autotune # code based https://github.com/fpgaminer/GPTQ-triton @@ -44,7 +44,7 @@ def matmul_248_kernel(a_ptr, b_ptr, c_ptr, C is of shape (M, N) float16 scales is of shape (G, N) float16 zeros is of shape (G, N) float16 - g_ptr is of shape (K) int32 + g_ptr is of shape (K) int32 """ infearure_per_bits = 32 // bits @@ -69,22 +69,22 @@ def matmul_248_kernel(a_ptr, b_ptr, c_ptr, g_ptrs = g_ptr + offs_k # shifter is used to extract the N bits of each element in the 32-bit word from B scales_ptrs = scales_ptr + offs_bn[None, :] - zeros_ptrs = zeros_ptr + (offs_bn[None, :] // infearure_per_bits) - + zeros_ptrs = zeros_ptr + (offs_bn[None, :] // infearure_per_bits) + shifter = (offs_k % infearure_per_bits) * bits zeros_shifter = (offs_bn % infearure_per_bits) * bits accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - + for k in range(0, num_pid_k): g_idx = tl.load(g_ptrs) # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop scales = tl.load(scales_ptrs + g_idx[:, None] * stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) zeros = tl.load(zeros_ptrs + g_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) - + zeros = (zeros >> zeros_shifter[None, :]) & maxq zeros = (zeros + 1) - + a = tl.load(a_ptrs, mask=a_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_K) b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated @@ -104,7 +104,7 @@ def matmul_248_kernel(a_ptr, b_ptr, c_ptr, c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :] c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N) tl.store(c_ptrs, c, mask=c_mask) - + # code based https://github.com/fpgaminer/GPTQ-triton @custom_autotune.autotune( @@ -146,7 +146,7 @@ def trans_matmul_248_kernel(a_ptr, b_ptr, c_ptr, C is of shape (M, K) float16 scales is of shape (G, N) float16 zeros is of shape (G, N) float16 - g_ptr is of shape (K) int32 + g_ptr is of shape (K) int32 """ infearure_per_bits = 32 // bits @@ -170,23 +170,23 @@ def trans_matmul_248_kernel(a_ptr, b_ptr, c_ptr, b_ptrs = b_ptr + ((offs_bk[:, None] // infearure_per_bits) * stride_bk + offs_n[None, :] * stride_bn) # (BLOCK_SIZE_K, BLOCK_SIZE_N) g_ptrs = g_ptr + offs_bk g_idx = tl.load(g_ptrs) - + # shifter is used to extract the N bits of each element in the 32-bit word from B scales_ptrs = scales_ptr + offs_n[None, :] + g_idx[:, None] * stride_scales zeros_ptrs = zeros_ptr + (offs_n[None, :] // infearure_per_bits) + g_idx[:, None] * stride_zeros - + shifter = (offs_bk % infearure_per_bits) * bits zeros_shifter = (offs_n % infearure_per_bits) * bits accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32) - + for k in range(0, num_pid_n): # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop scales = tl.load(scales_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) zeros = tl.load(zeros_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) - + zeros = (zeros >> zeros_shifter[None, :]) & maxq zeros = (zeros + 1) - + a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_N) b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated @@ -203,13 +203,13 @@ def trans_matmul_248_kernel(a_ptr, b_ptr, c_ptr, b_ptrs += BLOCK_SIZE_N scales_ptrs += BLOCK_SIZE_N zeros_ptrs += (BLOCK_SIZE_N // infearure_per_bits) - + c = accumulator.to(tl.float16) c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bk[None, :] c_mask = (offs_am[:, None] < M) & (offs_bk[None, :] < K) tl.store(c_ptrs, c, mask=c_mask) - - + + def triton_matmul(input, qweight, scales, qzeros, g_idx, bits, maxq): assert input.shape[-1] == qweight.shape[0] * 32 // bits outshape = input.shape[:-1] + (qweight.shape[1],) diff --git a/text-generation-webui/custom_monkey_patch.py b/text-generation-webui/custom_monkey_patch.py index 28bb53e..32c03ff 100644 --- a/text-generation-webui/custom_monkey_patch.py +++ b/text-generation-webui/custom_monkey_patch.py @@ -3,8 +3,9 @@ import autograd_4bit from autograd_4bit import load_llama_model_4bit_low_ram, Autograd4bitQuantLinear from peft import PeftModel -from monkeypatch.peft_tuners_lora_monkey_patch import replace_peft_model_with_gptq_lora_model, Linear4bitLt -replace_peft_model_with_gptq_lora_model() +from monkeypatch.peft_tuners_lora_monkey_patch import replace_peft_model_with_int4_lora_model +from models import Linear4bitLt +replace_peft_model_with_int4_lora_model() patch_encode_func = False @@ -16,12 +17,12 @@ def load_model_llama(*args, **kwargs): print("Loading {} ...".format(model_path)) t0 = time.time() - + model, tokenizer = load_llama_model_4bit_low_ram(config_path, model_path, groupsize=-1, is_v1_model=True) - + model = PeftModel.from_pretrained(model, lora_path, device_map={'': 0}, torch_dtype=torch.float32) print('{} Lora Applied.'.format(lora_path)) - + print('Apply auto switch and half') for n, m in model.named_modules(): if isinstance(m, Autograd4bitQuantLinear) or isinstance(m, Linear4bitLt): @@ -31,13 +32,13 @@ def load_model_llama(*args, **kwargs): m.bias = m.bias.half() autograd_4bit.use_new = True autograd_4bit.auto_switch = True - + return model, tokenizer # Monkey Patch -from modules import models +from modules import models as modelz from modules import shared -models.load_model = load_model_llama +modelz.load_model = load_model_llama shared.args.model = 'llama-13b-4bit' shared.settings['name1'] = 'You' shared.settings['name2'] = 'Assistant'