From 299f2f084ea8e9a92cf5368217208268bbd19fba Mon Sep 17 00:00:00 2001 From: wejoncy Date: Fri, 20 Sep 2024 04:45:02 -0700 Subject: [PATCH] support streaming --- .github/workflows/ruff.yml | 1 + vptq/__main__.py | 45 ++++++++++++++++++------- vptq/layers/model_base.py | 6 ++-- vptq/layers/{qlinear.py => vqlinear.py} | 2 +- 4 files changed, 38 insertions(+), 16 deletions(-) rename vptq/layers/{qlinear.py => vqlinear.py} (99%) diff --git a/.github/workflows/ruff.yml b/.github/workflows/ruff.yml index 41c4c44..2818e4b 100644 --- a/.github/workflows/ruff.yml +++ b/.github/workflows/ruff.yml @@ -38,4 +38,5 @@ jobs: isort . --check-only - name: Running yapf run: | + echo "please run \" yapf --recursive .\" if errors" yapf --diff --recursive . \ No newline at end of file diff --git a/vptq/__main__.py b/vptq/__main__.py index 957ed51..8aa1cba 100644 --- a/vptq/__main__.py +++ b/vptq/__main__.py @@ -4,6 +4,7 @@ # -------------------------------------------------------------------------- import argparse +import os import transformers @@ -28,21 +29,41 @@ def define_basic_args(): return parser -def chat_loop(model, tokenizer): +def eval_prompt(model, tokenizer, args): + inputs = tokenizer(args.prompt, return_tensors="pt").to(model.device) + streamer = transformers.TextStreamer(tokenizer) + model.generate(**inputs, streamer=streamer, max_new_tokens=100, pad_token_id=2) + + +def chat_loop(model, tokenizer, args): + if not args.chat: + eval_prompt(model, tokenizer, args) + return + + if getattr(tokenizer, "chat_template", None) is None: + print("warning: this tokenizer didn't provide chat_template.!!!") + eval_prompt(model, tokenizer, args) + return + print("============================chat with the model============================") print("Press 'exit' to quit") messages = [{"role": "system", "content": "you are a math teacher."}] + while True: text = input("You: ") if text == "exit": break messages.append({"role": "user", "content": text}) encodeds = tokenizer.apply_chat_template(messages, return_tensors="pt") - model_inputs = encodeds.to("cuda") - generated_ids = model.generate(model_inputs, pad_token_id=2, max_new_tokens=500, do_sample=True) + model_inputs = encodeds.to(model.device) + streamer = transformers.TextStreamer(tokenizer, skip_prompt=True, decode_kwargs={"skip_special_tokens": True}) + generated_ids = model.generate(model_inputs, + streamer=streamer, + pad_token_id=2, + max_new_tokens=500, + do_sample=True) decoded = tokenizer.batch_decode(generated_ids[:, model_inputs.shape[-1]:], skip_special_tokens=True) messages.append({"role": "assistant", "content": decoded[0]}) - print("assistant:", decoded[0]) def get_valid_args(parser): @@ -55,15 +76,15 @@ def main(): args = get_valid_args(parser) print(args) - model = VQAutoModelQuantization.from_pretrained(args.model, device_map="auto").half() - tokenizer = transformers.AutoTokenizer.from_pretrained(args.tokenizer or args.model) + hf_args = {} + token = os.getenv("HF_TOKEN", None) + if token is not None: + hf_args["token"] = token - if args.chat: - chat_loop(model, tokenizer) - return - inputs = tokenizer(args.prompt, return_tensors="pt").to(model.device) - out = model.generate(**inputs, max_new_tokens=100, pad_token_id=2) - print(tokenizer.decode(out[0], skip_special_tokens=False)) + model = VQAutoModelQuantization.from_pretrained(args.model, device_map="auto", **hf_args).half() + tokenizer = transformers.AutoTokenizer.from_pretrained(args.tokenizer or args.model, **hf_args) + + chat_loop(model, tokenizer, args) main() diff --git a/vptq/layers/model_base.py b/vptq/layers/model_base.py index bc8fa91..47071ce 100644 --- a/vptq/layers/model_base.py +++ b/vptq/layers/model_base.py @@ -13,7 +13,7 @@ import transformers from tqdm import tqdm -from .qlinear import QuantLinear +from .vqlinear import VQuantLinear def set_op_by_name(layer, name, new_module): @@ -57,7 +57,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): with transformers.utils.generic.ContextManagers(init_contexts): model = cls.from_config(auto_conf, *model_args, **cls_kwargs) - target_layer = QuantLinear + target_layer = VQuantLinear quant_config = auto_conf.quant_config # replace linear layers with quantized linear layers @@ -95,7 +95,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): max_memory=max_memory, no_split_module_classes=no_split_module_classes[0], dtype=auto_conf.torch_dtype, - # preload_module_classes=["QuantLinear"] + # preload_module_classes=["VQuantLinear"] ) # weight_bins = glob.glob(str(Path(pretrained_model_name_or_path).absolute() / '*.safetensors')) diff --git a/vptq/layers/qlinear.py b/vptq/layers/vqlinear.py similarity index 99% rename from vptq/layers/qlinear.py rename to vptq/layers/vqlinear.py index bea3b3e..572b3b6 100644 --- a/vptq/layers/qlinear.py +++ b/vptq/layers/vqlinear.py @@ -12,7 +12,7 @@ from torch.nn.parameter import Parameter -class QuantLinear(nn.Module): +class VQuantLinear(nn.Module): def __init__( self,