Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Llama65B patch for int4 fp32 #1769

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 100 additions & 85 deletions apps/language_models/scripts/vicuna.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@
"--model_name",
type=str,
default="vicuna",
choices=["vicuna", "llama2_7b", "llama2_70b"],
choices=["vicuna", "llama_65b", "llama2_7b", "llama2_70b"],
help="Specify which model to run.",
)
parser.add_argument(
Expand Down Expand Up @@ -161,7 +161,7 @@ class VicunaBase(SharkLLMBase):
def __init__(
self,
model_name,
hf_model_path="TheBloke/vicuna-7B-1.1-HF",
hf_model_path="elinas/llama-65b-hf-transformers-4.29",
max_num_tokens=512,
device="cpu",
precision="int8",
Expand Down Expand Up @@ -433,7 +433,7 @@ class ShardedVicuna(VicunaBase):
def __init__(
self,
model_name,
hf_model_path="TheBloke/vicuna-7B-1.1-HF",
hf_model_path="elinas/llama-65b-hf-transformers-4.29",
max_num_tokens=512,
device="cuda",
precision="fp32",
Expand Down Expand Up @@ -1212,7 +1212,7 @@ class UnshardedVicuna(VicunaBase):
def __init__(
self,
model_name,
hf_model_path="TheBloke/vicuna-7B-1.1-HF",
hf_model_path="elinas/llama-65b-hf-transformers-4.29",
hf_auth_token: str = None,
max_num_tokens=512,
device="cpu",
Expand All @@ -1232,7 +1232,9 @@ def __init__(
"HF auth token required. Pass it using --hf_auth_token flag."
)
self.hf_auth_token = hf_auth_token
if self.model_name == "llama2_7b":
if self.model_name == "llama_65b":
self.hf_model_path = "elinas/llama-65b-hf-transformers-4.29"
elif self.model_name == "llama2_7b":
self.hf_model_path = "meta-llama/Llama-2-7b-chat-hf"
elif self.model_name == "llama2_70b":
self.hf_model_path = "meta-llama/Llama-2-70b-chat-hf"
Expand Down Expand Up @@ -1423,21 +1425,21 @@ def compile(self, download_vmfb=False):
else:
compilation_prompt = "".join(["0" for _ in range(17)])

if Path(f"first_{self.precision}.mlir").exists():
print(f"loading first_{self.precision}.mlir")
with open(Path(f"first_{self.precision}.mlir"), "r") as f:
first_module = f.read()
if Path(f"second_{self.precision}.mlir").exists():
print(f"loading second_{self.precision}.mlir")
with open(Path(f"second_{self.precision}.mlir"), "r") as f:
second_module = f.read()
else:
# generate first vicuna
compilation_input_ids = self.tokenizer(
compilation_prompt,
return_tensors="pt",
).input_ids
compilation_input_ids = torch.tensor(
compilation_input_ids
).reshape([1, 19])
firstVicunaCompileInput = (compilation_input_ids,)
model = FirstVicuna(
# generate second vicuna
compilation_input_ids = torch.zeros(
[1, 1], dtype=torch.int64
)
pkv = tuple(
(torch.zeros([1, 64, 19, 128], dtype=torch.float32))
for _ in range(160)
)
secondVicunaCompileInput = (compilation_input_ids,) + pkv
model = SecondVicuna(
self.hf_model_path,
self.precision,
self.weight_group_size,
Expand All @@ -1447,27 +1449,33 @@ def compile(self, download_vmfb=False):
print(f"[DEBUG] generating torchscript graph")
ts_graph = import_with_fx(
model,
firstVicunaCompileInput,
secondVicunaCompileInput,
is_f16=self.precision == "fp16",
precision=self.precision,
f16_input_mask=[False, False],
f16_input_mask=[False] + [True] * 160,
mlir_type="torchscript",
)
del model
firstVicunaCompileInput = list(firstVicunaCompileInput)
firstVicunaCompileInput[
0
] = torch_mlir.TensorPlaceholder.like(
firstVicunaCompileInput[0], dynamic_axes=[1]
)

firstVicunaCompileInput = tuple(firstVicunaCompileInput)
first_module = None
if self.precision == "fp16":
secondVicunaCompileInput = get_f16_inputs(
secondVicunaCompileInput,
True,
f16_input_mask=[False] + [True] * 160,
)
secondVicunaCompileInput = list(secondVicunaCompileInput)
for i in range(len(secondVicunaCompileInput)):
if i != 0:
secondVicunaCompileInput[
i
] = torch_mlir.TensorPlaceholder.like(
secondVicunaCompileInput[i], dynamic_axes=[2]
)
secondVicunaCompileInput = tuple(secondVicunaCompileInput)
print(f"[DEBUG] generating torch mlir")
if self.precision in ["int4", "int8"]:
first_module = torch_mlir.compile(
second_module = torch_mlir.compile(
ts_graph,
[*firstVicunaCompileInput],
[*secondVicunaCompileInput],
output_type=torch_mlir.OutputType.TORCH,
backend_legal_ops=[
"brevitas.matmul_rhs_group_quant"
Expand All @@ -1478,47 +1486,58 @@ def compile(self, download_vmfb=False):
)
print(f"[DEBUG] converting torch to linalg")
run_pipeline_with_repro_report(
first_module,
second_module,
"builtin.module(func.func(torch-unpack-torch-tensor),torch-backend-to-linalg-on-tensors-backend-pipeline)",
description="Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR",
)
else:
first_module = torch_mlir.compile(
second_module = torch_mlir.compile(
ts_graph,
[*firstVicunaCompileInput],
[*secondVicunaCompileInput],
torch_mlir.OutputType.LINALG_ON_TENSORS,
use_tracing=False,
verbose=False,
)
from contextlib import redirect_stdout
print("Writing : second_llama_65b_linalg_ir_before_dynamic ELIDED")
with open('second_llama_65b_linalg_ir_before_dynamic_elided.mlir', 'w') as f:
with redirect_stdout(f):
print(second_module.operation.get_asm(large_elements_limit=4))
print("FINISHED")
del ts_graph
del firstVicunaCompileInput
del secondVicunaCompileInput
gc.collect()

print(
"[DEBUG] successfully generated first vicuna linalg mlir"
"[DEBUG] successfully generated second vicuna linalg mlir"
)
first_module = self.write_in_dynamic_inputs0(
str(first_module), dynamic_input_size=19
second_module = self.write_in_dynamic_inputs1(
str(second_module)
)
if self.cache_vicunas:
with open(f"first_{self.precision}.mlir", "w+") as f:
f.write(first_module)
print("Writing : second_llama_65b_linalg_ir_after_dynamic ELIDED")
with open('second_llama_65b_linalg_ir_after_dynamic_elided.mlir', 'w') as f:
with redirect_stdout(f):
print(second_module.operation.get_asm(large_elements_limit=4))
print("FINISHED")
# if self.cache_vicunas:
print("Writing : second_llama_65b_linalg_ir_after_dynamic")
with open(f"second_{self.precision}.mlir", "w+") as f:
f.write(second_module)

if Path(f"second_{self.precision}.mlir").exists():
print(f"loading second_{self.precision}.mlir")
with open(Path(f"second_{self.precision}.mlir"), "r") as f:
second_module = f.read()
if Path(f"first_{self.precision}.mlir").exists():
print(f"loading first_{self.precision}.mlir")
with open(Path(f"first_{self.precision}.mlir"), "r") as f:
first_module = f.read()
else:
# generate second vicuna
compilation_input_ids = torch.zeros(
[1, 1], dtype=torch.int64
)
pkv = tuple(
(torch.zeros([1, 32, 19, 128], dtype=torch.float32))
for _ in range(64)
)
secondVicunaCompileInput = (compilation_input_ids,) + pkv
model = SecondVicuna(
# generate first vicuna
compilation_input_ids = self.tokenizer(
compilation_prompt,
return_tensors="pt",
).input_ids
compilation_input_ids = torch.tensor(
compilation_input_ids
).reshape([1, 19])
firstVicunaCompileInput = (compilation_input_ids,)
model = FirstVicuna(
self.hf_model_path,
self.precision,
self.weight_group_size,
Expand All @@ -1528,33 +1547,27 @@ def compile(self, download_vmfb=False):
print(f"[DEBUG] generating torchscript graph")
ts_graph = import_with_fx(
model,
secondVicunaCompileInput,
firstVicunaCompileInput,
is_f16=self.precision == "fp16",
precision=self.precision,
f16_input_mask=[False] + [True] * 64,
f16_input_mask=[False, False],
mlir_type="torchscript",
)
del model
if self.precision == "fp16":
secondVicunaCompileInput = get_f16_inputs(
secondVicunaCompileInput,
True,
f16_input_mask=[False] + [True] * 64,
)
secondVicunaCompileInput = list(secondVicunaCompileInput)
for i in range(len(secondVicunaCompileInput)):
if i != 0:
secondVicunaCompileInput[
i
] = torch_mlir.TensorPlaceholder.like(
secondVicunaCompileInput[i], dynamic_axes=[2]
)
secondVicunaCompileInput = tuple(secondVicunaCompileInput)
firstVicunaCompileInput = list(firstVicunaCompileInput)
firstVicunaCompileInput[
0
] = torch_mlir.TensorPlaceholder.like(
firstVicunaCompileInput[0], dynamic_axes=[1]
)

firstVicunaCompileInput = tuple(firstVicunaCompileInput)
first_module = None
print(f"[DEBUG] generating torch mlir")
if self.precision in ["int4", "int8"]:
second_module = torch_mlir.compile(
first_module = torch_mlir.compile(
ts_graph,
[*secondVicunaCompileInput],
[*firstVicunaCompileInput],
output_type=torch_mlir.OutputType.TORCH,
backend_legal_ops=[
"brevitas.matmul_rhs_group_quant"
Expand All @@ -1565,30 +1578,31 @@ def compile(self, download_vmfb=False):
)
print(f"[DEBUG] converting torch to linalg")
run_pipeline_with_repro_report(
second_module,
first_module,
"builtin.module(func.func(torch-unpack-torch-tensor),torch-backend-to-linalg-on-tensors-backend-pipeline)",
description="Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR",
)
else:
second_module = torch_mlir.compile(
first_module = torch_mlir.compile(
ts_graph,
[*secondVicunaCompileInput],
[*firstVicunaCompileInput],
torch_mlir.OutputType.LINALG_ON_TENSORS,
use_tracing=False,
verbose=False,
)
del ts_graph
del secondVicunaCompileInput
del firstVicunaCompileInput
gc.collect()

print(
"[DEBUG] successfully generated second vicuna linalg mlir"
"[DEBUG] successfully generated first vicuna linalg mlir"
)
second_module = self.write_in_dynamic_inputs1(
str(second_module)
first_module = self.write_in_dynamic_inputs0(
str(first_module), dynamic_input_size=19
)
if self.cache_vicunas:
with open(f"second_{self.precision}.mlir", "w+") as f:
f.write(second_module)
with open(f"first_{self.precision}.mlir", "w+") as f:
f.write(first_module)

combined_module = self.combine_mlir_scripts(
first_module, second_module, self.vicuna_mlir_path
Expand Down Expand Up @@ -1752,6 +1766,7 @@ def autocomplete(self, prompt):

model_list = {
"vicuna": "vicuna=>TheBloke/vicuna-7B-1.1-HF",
"llama_65b": "elinas/llama-65b-hf-transformers-4.29",
"llama2_7b": "llama2_7b=>meta-llama/Llama-2-7b-chat-hf",
"llama2_70b": "llama2_70b=>meta-llama/Llama-2-70b-chat-hf",
}
Expand Down
Loading