Skip to content

Commit

Permalink
Seperates out runtime (nod-ai#237)
Browse files Browse the repository at this point in the history
- Adds model_runner.py for setting up iree runtime config to reduce
repeated code
- Creates runner script for llama, clip, unet, and vae to cleanup and
simplify the larger scripts
  • Loading branch information
IanNod authored Dec 13, 2023
1 parent 10c8ed1 commit 115c667
Show file tree
Hide file tree
Showing 13 changed files with 720 additions and 543 deletions.
138 changes: 138 additions & 0 deletions python/turbine_models/custom_models/llm_runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
import argparse
from turbine_models.model_runner import vmfbRunner
from transformers import AutoTokenizer
from iree import runtime as ireert
import torch

parser = argparse.ArgumentParser()

# TODO move common runner flags to generic flag file
parser.add_argument(
"--vmfb_path", type=str, default="", help="path to vmfb containing compiled module"
)
parser.add_argument(
"--external_weight_path",
type=str,
default="",
help="path to external weight parameters if model compiled without them",
)
parser.add_argument(
"--compare_vs_torch",
action="store_true",
help="Runs both turbine vmfb and a torch model to compare results",
)
parser.add_argument(
"--hf_model_name",
type=str,
help="HF model name",
default="meta-llama/Llama-2-7b-chat-hf",
)
parser.add_argument(
"--hf_auth_token",
type=str,
help="The Hugging face auth token, required for some models",
)
parser.add_argument(
"--device",
type=str,
default="local-task",
help="local-sync, local-task, cuda, vulkan, rocm",
)
parser.add_argument(
"--prompt",
type=str,
default="""<s>[INST] <<SYS>>
Be concise. You are a helpful, respectful and honest assistant. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information. <</SYS>> hi what are you? [/INST]
""",
help="prompt for llm model",
)


def run_llm(
device, prompt, vmfb_path, hf_model_name, hf_auth_token, external_weight_path
):
runner = vmfbRunner(
device=device, vmfb_path=vmfb_path, external_weight_path=external_weight_path
)

tokenizer = AutoTokenizer.from_pretrained(
hf_model_name,
use_fast=False,
token=hf_auth_token,
)
initial_input = tokenizer(prompt, return_tensors="pt")
example_input_id = initial_input.input_ids
inputs = [ireert.asdevicearray(runner.config.device, example_input_id)]
results = runner.ctx.modules.state_update["run_initialize"](
*inputs
) # example_input_id)

def format_out(results):
return torch.tensor(results.to_host()[0][0])

turbine_results = []
turbine_results.append(format_out(results))
while format_out(results) != 2:
results = runner.ctx.modules.state_update["run_forward"](results)
# uncomment to see tokens as they are emitted
# print(f"turbine: {tokenizer.decode(format_out(results))}")
turbine_results.append(format_out(results))

return tokenizer.decode(turbine_results)


def run_torch_llm(hf_model_name, hf_auth_token, prompt):
from turbine_models.model_builder import HFTransformerBuilder
from transformers import AutoModelForCausalLM

model_builder = HFTransformerBuilder(
example_input=None,
hf_id=hf_model_name,
auto_model=AutoModelForCausalLM,
hf_auth_token=hf_auth_token,
auto_tokenizer=AutoTokenizer,
)
model_builder.build_model()

def get_token_from_logits(logits):
return torch.argmax(logits[:, -1, :], dim=1)

initial_input = model_builder.tokenizer(prompt, return_tensors="pt")
example_input_id = initial_input.input_ids

model_results = model_builder.model.forward(example_input_id)
model_token = get_token_from_logits(model_results.logits)

pkv = model_results.past_key_values

torch_results = []
torch_results.append(int(model_token))
while model_token != 2:
model_results = model_builder.model.forward(
torch.unsqueeze(model_token, 0), past_key_values=pkv
)
model_token = get_token_from_logits(model_results.logits)
pkv = model_results.past_key_values
torch_results.append(int(model_token[0]))

return model_builder.tokenizer.decode(torch_results)


if __name__ == "__main__":
args = parser.parse_args()
print("generating turbine output: ")
turbine_output = run_llm(
args.device,
args.prompt,
args.vmfb_path,
args.hf_model_name,
args.hf_auth_token,
args.external_weight_path,
)
print(turbine_output)
if args.compare_vs_torch:
print("generating torch output: ")
torch_output = run_torch_llm(
args.hf_model_name, args.hf_auth_token, args.prompt
)
print(torch_output)
114 changes: 18 additions & 96 deletions python/turbine_models/custom_models/sd_inference/clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,8 @@
help="HF model name",
default="CompVis/stable-diffusion-v1-4",
)
parser.add_argument("--run_vmfb", action="store_true")
parser.add_argument("--compile_to", type=str, help="torch, linalg, vmfb")
parser.add_argument("--external_weight_file", type=str, default="")
parser.add_argument("--vmfb_path", type=str, default="")
parser.add_argument("--external_weight_path", type=str, default="")
parser.add_argument(
"--external_weights",
type=str,
Expand All @@ -49,15 +47,13 @@
)
parser.add_argument("--vulkan_max_allocation", type=str, default="4294967296")

