Skip to content

Commit

Permalink
TorchFX: GPTQ accuracy fix (#26294)
Browse files Browse the repository at this point in the history
### Details:
- Fix for the accuracy issues discovered in Llama2 GPTQ with
aot_autograd

### Tickets:
 - [CVS-149032](https://jira.devtools.intel.com/browse/CVS-149032)

---------

Co-authored-by: Maxim Vafin <[email protected]>
  • Loading branch information
cavusmustafa and mvafin authored Oct 18, 2024
1 parent 62183ab commit 43df0b6
Show file tree
Hide file tree
Showing 4 changed files with 250 additions and 52 deletions.
11 changes: 11 additions & 0 deletions .github/workflows/job_pytorch_models_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,17 @@ jobs:
TEST_DEVICE: CPU
USE_SYSTEM_CACHE: False

- name: TorchFX GPTQ Pattern Test
if: ${{ inputs.model_scope == 'precommit' }}
# install torch 2.3.1 as newer is not yet supported by openvino backend
run: |
export PYTHONPATH=${MODEL_HUB_TESTS_INSTALL_DIR}:$PYTHONPATH
python3 -m pip install torch==2.3.1 torchvision==0.18.1 torchaudio==2.3.1 --upgrade --index-url https://download.pytorch.org/whl/cpu
python3 -m pytest ${MODEL_HUB_TESTS_INSTALL_DIR}/transformation_tests/test_gptq_torchfx_transformations.py -m precommit --html=${INSTALL_TEST_DIR}/TEST-torch_gptqpattern_tests.html --self-contained-html -v --tb=short
env:
TEST_DEVICE: CPU
USE_SYSTEM_CACHE: False

- name: Reformat unsupported ops file
if: ${{ inputs.model_scope != 'precommit' && !cancelled()}}
run: |
Expand Down
188 changes: 136 additions & 52 deletions src/frontends/pytorch/src/transforms/torchfx_gptq_pattern_replacer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,18 +40,6 @@ uint32_t read_u4_data(const void* array, size_t index) {
return val;
};

void write_u4_data(void* array, size_t index, uint32_t data) {
auto arr_u32 = reinterpret_cast<uint32_t*>(array);
size_t idx_u32 = index / 8;
size_t offset_u32 = index % 8;
uint32_t old_val = arr_u32[idx_u32];
data = data << (offset_u32 * 4);
uint32_t mask = 15;
mask = ~(mask << (offset_u32 * 4));
uint32_t new_val = (old_val & mask) | data;
arr_u32[idx_u32] = new_val;
};

GPTQDecompressionReplacer::GPTQDecompressionReplacer() {
const auto& const_1 = wrap_type<v0::Constant>();
const auto& const_2 = wrap_type<v0::Constant>();
Expand All @@ -73,61 +61,157 @@ GPTQDecompressionReplacer::GPTQDecompressionReplacer() {
const auto& convert_2 = wrap_type<v0::Convert>({const_6});
const auto& bitwise_and = wrap_type<ov::op::v13::BitwiseAnd>({add_or_convert, convert_2});

ov::matcher_pass_callback callback = [unsqueeze_1](Matcher& m) {
ov::matcher_pass_callback callback = [=](Matcher& m) {
auto bitwise_and = m.get_match_root();
if (!bitwise_and) {
return false;
}
const auto& pattern_map = m.get_pattern_value_map();
const auto& input_node = pattern_map.at(unsqueeze_1).get_node_shared_ptr();
auto weights_u32 = std::dynamic_pointer_cast<v0::Constant>(input_node->get_input_node_shared_ptr(0));
auto axis = std::dynamic_pointer_cast<v0::Constant>(input_node->get_input_node_shared_ptr(1));
auto axis_data = axis->get_data_ptr<uint32_t>();

auto u8_shape = weights_u32->get_shape();
auto src = weights_u32->get_data_ptr<uint32_t>();

ov::Shape u4_shape;
bool dim_added = false;
size_t stride = 1;
size_t size_y = 1;
for (size_t i = 0; i < u8_shape.size(); i++) {
if (axis_data[0] == i) {
u4_shape.push_back(8);
dim_added = true;
}
if (axis_data[0] <= i) {
stride *= u8_shape[i];
} else {
size_y *= u8_shape[i];
}
u4_shape.push_back(u8_shape[i]);
auto unsqueeze_1_node = pattern_map.at(unsqueeze_1).get_node_shared_ptr();
auto unsqueeze_1_in0_const =
std::dynamic_pointer_cast<v0::Constant>(unsqueeze_1_node->get_input_node_shared_ptr(0));
auto unsqueeze_1_in1_const =
std::dynamic_pointer_cast<v0::Constant>(unsqueeze_1_node->get_input_node_shared_ptr(1));
auto abs_node = pattern_map.at(abs).get_node_shared_ptr();
auto abs_in_const = std::dynamic_pointer_cast<v0::Constant>(abs_node->get_input_node_shared_ptr(0));
auto broadcast_node = pattern_map.at(broadcast).get_node_shared_ptr();
auto unsqueeze_2_node = pattern_map.at(unsqueeze_2).get_node_shared_ptr();
auto unsqueeze_2_in0_const =
std::dynamic_pointer_cast<v0::Constant>(unsqueeze_2_node->get_input_node_shared_ptr(0));
auto unsqueeze_2_in1_const =
std::dynamic_pointer_cast<v0::Constant>(unsqueeze_2_node->get_input_node_shared_ptr(1));

OutputVector outputs_1(unsqueeze_1_node->get_output_size());
OutputVector unsqueeze_1_inputs(2);
unsqueeze_1_inputs[0] = unsqueeze_1_in0_const->outputs()[0];
unsqueeze_1_inputs[1] = unsqueeze_1_in1_const->outputs()[0];
if (!unsqueeze_1_node->constant_fold(outputs_1, unsqueeze_1_inputs)) {
return false;
}
if (!dim_added) {
u4_shape.push_back(8);

OutputVector outputs_2(abs_node->get_output_size());
if (!abs_node->constant_fold(outputs_2, abs_in_const->outputs())) {
return false;
}

auto new_const = std::make_shared<v0::Constant>(element::u4, u4_shape);
auto dst = const_cast<uint32_t*>(reinterpret_cast<const uint32_t*>(new_const->get_data_ptr()));
OutputVector outputs_3(broadcast_node->get_output_size());
OutputVector broadcast_inputs(2);
broadcast_inputs[0] = outputs_1[0];
broadcast_inputs[1] = outputs_2[0];
if (!broadcast_node->constant_fold(outputs_3, broadcast_inputs)) {
return false;
}

OutputVector outputs_4(unsqueeze_2_node->get_output_size());
OutputVector unsqueeze_2_inputs(2);
unsqueeze_2_inputs[0] = unsqueeze_2_in0_const->outputs()[0];
unsqueeze_2_inputs[1] = unsqueeze_2_in1_const->outputs()[0];
if (!unsqueeze_2_node->constant_fold(outputs_4, unsqueeze_2_inputs)) {
return false;
}
const int32_t* rs_in0 =
std::dynamic_pointer_cast<v0::Constant>(outputs_3[0].get_node_shared_ptr())->get_data_ptr<int32_t>();
const int32_t* rs_in1 =
std::dynamic_pointer_cast<v0::Constant>(outputs_4[0].get_node_shared_ptr())->get_data_ptr<int32_t>();
auto shifted_const = std::make_shared<v0::Constant>(element::i32, outputs_3[0].get_shape());
auto dst = const_cast<int32_t*>(reinterpret_cast<const int32_t*>(shifted_const->get_data_ptr()));
if (!dst)
return false;

size_t in_idx = 0;
for (size_t y = 0; y < size_y; y++) {
size_t offset = y * stride * 8;
for (size_t x = 0; x < stride; x++) {
for (size_t z = 0; z < 8; z++) {
uint32_t val = read_u4_data(src, in_idx);
write_u4_data(dst, (offset + x + stride * z), val);
in_idx++;
}
// TODO: Bitwise right shift operation below might need to be
// optimized to reduce FIL.
size_t rs_in0_shape_size = shape_size(outputs_3[0].get_shape());
const auto& rs_in0_shape = outputs_3[0].get_shape();
const auto& rs_in1_shape = outputs_4[0].get_shape();
int shift_dim = -1;
size_t shift_offset = 1;
for (size_t i = 0; i < rs_in1_shape.size(); ++i) {
size_t dim = rs_in1_shape[i];
if (dim != 1 && dim != rs_in0_shape[i]) {
return false;
}
if (shift_dim != -1) {
shift_offset *= rs_in0_shape[i];
}
if (dim == rs_in0_shape[i]) {
shift_dim = static_cast<int>(i);
}
}
if (shift_dim == -1)
return false;
for (size_t k = 0; k < rs_in0_shape_size; ++k) {
size_t shift_idx = (k / shift_offset) % rs_in1_shape[shift_dim];
int32_t shift_val = rs_in1[shift_idx];
dst[k] = (rs_in0[k] >> shift_val);
}

std::shared_ptr<ov::Node> convert_1_node = nullptr;
OutputVector outputs_7;
if (pattern_map.find(convert_1) != pattern_map.end()) {
convert_1_node = pattern_map.at(convert_1).get_node_shared_ptr();
outputs_7.resize(convert_1_node->get_output_size());
if (!convert_1_node->constant_fold(outputs_7, shifted_const->outputs())) {
return false;
}
} else {
auto convert_3_node = pattern_map.at(convert_3).get_node_shared_ptr();
auto convert_4_node = pattern_map.at(convert_4).get_node_shared_ptr();
auto convert_4_in_const =
std::dynamic_pointer_cast<v0::Constant>(convert_4_node->get_input_node_shared_ptr(0));
auto add_node = pattern_map.at(add).get_node_shared_ptr();
OutputVector outputs_5(convert_3_node->get_output_size());
if (!convert_3_node->constant_fold(outputs_5, shifted_const->outputs())) {
return false;
}
OutputVector outputs_6(convert_4_node->get_output_size());
if (!convert_4_node->constant_fold(outputs_6, convert_4_in_const->outputs())) {
return false;
}
outputs_7.resize(add_node->get_output_size());
OutputVector add_inputs(2);
add_inputs[0] = outputs_5[0];
add_inputs[1] = outputs_6[0];
if (!add_node->constant_fold(outputs_7, add_inputs)) {
return false;
}
}

copy_runtime_info_and_name(weights_u32, {new_const}, {weights_u32, bitwise_and});
auto convert_2_node = pattern_map.at(convert_2).get_node_shared_ptr();
auto convert_2_in_const = std::dynamic_pointer_cast<v0::Constant>(convert_2_node->get_input_node_shared_ptr(0));

OutputVector outputs_8(convert_2_node->get_output_size());
if (!convert_2_node->constant_fold(outputs_8, convert_2_in_const->outputs())) {
return false;
}

OutputVector outputs_9(bitwise_and->get_output_size());

const int8_t* and_in0 =
std::dynamic_pointer_cast<v0::Constant>(outputs_7[0].get_node_shared_ptr())->get_data_ptr<int8_t>();
const int8_t* and_in1 =
std::dynamic_pointer_cast<v0::Constant>(outputs_8[0].get_node_shared_ptr())->get_data_ptr<int8_t>();
auto masked_const = std::make_shared<v0::Constant>(element::i8, outputs_7[0].get_shape());
auto masked_dst = const_cast<int8_t*>(reinterpret_cast<const int8_t*>(masked_const->get_data_ptr()));
if (!masked_dst)
return false;

size_t and_in0_shape_size = shape_size(outputs_7[0].get_shape());
// TODO: Bitwise and operation below might need to be
// optimized to reduce FIL.
int8_t mask = and_in1[0];
for (size_t k = 0; k < and_in0_shape_size; ++k) {
masked_dst[k] = (and_in0[k] & mask);
}

auto convert_to_u4 = std::make_shared<v0::Convert>(masked_const, element::u4);
OutputVector outputs_10(convert_to_u4->get_output_size());
if (!convert_to_u4->constant_fold(outputs_10, masked_const->outputs())) {
return false;
}

auto new_convert = std::make_shared<v0::Convert>(new_const, bitwise_and->get_output_element_type(0));
copy_runtime_info_and_name(bitwise_and, {new_convert}, {input_node});
auto new_convert =
std::make_shared<v0::Convert>(outputs_10[0].get_node_shared_ptr(), bitwise_and->get_output_element_type(0));
copy_runtime_info_and_name(bitwise_and, {new_convert}, {unsqueeze_1_node});
replace_node(bitwise_and, new_convert);
return true;
};
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
atorsvn/TinyLlama-1.1B-Chat-v0.3-gptq-4bit,https://huggingface.co/atorsvn/TinyLlama-1.1B-Chat-v0.3-gptq-4bit
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# Copyright (C) 2018-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM, pipeline
import torch
import hashlib
from openvino.frontend.pytorch.torchdynamo.execute import compiled_cache
import models_hub_common.utils as utils
import pytest
import os

def patch_gptq(config):
do_gptq_patching = False
config_dict = config.to_dict()
quantization_config = config_dict.get("quantization_config", None)
do_gptq_patching = quantization_config and quantization_config["quant_method"] == "gptq"
orig_cuda_check = torch.cuda.is_available
orig_post_init_model = None
if do_gptq_patching:
torch.set_default_dtype(torch.float32)
torch.cuda.is_available = lambda: False

from optimum.gptq import GPTQQuantizer

orig_post_init_model = GPTQQuantizer.post_init_model

def post_init_model(self, model):
from auto_gptq import exllama_set_max_input_length

class StoreAttr(object):
pass

model.quantize_config = StoreAttr()
model.quantize_config.desc_act = self.desc_act
if self.desc_act and not self.disable_exllama and self.max_input_length is not None:
model = exllama_set_max_input_length(model, self.max_input_length)
return model

GPTQQuantizer.post_init_model = post_init_model
return orig_cuda_check, orig_post_init_model

def run_gptq_torchfx(tmp_path, model_id, model_link, prompt_result_pair):
config = AutoConfig.from_pretrained(model_id, trust_remote_code=True, torch_dtype=torch.float32)
cuda, post_init = patch_gptq(config)
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True, torch_dtype=torch.float32)
model = AutoModelForCausalLM.from_pretrained(
model_id,
trust_remote_code=True,
config=config,
device_map='cpu',
torch_dtype=torch.float32
)

pipe = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
max_new_tokens=4,
do_sample=True,
temperature=0.01,
top_p=0.01,
top_k=1,
repetition_penalty=1.1,
num_beams=1,
)

prompt = prompt_result_pair["prompt"]
expected_md5 = prompt_result_pair["result_md5"]

model.model.forward = torch.compile(model.model.forward, backend="openvino", dynamic=True, fullgraph=True, options={'aot_autograd': True})

result_ov = pipe(prompt)
md5_ov = hashlib.new("md5", result_ov[0]['generated_text'].encode(), usedforsecurity=False).hexdigest()

u4_ops = ["FullyConnected",]
num_u4_ops = 0
num_u4_ops_supported = 0
for pid in compiled_cache:
for op in compiled_cache[pid].get_runtime_model().get_ordered_ops():
if (str(op.get_rt_info()["layerType"].get()) in u4_ops):
u4_exec = (str(op.get_rt_info()["runtimePrecision"].get()) == "u4")
if u4_exec:
num_u4_ops_supported += 1
num_u4_ops += 1

assert(expected_md5 == md5_ov), "Output does not match with the expected output"
assert((num_u4_ops > 0) and (num_u4_ops == num_u4_ops_supported)), "Runtime precision is not u4"

@pytest.mark.precommit
@pytest.mark.parametrize("model_name, model_link, mark, reason", utils.get_models_list(os.path.join(os.path.dirname(__file__), "models", "gptq-torchfx-models-precommit")))
@pytest.mark.parametrize('prompt_result_pair', ([
{"prompt" : "Tell me about AI", "result_md5" : "4385ccbce14627ae91f846b4c8a3f145"},
]))
def test_gptq_torchfx_precommit(tmp_path, model_name, model_link, mark, reason, prompt_result_pair, ie_device):
assert mark is None or mark == 'skip' or mark == 'xfail', \
"Incorrect test case: {}, {}".format(model_name, model_link)
if mark == 'skip':
pytest.skip(reason)
elif mark == 'xfail':
pytest.xfail(reason)
run_gptq_torchfx(tmp_path, model_name, model_link, prompt_result_pair)

0 comments on commit 43df0b6

Please sign in to comment.