Skip to content

Commit

Permalink
support streaming
Browse files Browse the repository at this point in the history
  • Loading branch information
wejoncy committed Sep 20, 2024
1 parent a43019a commit 299f2f0
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 16 deletions.
1 change: 1 addition & 0 deletions .github/workflows/ruff.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,5 @@ jobs:
isort . --check-only
- name: Running yapf
run: |
echo "please run \" yapf --recursive .\" if errors"
yapf --diff --recursive .
45 changes: 33 additions & 12 deletions vptq/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# --------------------------------------------------------------------------

import argparse
import os

import transformers

Expand All @@ -28,21 +29,41 @@ def define_basic_args():
return parser


def chat_loop(model, tokenizer):
def eval_prompt(model, tokenizer, args):
inputs = tokenizer(args.prompt, return_tensors="pt").to(model.device)
streamer = transformers.TextStreamer(tokenizer)
model.generate(**inputs, streamer=streamer, max_new_tokens=100, pad_token_id=2)


def chat_loop(model, tokenizer, args):
if not args.chat:
eval_prompt(model, tokenizer, args)
return

if getattr(tokenizer, "chat_template", None) is None:
print("warning: this tokenizer didn't provide chat_template.!!!")
eval_prompt(model, tokenizer, args)
return

print("============================chat with the model============================")
print("Press 'exit' to quit")
messages = [{"role": "system", "content": "you are a math teacher."}]

while True:
text = input("You: ")
if text == "exit":
break
messages.append({"role": "user", "content": text})
encodeds = tokenizer.apply_chat_template(messages, return_tensors="pt")
model_inputs = encodeds.to("cuda")
generated_ids = model.generate(model_inputs, pad_token_id=2, max_new_tokens=500, do_sample=True)
model_inputs = encodeds.to(model.device)
streamer = transformers.TextStreamer(tokenizer, skip_prompt=True, decode_kwargs={"skip_special_tokens": True})
generated_ids = model.generate(model_inputs,
streamer=streamer,
pad_token_id=2,
max_new_tokens=500,
do_sample=True)
decoded = tokenizer.batch_decode(generated_ids[:, model_inputs.shape[-1]:], skip_special_tokens=True)
messages.append({"role": "assistant", "content": decoded[0]})
print("assistant:", decoded[0])


def get_valid_args(parser):
Expand All @@ -55,15 +76,15 @@ def main():
args = get_valid_args(parser)
print(args)

model = VQAutoModelQuantization.from_pretrained(args.model, device_map="auto").half()
tokenizer = transformers.AutoTokenizer.from_pretrained(args.tokenizer or args.model)
hf_args = {}
token = os.getenv("HF_TOKEN", None)
if token is not None:
hf_args["token"] = token

if args.chat:
chat_loop(model, tokenizer)
return
inputs = tokenizer(args.prompt, return_tensors="pt").to(model.device)
out = model.generate(**inputs, max_new_tokens=100, pad_token_id=2)
print(tokenizer.decode(out[0], skip_special_tokens=False))
model = VQAutoModelQuantization.from_pretrained(args.model, device_map="auto", **hf_args).half()
tokenizer = transformers.AutoTokenizer.from_pretrained(args.tokenizer or args.model, **hf_args)

chat_loop(model, tokenizer, args)


main()
6 changes: 3 additions & 3 deletions vptq/layers/model_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import transformers
from tqdm import tqdm

from .qlinear import QuantLinear
from .vqlinear import VQuantLinear


def set_op_by_name(layer, name, new_module):
Expand Down Expand Up @@ -57,7 +57,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
with transformers.utils.generic.ContextManagers(init_contexts):
model = cls.from_config(auto_conf, *model_args, **cls_kwargs)

target_layer = QuantLinear
target_layer = VQuantLinear
quant_config = auto_conf.quant_config

# replace linear layers with quantized linear layers
Expand Down Expand Up @@ -95,7 +95,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
max_memory=max_memory,
no_split_module_classes=no_split_module_classes[0],
dtype=auto_conf.torch_dtype,
# preload_module_classes=["QuantLinear"]
# preload_module_classes=["VQuantLinear"]
)

# weight_bins = glob.glob(str(Path(pretrained_model_name_or_path).absolute() / '*.safetensors'))
Expand Down
2 changes: 1 addition & 1 deletion vptq/layers/qlinear.py → vptq/layers/vqlinear.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from torch.nn.parameter import Parameter


class QuantLinear(nn.Module):
class VQuantLinear(nn.Module):

def __init__(
self,
Expand Down

0 comments on commit 299f2f0

Please sign in to comment.