prompt = ["a photograph of an astronaut riding a horse"]


def export_clip_model(
hf_model_name,
hf_auth_token=None,
compile_to="torch",
external_weights=None,
external_weight_file=None,
external_weight_path=None,
device=None,
target_triple=None,
max_alloc=None,
Expand All @@ -76,7 +72,7 @@ def export_clip_model(

mapper = {}
utils.save_external_weights(
mapper, text_encoder_model, external_weights, external_weight_file
mapper, text_encoder_model, external_weights, external_weight_path
)

class CompiledClip(CompiledModule):
Expand Down Expand Up @@ -104,94 +100,20 @@ def main(self, inp=AbstractTensor(1, 77, dtype=torch.int64)):
utils.compile_to_vmfb(module_str, device, target_triple, max_alloc, safe_name)


def run_clip_vmfb_comparison(args):
config = ireert.Config(args.device)

if args.external_weight_file:
index = ireert.ParameterIndex()
index.load(args.external_weight_file)

safe_name = utils.create_safe_name(args.hf_model_name, "-clip")
if args.vmfb_path:
mod = ireert.VmModule.mmap(config.vm_instance, args.vmfb_path)
elif os.path.exists(f"{safe_name}.vmfb"):
mod = ireert.VmModule.mmap(config.vm_instance, f"{safe_name}.vmfb")
else:
sys.exit("no vmfb_path provided, required for run_vmfb")

vm_modules = [
mod,
ireert.create_hal_module(config.vm_instance, config.device),
]
if args.external_weight_file:
param_module = ireert.create_io_parameters_module(
config.vm_instance, index.create_provider(scope="model")
)
vm_modules.insert(0, param_module)

ctx = ireert.SystemContext(
vm_modules=vm_modules,
config=config,
)
tokenizer = CLIPTokenizer.from_pretrained(
args.hf_model_name,
subfolder="tokenizer",
token=args.hf_auth_token,
)
text_input = tokenizer(
prompt,
padding="max_length",
max_length=tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
inp = text_input.input_ids
device_inputs = [ireert.asdevicearray(config.device, inp)]

# Turbine output
ModuleCompiled = ctx.modules.compiled_clip
turbine_outputs = ModuleCompiled["main"](*device_inputs)
turbine_output = turbine_outputs[0]
print(
"TURBINE OUTPUT:",
turbine_output.to_host(),
turbine_output.to_host().shape,
turbine_output.to_host().dtype,
)

# Torch output
text_encoder_model = CLIPTextModel.from_pretrained(
args.hf_model_name,
subfolder="text_encoder",
token=args.hf_auth_token,
)
torch_output = text_encoder_model.forward(inp)[0]
np_torch_output = torch_output.detach().cpu().numpy()
print(
"TORCH OUTPUT:", np_torch_output, np_torch_output.shape, np_torch_output.dtype
)

err = utils.largest_error(np_torch_output, turbine_output)
print("LARGEST ERROR:", err)
assert err < 9e-5


if __name__ == "__main__":
args = parser.parse_args()
if args.run_vmfb:
run_clip_vmfb_comparison(args)
else:
mod_str, _ = export_clip_model(
args.hf_model_name,
args.hf_auth_token,
args.compile_to,
args.external_weights,
args.external_weight_file,
args.device,
args.iree_target_triple,
args.vulkan_max_allocation,
)
safe_name = utils.create_safe_name(args.hf_model_name, "-clip")
with open(f"{safe_name}.mlir", "w+") as f:
f.write(mod_str)
print("Saved to", safe_name + ".mlir")
mod_str, _ = export_clip_model(
args.hf_model_name,
args.hf_auth_token,
args.compile_to,
args.external_weights,
args.external_weight_path,
args.device,
args.iree_target_triple,
args.vulkan_max_allocation,
)
safe_name = args.hf_model_name.split("/")[-1].strip()
safe_name = re.sub("-", "_", safe_name)
with open(f"{safe_name}.mlir", "w+") as f:
f.write(mod_str)
print("Saved to", safe_name + ".mlir")
Loading

0 comments on commit 115c667

Please sign in to comment.