Skip to content

Commit

Permalink
[Neo][vLLM] Accept quant options for awq, fp8 (#2382)
Browse files Browse the repository at this point in the history
  • Loading branch information
a-ys committed Sep 13, 2024
1 parent a7cb6c2 commit f037b32
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 40 deletions.
63 changes: 43 additions & 20 deletions serving/docker/partition/partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import argparse
import subprocess

from typing import Optional
from pathlib import Path

import utils
Expand Down Expand Up @@ -243,7 +242,7 @@ def load_the_generated_checkpoints(self):
if entry_point_file:
os.remove(os.path.join(saved_checkpoints_dir, 'model.py'))

def run_quantization(self, autofp8_config: Optional[dict] = None):
def run_quantization(self):
quant_method = self.properties['option.quantize']
if quant_method == 'awq':
logging.info("Running AutoAWQ quantization")
Expand All @@ -252,7 +251,7 @@ def run_quantization(self, autofp8_config: Optional[dict] = None):
self.upload_checkpoints_to_s3()
elif quant_method == 'fp8':
logging.info("Running AutoFP8 quantization")
self.autofp8_quantize(autofp8_config)
self.autofp8_quantize()
self.properties_manager.generate_properties_file()
self.upload_checkpoints_to_s3()
else:
Expand All @@ -263,16 +262,27 @@ def autoawq_quantize(self):
Quantizes model using AutoAWQ. Saves output to save_mp_checkpoint_path.
"""
hf_configs, tokenizer = load_hf_config_and_tokenizer(self.properties)
logging.info(f"Model loading kwargs: {hf_configs.kwargs}")

# Hard-coding these options for now. If vLLM continues to prioritize
# AutoAWQ we will expose these options to customers in the future.
quant_config = {
"zero_point": True,
"q_group_size": 128,
"w_bit": 4,
"version": "GEMM"
"zero_point":
self.properties.get("option.awq_zero_point",
"true").lower() == 'true',
"q_group_size":
int(self.properties.get("option.awq_block_size", "128")),
"w_bit":
int(self.properties.get("option.awq_weight_bit_width", "4")),
"version":
self.properties.get("option.awq_mm_version", "GEMM")
}
logging.info(f"Model loading kwargs: {hf_configs.kwargs}")
if self.properties.get("option.awq_ignore_layers"):
quant_config["modules_to_not_convert"] = [
s.strip() for s in self.properties.get(
"option.awq_ignore_layers").split(',')
]
logging.info(
f"Using the following configurations for AWQ quantization: {quant_config}"
)
try:
from awq import AutoAWQForCausalLM
awq_model = AutoAWQForCausalLM.from_pretrained(
Expand All @@ -289,38 +299,51 @@ def autoawq_quantize(self):
raise ImportError(
"AutoAWQ is not installed. Failing during quantization.")

def autofp8_quantize(self, config: Optional[dict] = None):
def autofp8_quantize(self):
"""
Quantizes model using AutoFP8.
:param config: Dictionary containing values to construct auto_fp8.BaseQuantizeConfig
"""
# initialize configs
hf_configs, tokenizer = load_hf_config_and_tokenizer(self.properties)
if not tokenizer.pad_token:
tokenizer.pad_token = tokenizer.eos_token

config = {
k: v
for k, v in config.items() if v is not None
} if config else {}
if config.get("activation_scheme") == "dynamic":
quant_config = {
"activation_scheme":
self.properties.get("option.fp8_activation_scheme", "static"),
}
if self.properties.get("option.fp8_kv_cache_quant_targets"):
quant_config["kv_cache_quant_targets"] = tuple([
s.strip() for s in self.properties.get(
"option.fp8_kv_cache_quant_targets").split(',')
])
if self.properties.get("option.fp8_ignore_patterns"):
quant_config["ignore_patterns"] = [
s.strip() for s in self.properties.get(
"option.fp8_ignore_patterns").split(',')
]

# create samples for calibrating scaling factors
if quant_config["activation_scheme"] == "dynamic":
# If using dynamic activation scales, a calibration dataset is not required
examples = []
else:
calib_size = self.properties.get("option.calib_size", 512)
# Tokenize dataset for calibrating static activation scales
ds = load_dataset("abisee/cnn_dailymail",
"3.0.0",
split="validation").shuffle(seed=42).select(
range(512))
range(calib_size))
examples = [batch["article"] for batch in ds]
examples = tokenizer(examples,
padding=True,
truncation=True,
return_tensors="pt").to("cuda")

# quantization
try:
from auto_fp8 import AutoFP8ForCausalLM, BaseQuantizeConfig
quantize_config = BaseQuantizeConfig(**config)
quantize_config = BaseQuantizeConfig(**quant_config)
logging.info(
f"Using the following configurations for fp8 quantization: {vars(quantize_config)}"
)
Expand Down
23 changes: 3 additions & 20 deletions serving/docker/partition/sm_neo_quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,14 @@
from types import SimpleNamespace
from typing import Final
import torch
import json

from sm_neo_utils import (OptimizationFatalError, InputConfiguration,
write_error_to_file, get_neo_env_vars,
update_dataset_cache_location)
from sm_neo_utils import (OptimizationFatalError, write_error_to_file,
get_neo_env_vars, update_dataset_cache_location)
from utils import (extract_python_jar, load_properties)
from properties_manager import PropertiesManager
from partition import PartitionService

PYTHON_CACHE_DIR = '/tmp/djlserving/cache'
AUTOFP8_CONFIG_ENVVAR = 'AUTOFP8_CONFIG'


class NeoQuantizationService():
Expand All @@ -48,7 +45,6 @@ def __init__(self):

self.customer_properties: dict = load_properties(
self.INPUT_MODEL_DIRECTORY)
self.autofp8_config = None

def initialize_partition_args_namespace(self):
"""
Expand Down Expand Up @@ -87,26 +83,14 @@ def construct_properties_manager(self):
self.properties_manager = PropertiesManager(
self.args, addl_properties=addl_properties)

def parse_autofp8_config(self) -> dict:
autofp8_config = os.environ.get(AUTOFP8_CONFIG_ENVVAR, {})
if autofp8_config:
try:
autofp8_config = json.loads(autofp8_config)
if not isinstance(autofp8_config, dict):
raise ValueError("Parsed JSON is not a dictionary")
self.autofp8_config = autofp8_config
except Exception as exc:
raise InputConfiguration(
f"Failed to parse AutoFP8 configuration: {exc}")

def run_quantization(self) -> str:
"""
:return: the output of the partition command captured from stdout
"""
partition_service = PartitionService(self.properties_manager)
extract_python_jar(PYTHON_CACHE_DIR)
try:
return partition_service.run_quantization(self.autofp8_config)
return partition_service.run_quantization()
except Exception as exc:
raise OptimizationFatalError(
f"Encountered an error during quantization: {exc}")
Expand Down Expand Up @@ -150,7 +134,6 @@ def neo_quantize(self):
update_dataset_cache_location(self.HF_CACHE_LOCATION)
self.initialize_partition_args_namespace()
self.construct_properties_manager()
self.parse_autofp8_config()
self.run_quantization()
self.write_properties()

Expand Down

0 comments on commit f037b32

Please sign in to comment.