From 115c6678d475d86c3007118891c73541c47fcfa5 Mon Sep 17 00:00:00 2001 From: IanNod <45800100+IanNod@users.noreply.github.com> Date: Wed, 13 Dec 2023 15:30:03 -0800 Subject: [PATCH] Seperates out runtime (#237) - Adds model_runner.py for setting up iree runtime config to reduce repeated code - Creates runner script for llama, clip, unet, and vae to cleanup and simplify the larger scripts --- .../custom_models/llm_runner.py | 138 +++++++++++ .../custom_models/sd_inference/clip.py | 114 ++------- .../custom_models/sd_inference/clip_runner.py | 130 ++++++++++ .../custom_models/sd_inference/unet.py | 111 ++------- .../custom_models/sd_inference/unet_runner.py | 151 ++++++++++++ .../custom_models/sd_inference/vae.py | 99 ++------ .../custom_models/sd_inference/vae_runner.py | 120 ++++++++++ .../custom_models/stateless_llama.py | 13 - python/turbine_models/model_builder.py | 4 +- python/turbine_models/model_runner.py | 30 +++ python/turbine_models/tests/sd_test.py | 93 ++++++- .../tests/stateless_llama_test.py | 34 +-- .../turbine_models/tests/vmfb_comparison.py | 226 ------------------ 13 files changed, 720 insertions(+), 543 deletions(-) create mode 100644 python/turbine_models/custom_models/llm_runner.py create mode 100644 python/turbine_models/custom_models/sd_inference/clip_runner.py create mode 100644 python/turbine_models/custom_models/sd_inference/unet_runner.py create mode 100644 python/turbine_models/custom_models/sd_inference/vae_runner.py create mode 100644 python/turbine_models/model_runner.py delete mode 100644 python/turbine_models/tests/vmfb_comparison.py diff --git a/python/turbine_models/custom_models/llm_runner.py b/python/turbine_models/custom_models/llm_runner.py new file mode 100644 index 000000000..f3e84acc8 --- /dev/null +++ b/python/turbine_models/custom_models/llm_runner.py @@ -0,0 +1,138 @@ +import argparse +from turbine_models.model_runner import vmfbRunner +from transformers import AutoTokenizer +from iree import runtime as ireert +import torch + +parser = argparse.ArgumentParser() + +# TODO move common runner flags to generic flag file +parser.add_argument( + "--vmfb_path", type=str, default="", help="path to vmfb containing compiled module" +) +parser.add_argument( + "--external_weight_path", + type=str, + default="", + help="path to external weight parameters if model compiled without them", +) +parser.add_argument( + "--compare_vs_torch", + action="store_true", + help="Runs both turbine vmfb and a torch model to compare results", +) +parser.add_argument( + "--hf_model_name", + type=str, + help="HF model name", + default="meta-llama/Llama-2-7b-chat-hf", +) +parser.add_argument( + "--hf_auth_token", + type=str, + help="The Hugging face auth token, required for some models", +) +parser.add_argument( + "--device", + type=str, + default="local-task", + help="local-sync, local-task, cuda, vulkan, rocm", +) +parser.add_argument( + "--prompt", + type=str, + default="""[INST] <> +Be concise. You are a helpful, respectful and honest assistant. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information. <> hi what are you? [/INST] +""", + help="prompt for llm model", +) + + +def run_llm( + device, prompt, vmfb_path, hf_model_name, hf_auth_token, external_weight_path +): + runner = vmfbRunner( + device=device, vmfb_path=vmfb_path, external_weight_path=external_weight_path + ) + + tokenizer = AutoTokenizer.from_pretrained( + hf_model_name, + use_fast=False, + token=hf_auth_token, + ) + initial_input = tokenizer(prompt, return_tensors="pt") + example_input_id = initial_input.input_ids + inputs = [ireert.asdevicearray(runner.config.device, example_input_id)] + results = runner.ctx.modules.state_update["run_initialize"]( + *inputs + ) # example_input_id) + + def format_out(results): + return torch.tensor(results.to_host()[0][0]) + + turbine_results = [] + turbine_results.append(format_out(results)) + while format_out(results) != 2: + results = runner.ctx.modules.state_update["run_forward"](results) + # uncomment to see tokens as they are emitted + # print(f"turbine: {tokenizer.decode(format_out(results))}") + turbine_results.append(format_out(results)) + + return tokenizer.decode(turbine_results) + + +def run_torch_llm(hf_model_name, hf_auth_token, prompt): + from turbine_models.model_builder import HFTransformerBuilder + from transformers import AutoModelForCausalLM + + model_builder = HFTransformerBuilder( + example_input=None, + hf_id=hf_model_name, + auto_model=AutoModelForCausalLM, + hf_auth_token=hf_auth_token, + auto_tokenizer=AutoTokenizer, + ) + model_builder.build_model() + + def get_token_from_logits(logits): + return torch.argmax(logits[:, -1, :], dim=1) + + initial_input = model_builder.tokenizer(prompt, return_tensors="pt") + example_input_id = initial_input.input_ids + + model_results = model_builder.model.forward(example_input_id) + model_token = get_token_from_logits(model_results.logits) + + pkv = model_results.past_key_values + + torch_results = [] + torch_results.append(int(model_token)) + while model_token != 2: + model_results = model_builder.model.forward( + torch.unsqueeze(model_token, 0), past_key_values=pkv + ) + model_token = get_token_from_logits(model_results.logits) + pkv = model_results.past_key_values + torch_results.append(int(model_token[0])) + + return model_builder.tokenizer.decode(torch_results) + + +if __name__ == "__main__": + args = parser.parse_args() + print("generating turbine output: ") + turbine_output = run_llm( + args.device, + args.prompt, + args.vmfb_path, + args.hf_model_name, + args.hf_auth_token, + args.external_weight_path, + ) + print(turbine_output) + if args.compare_vs_torch: + print("generating torch output: ") + torch_output = run_torch_llm( + args.hf_model_name, args.hf_auth_token, args.prompt + ) + print(torch_output) diff --git a/python/turbine_models/custom_models/sd_inference/clip.py b/python/turbine_models/custom_models/sd_inference/clip.py index 4b640617f..996d5fb83 100644 --- a/python/turbine_models/custom_models/sd_inference/clip.py +++ b/python/turbine_models/custom_models/sd_inference/clip.py @@ -29,10 +29,8 @@ help="HF model name", default="CompVis/stable-diffusion-v1-4", ) -parser.add_argument("--run_vmfb", action="store_true") parser.add_argument("--compile_to", type=str, help="torch, linalg, vmfb") -parser.add_argument("--external_weight_file", type=str, default="") -parser.add_argument("--vmfb_path", type=str, default="") +parser.add_argument("--external_weight_path", type=str, default="") parser.add_argument( "--external_weights", type=str, @@ -49,15 +47,13 @@ ) parser.add_argument("--vulkan_max_allocation", type=str, default="4294967296") -prompt = ["a photograph of an astronaut riding a horse"] - def export_clip_model( hf_model_name, hf_auth_token=None, compile_to="torch", external_weights=None, - external_weight_file=None, + external_weight_path=None, device=None, target_triple=None, max_alloc=None, @@ -76,7 +72,7 @@ def export_clip_model( mapper = {} utils.save_external_weights( - mapper, text_encoder_model, external_weights, external_weight_file + mapper, text_encoder_model, external_weights, external_weight_path ) class CompiledClip(CompiledModule): @@ -104,94 +100,20 @@ def main(self, inp=AbstractTensor(1, 77, dtype=torch.int64)): utils.compile_to_vmfb(module_str, device, target_triple, max_alloc, safe_name) -def run_clip_vmfb_comparison(args): - config = ireert.Config(args.device) - - if args.external_weight_file: - index = ireert.ParameterIndex() - index.load(args.external_weight_file) - - safe_name = utils.create_safe_name(args.hf_model_name, "-clip") - if args.vmfb_path: - mod = ireert.VmModule.mmap(config.vm_instance, args.vmfb_path) - elif os.path.exists(f"{safe_name}.vmfb"): - mod = ireert.VmModule.mmap(config.vm_instance, f"{safe_name}.vmfb") - else: - sys.exit("no vmfb_path provided, required for run_vmfb") - - vm_modules = [ - mod, - ireert.create_hal_module(config.vm_instance, config.device), - ] - if args.external_weight_file: - param_module = ireert.create_io_parameters_module( - config.vm_instance, index.create_provider(scope="model") - ) - vm_modules.insert(0, param_module) - - ctx = ireert.SystemContext( - vm_modules=vm_modules, - config=config, - ) - tokenizer = CLIPTokenizer.from_pretrained( - args.hf_model_name, - subfolder="tokenizer", - token=args.hf_auth_token, - ) - text_input = tokenizer( - prompt, - padding="max_length", - max_length=tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - inp = text_input.input_ids - device_inputs = [ireert.asdevicearray(config.device, inp)] - - # Turbine output - ModuleCompiled = ctx.modules.compiled_clip - turbine_outputs = ModuleCompiled["main"](*device_inputs) - turbine_output = turbine_outputs[0] - print( - "TURBINE OUTPUT:", - turbine_output.to_host(), - turbine_output.to_host().shape, - turbine_output.to_host().dtype, - ) - - # Torch output - text_encoder_model = CLIPTextModel.from_pretrained( - args.hf_model_name, - subfolder="text_encoder", - token=args.hf_auth_token, - ) - torch_output = text_encoder_model.forward(inp)[0] - np_torch_output = torch_output.detach().cpu().numpy() - print( - "TORCH OUTPUT:", np_torch_output, np_torch_output.shape, np_torch_output.dtype - ) - - err = utils.largest_error(np_torch_output, turbine_output) - print("LARGEST ERROR:", err) - assert err < 9e-5 - - if __name__ == "__main__": args = parser.parse_args() - if args.run_vmfb: - run_clip_vmfb_comparison(args) - else: - mod_str, _ = export_clip_model( - args.hf_model_name, - args.hf_auth_token, - args.compile_to, - args.external_weights, - args.external_weight_file, - args.device, - args.iree_target_triple, - args.vulkan_max_allocation, - ) - safe_name = utils.create_safe_name(args.hf_model_name, "-clip") - with open(f"{safe_name}.mlir", "w+") as f: - f.write(mod_str) - print("Saved to", safe_name + ".mlir") + mod_str, _ = export_clip_model( + args.hf_model_name, + args.hf_auth_token, + args.compile_to, + args.external_weights, + args.external_weight_path, + args.device, + args.iree_target_triple, + args.vulkan_max_allocation, + ) + safe_name = args.hf_model_name.split("/")[-1].strip() + safe_name = re.sub("-", "_", safe_name) + with open(f"{safe_name}.mlir", "w+") as f: + f.write(mod_str) + print("Saved to", safe_name + ".mlir") diff --git a/python/turbine_models/custom_models/sd_inference/clip_runner.py b/python/turbine_models/custom_models/sd_inference/clip_runner.py new file mode 100644 index 000000000..b7f046e2e --- /dev/null +++ b/python/turbine_models/custom_models/sd_inference/clip_runner.py @@ -0,0 +1,130 @@ +import argparse +from turbine_models.model_runner import vmfbRunner +from transformers import CLIPTokenizer +from iree import runtime as ireert +import torch + +parser = argparse.ArgumentParser() + +# TODO move common runner flags to generic flag file +parser.add_argument( + "--vmfb_path", type=str, default="", help="path to vmfb containing compiled module" +) +parser.add_argument( + "--external_weight_path", + type=str, + default="", + help="path to external weight parameters if model compiled without them", +) +parser.add_argument( + "--compare_vs_torch", + action="store_true", + help="Runs both turbine vmfb and a torch model to compare results", +) +parser.add_argument( + "--hf_model_name", + type=str, + help="HF model name", + default="CompVis/stable-diffusion-v1-4", +) +parser.add_argument( + "--hf_auth_token", + type=str, + help="The Hugging face auth token, required for some models", +) +parser.add_argument( + "--device", + type=str, + default="local-task", + help="local-sync, local-task, cuda, vulkan, rocm", +) + +parser.add_argument( + "--prompt", + type=str, + default="a photograph of an astronaut riding a horse", + help="prompt for clip model", +) + + +def run_clip( + device, prompt, vmfb_path, hf_model_name, hf_auth_token, external_weight_path +): + runner = vmfbRunner(device, vmfb_path, external_weight_path) + + tokenizer = CLIPTokenizer.from_pretrained( + hf_model_name, + subfolder="tokenizer", + token=hf_auth_token, + ) + text_input = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + example_input = text_input.input_ids + inp = [ireert.asdevicearray(runner.config.device, example_input)] + + results = runner.ctx.modules.compiled_clip["main"](*inp) + return results + + +def run_torch_clip(hf_model_name, hf_auth_token, prompt): + # TODO: Integrate with HFTransformerBuilder + from transformers import CLIPTextModel + + model = CLIPTextModel.from_pretrained( + hf_model_name, + subfolder="text_encoder", + token=hf_auth_token, + ) + tokenizer = CLIPTokenizer.from_pretrained( + hf_model_name, + subfolder="tokenizer", + token=hf_auth_token, + ) + text_input = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + example_input = text_input.input_ids + + results = model.forward(example_input)[0] + np_torch_output = results.detach().cpu().numpy() + return np_torch_output + + +if __name__ == "__main__": + args = parser.parse_args() + turbine_output = run_clip( + args.device, + args.prompt, + args.vmfb_path, + args.hf_model_name, + args.hf_auth_token, + args.external_weight_path, + ) + print( + "TURBINE OUTPUT:", + turbine_output[0].to_host(), + turbine_output[0].to_host().shape, + turbine_output[0].to_host().dtype, + ) + if args.compare_vs_torch: + print("generating torch output: ") + from turbine_models.custom_models.sd_inference import utils + + torch_output = run_torch_clip( + args.hf_model_name, args.hf_auth_token, args.prompt + ) + print("TORCH OUTPUT:", torch_output, torch_output.shape, torch_output.dtype) + err = utils.largest_error(torch_output, turbine_output[0]) + print("Largest Error: ", err) + assert err < 9e-5 + # TODO: Figure out why we occasionally segfault without unlinking output variables + turbine_output = None diff --git a/python/turbine_models/custom_models/sd_inference/unet.py b/python/turbine_models/custom_models/sd_inference/unet.py index 3372e3e05..4045b572e 100644 --- a/python/turbine_models/custom_models/sd_inference/unet.py +++ b/python/turbine_models/custom_models/sd_inference/unet.py @@ -36,10 +36,8 @@ "--height", type=int, default=512, help="Height of Stable Diffusion" ) parser.add_argument("--width", type=int, default=512, help="Width of Stable Diffusion") -parser.add_argument("--run_vmfb", action="store_true") parser.add_argument("--compile_to", type=str, help="torch, linalg, vmfb") parser.add_argument("--external_weight_file", type=str, default="") -parser.add_argument("--vmfb_path", type=str, default="") parser.add_argument( "--external_weights", type=str, @@ -133,100 +131,27 @@ def main( utils.compile_to_vmfb(module_str, device, target_triple, max_alloc, safe_name) -def run_unet_vmfb_comparison(unet_model, args): - config = ireert.Config(args.device) - - if args.external_weight_file: - index = ireert.ParameterIndex() - index.load(args.external_weight_file) - - safe_name = utils.create_safe_name(args.hf_model_name, "-unet") - if args.vmfb_path: - mod = ireert.VmModule.mmap(config.vm_instance, args.vmfb_path) - elif os.path.exists(f"{safe_name}.vmfb"): - mod = ireert.VmModule.mmap(config.vm_instance, f"{safe_name}.vmfb") - else: - sys.exit("no vmfb_path provided, required for run_vmfb") - - vm_modules = [ - mod, - ireert.create_hal_module(config.vm_instance, config.device), - ] - if args.external_weight_file: - param_module = ireert.create_io_parameters_module( - config.vm_instance, index.create_provider(scope="model") - ) - vm_modules.insert(0, param_module) - - ctx = ireert.SystemContext( - vm_modules=vm_modules, - config=config, - ) - sample = torch.rand( - args.batch_size, - unet_model.unet.in_channels, - args.height // 8, - args.width // 8, - dtype=torch.float32, - ) - timestep = torch.zeros(1, dtype=torch.float32) - if args.hf_model_name == "CompVis/stable-diffusion-v1-4": - encoder_hidden_states = torch.rand(2, 77, 768, dtype=torch.float32) - elif args.hf_model_name == "stabilityai/stable-diffusion-2-1-base": - encoder_hidden_states = torch.rand(2, 77, 1024, dtype=torch.float32) - - device_inputs = [ - ireert.asdevicearray(config.device, sample), - ireert.asdevicearray(config.device, timestep), - ireert.asdevicearray(config.device, encoder_hidden_states), - ] - - # Turbine output - ModuleCompiled = ctx.modules.compiled_unet - turbine_output = ModuleCompiled["main"](*device_inputs) - print( - "TURBINE OUTPUT:", - turbine_output.to_host(), - turbine_output.to_host().shape, - turbine_output.to_host().dtype, - ) - - # Torch output - torch_output = unet_model.forward(sample, timestep, encoder_hidden_states) - np_torch_output = torch_output.detach().cpu().numpy() - print( - "TORCH OUTPUT:", np_torch_output, np_torch_output.shape, np_torch_output.dtype - ) - - err = utils.largest_error(np_torch_output, turbine_output) - print("LARGEST ERROR:", err) - assert err < 9e-5 - - if __name__ == "__main__": args = parser.parse_args() unet_model = UnetModel( args.hf_model_name, args.hf_auth_token, ) - if args.run_vmfb: - run_unet_vmfb_comparison(unet_model, args) - else: - mod_str = export_unet_model( - unet_model, - args.hf_model_name, - args.batch_size, - args.height, - args.width, - args.hf_auth_token, - args.compile_to, - args.external_weights, - args.external_weight_file, - args.device, - args.iree_target_triple, - args.vulkan_max_allocation, - ) - safe_name = utils.create_safe_name(args.hf_model_name, "-unet") - with open(f"{safe_name}.mlir", "w+") as f: - f.write(mod_str) - print("Saved to", safe_name + ".mlir") + mod_str = export_unet_model( + unet_model, + args.hf_model_name, + args.batch_size, + args.height, + args.width, + args.hf_auth_token, + args.compile_to, + args.external_weights, + args.external_weight_file, + args.device, + args.iree_target_triple, + args.vulkan_max_allocation, + ) + safe_name = utils.create_safe_name(args.hf_model_name, "-unet") + with open(f"{safe_name}.mlir", "w+") as f: + f.write(mod_str) + print("Saved to", safe_name + ".mlir") diff --git a/python/turbine_models/custom_models/sd_inference/unet_runner.py b/python/turbine_models/custom_models/sd_inference/unet_runner.py new file mode 100644 index 000000000..2f73493a2 --- /dev/null +++ b/python/turbine_models/custom_models/sd_inference/unet_runner.py @@ -0,0 +1,151 @@ +import argparse +from turbine_models.model_runner import vmfbRunner +from transformers import CLIPTokenizer +from iree import runtime as ireert +import torch + +parser = argparse.ArgumentParser() + +# TODO move common runner flags to generic flag file +parser.add_argument( + "--vmfb_path", type=str, default="", help="path to vmfb containing compiled module" +) +parser.add_argument( + "--external_weight_path", + type=str, + default="", + help="path to external weight parameters if model compiled without them", +) +parser.add_argument( + "--compare_vs_torch", + action="store_true", + help="Runs both turbine vmfb and a torch model to compare results", +) +parser.add_argument( + "--hf_model_name", + type=str, + help="HF model name", + default="CompVis/stable-diffusion-v1-4", +) +parser.add_argument( + "--hf_auth_token", + type=str, + help="The Hugging face auth token, required for some models", +) +parser.add_argument( + "--device", + type=str, + default="local-task", + help="local-sync, local-task, cuda, vulkan, rocm", +) +parser.add_argument( + "--batch_size", type=int, default=1, help="Batch size for inference" +) +parser.add_argument( + "--height", type=int, default=512, help="Height of Stable Diffusion" +) +parser.add_argument("--width", type=int, default=512, help="Width of Stable Diffusion") + + +def run_unet( + device, + sample, + timestep, + encoder_hidden_states, + vmfb_path, + hf_model_name, + hf_auth_token, + external_weight_path, +): + runner = vmfbRunner(device, vmfb_path, external_weight_path) + + inputs = [ + ireert.asdevicearray(runner.config.device, sample), + ireert.asdevicearray(runner.config.device, timestep), + ireert.asdevicearray(runner.config.device, encoder_hidden_states), + ] + results = runner.ctx.modules.compiled_unet["main"](*inputs) + return results + + +def run_torch_unet( + hf_model_name, hf_auth_token, sample, timestep, encoder_hidden_states +): + from diffusers import UNet2DConditionModel + + class UnetModel(torch.nn.Module): + def __init__(self, hf_model_name, hf_auth_token): + super().__init__() + self.unet = UNet2DConditionModel.from_pretrained( + hf_model_name, + subfolder="unet", + token=hf_auth_token, + ) + self.guidance_scale = 7.5 + + def forward(self, sample, timestep, encoder_hidden_states): + samples = torch.cat([sample] * 2) + unet_out = self.unet.forward( + samples, timestep, encoder_hidden_states, return_dict=False + )[0] + noise_pred_uncond, noise_pred_text = unet_out.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) + return noise_pred + + unet_model = UnetModel( + hf_model_name, + hf_auth_token, + ) + results = unet_model.forward(sample, timestep, encoder_hidden_states) + np_torch_output = results.detach().cpu().numpy() + return np_torch_output + + +if __name__ == "__main__": + args = parser.parse_args() + sample = torch.rand( + args.batch_size, 4, args.height // 8, args.width // 8, dtype=torch.float32 + ) + timestep = torch.zeros(1, dtype=torch.float32) + if args.hf_model_name == "CompVis/stable-diffusion-v1-4": + encoder_hidden_states = torch.rand(2, 77, 768, dtype=torch.float32) + elif args.hf_model_name == "stabilityai/stable-diffusion-2-1-base": + encoder_hidden_states = torch.rand(2, 77, 1024, dtype=torch.float32) + + turbine_output = run_unet( + args.device, + sample, + timestep, + encoder_hidden_states, + args.vmfb_path, + args.hf_model_name, + args.hf_auth_token, + args.external_weight_path, + ) + print( + "TURBINE OUTPUT:", + turbine_output.to_host(), + turbine_output.to_host().shape, + turbine_output.to_host().dtype, + ) + + if args.compare_vs_torch: + print("generating torch output: ") + from turbine_models.custom_models.sd_inference import utils + + torch_output = run_torch_unet( + args.hf_model_name, + args.hf_auth_token, + sample, + timestep, + encoder_hidden_states, + ) + print("TORCH OUTPUT:", torch_output, torch_output.shape, torch_output.dtype) + err = utils.largest_error(torch_output, turbine_output) + print("Largest Error: ", err) + assert err < 9e-5 + + # TODO: Figure out why we occasionally segfault without unlinking output variables + turbine_output = None diff --git a/python/turbine_models/custom_models/sd_inference/vae.py b/python/turbine_models/custom_models/sd_inference/vae.py index b86d88ca5..7bc9247b9 100644 --- a/python/turbine_models/custom_models/sd_inference/vae.py +++ b/python/turbine_models/custom_models/sd_inference/vae.py @@ -36,10 +36,8 @@ "--height", type=int, default=512, help="Height of Stable Diffusion" ) parser.add_argument("--width", type=int, default=512, help="Width of Stable Diffusion") -parser.add_argument("--run_vmfb", action="store_true") parser.add_argument("--compile_to", type=str, help="torch, linalg, vmfb") parser.add_argument("--external_weight_file", type=str, default="") -parser.add_argument("--vmfb_path", type=str, default="") parser.add_argument( "--external_weights", type=str, @@ -110,88 +108,27 @@ def main(self, inp=AbstractTensor(*sample, dtype=torch.float32)): utils.compile_to_vmfb(module_str, device, target_triple, max_alloc, safe_name) -def run_vae_vmfb_comparison(vae_model, args): - config = ireert.Config(args.device) - - if args.external_weight_file: - index = ireert.ParameterIndex() - index.load(args.external_weight_file) - - safe_name = utils.create_safe_name(args.hf_model_name, "-vae") - if args.vmfb_path: - mod = ireert.VmModule.mmap(config.vm_instance, args.vmfb_path) - elif os.path.exists(f"{safe_name}.vmfb"): - mod = ireert.VmModule.mmap(config.vm_instance, f"{safe_name}.vmfb") - else: - sys.exit("no vmfb_path provided, required for run_vmfb") - - vm_modules = [ - mod, - ireert.create_hal_module(config.vm_instance, config.device), - ] - if args.external_weight_file: - param_module = ireert.create_io_parameters_module( - config.vm_instance, index.create_provider(scope="model") - ) - vm_modules.insert(0, param_module) - - ctx = ireert.SystemContext( - vm_modules=vm_modules, - config=config, - ) - inp = torch.rand( - args.batch_size, - 4, - args.height // 8, - args.width // 8, - dtype=torch.float32, - ) - device_inputs = [ireert.asdevicearray(config.device, inp)] - - # Turbine output - ModuleCompiled = ctx.modules.compiled_vae - turbine_output = ModuleCompiled["main"](*device_inputs) - print( - "TURBINE OUTPUT:", - turbine_output.to_host(), - turbine_output.to_host().shape, - turbine_output.to_host().dtype, - ) - - # Torch output - torch_output = vae_model.forward(inp) - torch_output = torch_output.detach().cpu().numpy() - print("TORCH OUTPUT:", torch_output, torch_output.shape, torch_output.dtype) - - err = utils.largest_error(torch_output, turbine_output) - print("LARGEST ERROR:", err) - assert err < 9e-5 - - if __name__ == "__main__": args = parser.parse_args() vae_model = VaeModel( args.hf_model_name, args.hf_auth_token, ) - if args.run_vmfb: - run_vae_vmfb_comparison(vae_model, args) - else: - mod_str = export_vae_model( - vae_model, - args.hf_model_name, - args.batch_size, - args.height, - args.width, - args.hf_auth_token, - args.compile_to, - args.external_weights, - args.external_weight_file, - args.device, - args.iree_target_triple, - args.vulkan_max_allocation, - ) - safe_name = utils.create_safe_name(args.hf_model_name, "-vae") - with open(f"{safe_name}.mlir", "w+") as f: - f.write(mod_str) - print("Saved to", safe_name + ".mlir") + mod_str = export_vae_model( + vae_model, + args.hf_model_name, + args.batch_size, + args.height, + args.width, + args.hf_auth_token, + args.compile_to, + args.external_weights, + args.external_weight_file, + args.device, + args.iree_target_triple, + args.vulkan_max_allocation, + ) + safe_name = utils.create_safe_name(args.hf_model_name, "-vae") + with open(f"{safe_name}.mlir", "w+") as f: + f.write(mod_str) + print("Saved to", safe_name + ".mlir") diff --git a/python/turbine_models/custom_models/sd_inference/vae_runner.py b/python/turbine_models/custom_models/sd_inference/vae_runner.py new file mode 100644 index 000000000..e058cfa8b --- /dev/null +++ b/python/turbine_models/custom_models/sd_inference/vae_runner.py @@ -0,0 +1,120 @@ +import argparse +from turbine_models.model_runner import vmfbRunner +from transformers import CLIPTokenizer +from iree import runtime as ireert +import torch + +parser = argparse.ArgumentParser() + +# TODO move common runner flags to generic flag file +parser.add_argument( + "--vmfb_path", type=str, default="", help="path to vmfb containing compiled module" +) +parser.add_argument( + "--external_weight_path", + type=str, + default="", + help="path to external weight parameters if model compiled without them", +) +parser.add_argument( + "--compare_vs_torch", + action="store_true", + help="Runs both turbine vmfb and a torch model to compare results", +) +parser.add_argument( + "--hf_model_name", + type=str, + help="HF model name", + default="CompVis/stable-diffusion-v1-4", +) +parser.add_argument( + "--hf_auth_token", + type=str, + help="The Hugging face auth token, required for some models", +) +parser.add_argument( + "--device", + type=str, + default="local-task", + help="local-sync, local-task, cuda, vulkan, rocm", +) +parser.add_argument( + "--batch_size", type=int, default=1, help="Batch size for inference" +) +parser.add_argument( + "--height", type=int, default=512, help="Height of Stable Diffusion" +) +parser.add_argument("--width", type=int, default=512, help="Width of Stable Diffusion") + + +def run_vae( + device, example_input, vmfb_path, hf_model_name, hf_auth_token, external_weight_path +): + runner = vmfbRunner(device, vmfb_path, external_weight_path) + + inputs = [ireert.asdevicearray(runner.config.device, example_input)] + results = runner.ctx.modules.compiled_vae["main"](*inputs) + return results + + +def run_torch_vae(hf_model_name, hf_auth_token, example_input): + from diffusers import AutoencoderKL + + class VaeModel(torch.nn.Module): + def __init__(self, hf_model_name, hf_auth_token): + super().__init__() + self.vae = AutoencoderKL.from_pretrained( + hf_model_name, + subfolder="vae", + token=hf_auth_token, + ) + + def forward(self, inp): + with torch.no_grad(): + x = self.vae.decode(inp, return_dict=False)[0] + return x + + vae_model = VaeModel( + hf_model_name, + hf_auth_token, + ) + + results = vae_model.forward(example_input) + np_torch_output = results.detach().cpu().numpy() + return np_torch_output + + +if __name__ == "__main__": + args = parser.parse_args() + example_input = torch.rand( + batch_size, 4, height // 8, width // 8, dtype=torch.float32 + ) + print("generating turbine output:") + turbine_results = run_vae( + args.device, + example_input, + args.vmfb_path, + args.hf_model_name, + args.hf_auth_token, + args.external_weight_path, + ) + print( + "TURBINE OUTPUT:", + turbine_results.to_host(), + turbine_results.to_host().shape, + turbine_results.to_host().dtype, + ) + if args.compare_vs_torch: + print("generating torch output: ") + from turbine_models.custom_models.sd_inference import utils + + torch_output = run_torch_vae( + args.hf_model_name, args.hf_auth_token, example_input + ) + print("TORCH OUTPUT:", torch_output, torch_output.shape, torch_output.dtype) + err = utils.largest_error(torch_output, turbine_results) + print("Largest Error: ", err) + assert err < 9e-5 + + # TODO: Figure out why we occasionally segfault without unlinking output variables + turbine_results = None diff --git a/python/turbine_models/custom_models/stateless_llama.py b/python/turbine_models/custom_models/stateless_llama.py index 0f3810b0c..fcfb983f5 100644 --- a/python/turbine_models/custom_models/stateless_llama.py +++ b/python/turbine_models/custom_models/stateless_llama.py @@ -8,7 +8,6 @@ from torch.utils import _pytree as pytree from shark_turbine.aot import * from iree.compiler.ir import Context -from iree import runtime as ireert from turbine_models.custom_models import remap_gguf import safetensors @@ -23,11 +22,6 @@ "--hf_auth_token", type=str, help="The Hugging Face auth token, required" ) parser.add_argument("--compile_to", type=str, help="torch, linalg, vmfb") -parser.add_argument( - "--test", - action="store_true", - help="run stateless tests instead of exporting", -) parser.add_argument( "--hf_model_name", type=str, @@ -59,10 +53,6 @@ ) parser.add_argument("--vulkan_max_allocation", type=str, default="4294967296") -prompt = """[INST] <> -Be concise. You are a helpful, respectful and honest assistant. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information. <> hi what are you? [/INST] -""" - # TODO (Dan): replace this with a file once I figure out paths on windows exe json_schema = """ [1, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}]}] @@ -282,11 +272,8 @@ def forward(token0: torch.Tensor, *state0_flat): return module_str, tokenizer -# if you're looking for run_vmfb_comparison, it's now in python/turbine_models/tests/vmfb_comparison.py - if __name__ == "__main__": args = parser.parse_args() - mod_str, _ = export_transformer_model( args.hf_model_name, args.hf_auth_token, diff --git a/python/turbine_models/model_builder.py b/python/turbine_models/model_builder.py index 5376287dc..22139ca64 100644 --- a/python/turbine_models/model_builder.py +++ b/python/turbine_models/model_builder.py @@ -40,11 +40,11 @@ def build_model(self) -> None: """ # TODO: check cloud storage for existing ir self.model = self.auto_model.from_pretrained( - self.hf_id, use_auth_token=self.hf_auth_token, config=self.auto_config + self.hf_id, token=self.hf_auth_token, config=self.auto_config ) if self.auto_tokenizer is not None: self.tokenizer = self.auto_tokenizer.from_pretrained( - self.hf_id, use_auth_token=self.hf_auth_token + self.hf_id, token=self.hf_auth_token ) else: self.tokenizer = None diff --git a/python/turbine_models/model_runner.py b/python/turbine_models/model_runner.py new file mode 100644 index 000000000..74dd3dc9a --- /dev/null +++ b/python/turbine_models/model_runner.py @@ -0,0 +1,30 @@ +import argparse +import sys +from iree import runtime as ireert + + +class vmfbRunner: + def __init__(self, device, vmfb_path, external_weight_path=None): + self.config = ireert.Config(device) + + # TODO: enable multiple vmfb's + mod = ireert.VmModule.mmap(self.config.vm_instance, vmfb_path) + vm_modules = [ + mod, + ireert.create_hal_module(self.config.vm_instance, self.config.device), + ] + + # TODO: Enable multiple weight files + if external_weight_path: + index = ireert.ParameterIndex() + index.load(external_weight_path) + # TODO: extend scope + param_module = ireert.create_io_parameters_module( + self.config.vm_instance, index.create_provider(scope="model") + ) + vm_modules.insert(0, param_module) + + self.ctx = ireert.SystemContext( + vm_modules=vm_modules, + config=self.config, + ) diff --git a/python/turbine_models/tests/sd_test.py b/python/turbine_models/tests/sd_test.py index e01027fc5..b8dca64f5 100644 --- a/python/turbine_models/tests/sd_test.py +++ b/python/turbine_models/tests/sd_test.py @@ -6,7 +6,17 @@ import argparse import logging -from turbine_models.custom_models.sd_inference import clip, unet, vae +from turbine_models.custom_models.sd_inference import ( + clip, + clip_runner, + unet, + unet_runner, + vae, + vae_runner, +) +from transformers import CLIPTextModel +from turbine_models.custom_models.sd_inference import utils +import torch import unittest import os @@ -19,12 +29,14 @@ "width": 512, "run_vmfb": True, "compile_to": None, - "external_weight_file": "", + "external_weight_path": "", "vmfb_path": "", "external_weights": None, "device": "local-task", "iree_target_triple": "", "vulkan_max_allocation": "4294967296", + "prompt": "a photograph of an astronaut riding a horse", + "in_channels": 4, } @@ -54,9 +66,21 @@ def testExportClipModel(self): "cpu", ) self.assertEqual(cm.exception.code, None) - arguments["external_weight_file"] = "stable_diffusion_v1_4_clip.safetensors" - namespace = argparse.Namespace(**arguments) - clip.run_clip_vmfb_comparison(namespace) + arguments["external_weight_path"] = "stable_diffusion_v1_4_clip.safetensors" + arguments["vmfb_path"] = "stable_diffusion_v1_4_clip.vmfb" + turbine = clip_runner.run_clip( + arguments["device"], + arguments["prompt"], + arguments["vmfb_path"], + arguments["hf_model_name"], + arguments["hf_auth_token"], + arguments["external_weight_path"], + ) + torch_output = clip_runner.run_torch_clip( + arguments["hf_model_name"], arguments["hf_auth_token"], arguments["prompt"] + ) + err = utils.largest_error(torch_output, turbine[0]) + assert err < 9e-5 os.remove("stable_diffusion_v1_4_clip.safetensors") os.remove("stable_diffusion_v1_4_clip.vmfb") @@ -76,9 +100,37 @@ def testExportUnetModel(self): "cpu", ) self.assertEqual(cm.exception.code, None) - arguments["external_weight_file"] = "stable_diffusion_v1_4_unet.safetensors" - namespace = argparse.Namespace(**arguments) - unet.run_unet_vmfb_comparison(unet_model, namespace) + arguments["external_weight_path"] = "stable_diffusion_v1_4_unet.safetensors" + arguments["vmfb_path"] = "stable_diffusion_v1_4_unet.vmfb" + sample = torch.rand( + arguments["batch_size"], + arguments["in_channels"], + arguments["height"] // 8, + arguments["width"] // 8, + dtype=torch.float32, + ) + timestep = torch.zeros(1, dtype=torch.float32) + encoder_hidden_states = torch.rand(2, 77, 768, dtype=torch.float32) + + turbine = unet_runner.run_unet( + arguments["device"], + sample, + timestep, + encoder_hidden_states, + arguments["vmfb_path"], + arguments["hf_model_name"], + arguments["hf_auth_token"], + arguments["external_weight_path"], + ) + torch_output = unet_runner.run_torch_unet( + arguments["hf_model_name"], + arguments["hf_auth_token"], + sample, + timestep, + encoder_hidden_states, + ) + err = utils.largest_error(torch_output, turbine) + assert err < 9e-5 os.remove("stable_diffusion_v1_4_unet.safetensors") os.remove("stable_diffusion_v1_4_unet.vmfb") @@ -98,9 +150,28 @@ def testExportVaeModel(self): "cpu", ) self.assertEqual(cm.exception.code, None) - arguments["external_weight_file"] = "stable_diffusion_v1_4_vae.safetensors" - namespace = argparse.Namespace(**arguments) - vae.run_vae_vmfb_comparison(vae_model, namespace) + arguments["external_weight_path"] = "stable_diffusion_v1_4_vae.safetensors" + arguments["vmfb_path"] = "stable_diffusion_v1_4_vae.vmfb" + example_input = torch.rand( + arguments["batch_size"], + 4, + arguments["height"] // 8, + arguments["width"] // 8, + dtype=torch.float32, + ) + turbine = vae_runner.run_vae( + arguments["device"], + example_input, + arguments["vmfb_path"], + arguments["hf_model_name"], + arguments["hf_auth_token"], + arguments["external_weight_path"], + ) + torch_output = vae_runner.run_torch_vae( + arguments["hf_model_name"], arguments["hf_auth_token"], example_input + ) + err = utils.largest_error(torch_output, turbine) + assert err < 9e-5 os.remove("stable_diffusion_v1_4_vae.safetensors") os.remove("stable_diffusion_v1_4_vae.vmfb") diff --git a/python/turbine_models/tests/stateless_llama_test.py b/python/turbine_models/tests/stateless_llama_test.py index c99cb7c23..081c31b71 100644 --- a/python/turbine_models/tests/stateless_llama_test.py +++ b/python/turbine_models/tests/stateless_llama_test.py @@ -30,7 +30,6 @@ import safetensors from tqdm import tqdm -from .vmfb_comparison import get_turbine_vmfb_string def test_vmfb_comparison(): @@ -77,31 +76,25 @@ def test_vmfb_comparison(): with open(torch_str_cache_path, "r") as f: torch_str = f.read() else: - from .vmfb_comparison import get_torch_string - - torch_str = get_torch_string( - prompt=DEFAULT_PROMPT, - hf_auth_token=None, - hf_model_name="Trelis/Llama-2-7b-chat-hf-function-calling-v2", - tokens_to_compare=50, - precision=precision, - quantization=quantization, + from turbine_models.custom_models import llm_runner + + torch_str = llm_runner.run_torch_llm( + "Trelis/Llama-2-7b-chat-hf-function-calling-v2", None, DEFAULT_PROMPT ) with open(torch_str_cache_path, "w") as f: f.write(torch_str) - turbine_str = get_turbine_vmfb_string( - prompt=DEFAULT_PROMPT, - hf_auth_token=None, - hf_model_name="Trelis/Llama-2-7b-chat-hf-function-calling-v2", - vmfb_path="Llama_2_7b_chat_hf_function_calling_v2.vmfb", - external_weight_file=f"Llama_2_7b_chat_hf_function_calling_v2_{precision}_{quantization}.safetensors", - tokens_to_compare=50, - device="llvm-cpu", - ) + from turbine_models.custom_models import llm_runner - torch_str = torch_str[: len(turbine_str)] + turbine_str = llm_runner.run_llm( + "local-task", + DEFAULT_PROMPT, + "Llama_2_7b_chat_hf_function_calling_v2.vmfb", + "Trelis/Llama-2-7b-chat-hf-function-calling-v2", + None, + f"Llama_2_7b_chat_hf_function_calling_v2_{precision}_{quantization}.safetensors", + ) import difflib @@ -113,5 +106,4 @@ def test_vmfb_comparison(): tofile="turbine_str", lineterm="", ) - assert torch_str == turbine_str, "".join(diff) diff --git a/python/turbine_models/tests/vmfb_comparison.py b/python/turbine_models/tests/vmfb_comparison.py deleted file mode 100644 index 112bb89f5..000000000 --- a/python/turbine_models/tests/vmfb_comparison.py +++ /dev/null @@ -1,226 +0,0 @@ -import os -import sys -import re - -from typing import Tuple - -os.environ["TORCH_LOGS"] = "dynamic" -from transformers import AutoTokenizer, AutoModelForCausalLM -import torch -from torch.utils import _pytree as pytree -from shark_turbine.aot import * -from iree.compiler.ir import Context -from iree import runtime as ireert - -from turbine_models.custom_models import remap_gguf -import safetensors - -from tqdm import tqdm -from typing import Literal - - -def torch_token_generator( - prompt, - hf_model_name: str, - hf_auth_token: str, - break_on_eos=False, - precision="f32", - quantization="unquantized", -): - if precision == "f16": - torch_dtype = torch.float16 - elif precision == "f32": - torch_dtype = torch.float32 - else: - raise ValueError("Invalid dtype, f16 or f32 supported") - - if ( - quantization is not None - and quantization.lower() != "none" - and quantization.lower() != "unquantized" - ): - raise NotImplementedError("Quantization not supported for torch") - - tokenizer = AutoTokenizer.from_pretrained( - hf_model_name, - use_fast=False, - use_auth_token=hf_auth_token, - ) - model = AutoModelForCausalLM.from_pretrained( - hf_model_name, torch_dtype=torch_dtype, use_auth_token=hf_auth_token - ) - - initial_input = tokenizer(prompt, return_tensors="pt") - input_ids = initial_input.input_ids - past_key_values = None - - while True: - model_results = model.forward(input_ids, past_key_values=past_key_values) - logits = model_results.logits - next_token_id = torch.argmax(logits[:, -1, :], dim=1) - past_key_values = model_results.past_key_values - - yield next_token_id - input_ids = next_token_id.unsqueeze(0) # Prepare for the next iteration - - if next_token_id.item() == tokenizer.eos_token_id and break_on_eos: - break - - -def turbine_token_generator( - prompt: str, - hf_model_name: str, - vmfb_path: str = None, - external_weight_file: str = None, - hf_auth_token: str = None, - break_on_eos: bool = False, - device: Literal["llvm-cpu", "cuda", "vulcan", "rocm"] = "llvm-cpu", -) -> torch.Tensor: - """ - A generator function for turbine model inference. - - :param prompt: The input prompt for the model. - :param hf_model_name: The name of the Hugging Face model. - :param vmfb_path: Path to the .vmfb model file. - :param external_weight_file: Path to the external weight file (optional). - :param hf_auth_token: Hugging Face authorization token (optional). - :param break_on_eos: Whether to break the loop on end-of-sentence token. - :return: Yields a tensor representing the generated token. - """ - - # Create the config for the IREE runtime environment - config = ireert.Config("local-task" if device == "llvm-cpu" else device) - - # Load the external weight file if provided - if external_weight_file: - index = ireert.ParameterIndex() - index.load(external_weight_file) - - # Ensure model name is in a safe format - safe_name = hf_model_name.split("/")[-1].strip() - safe_name = re.sub("-", "_", safe_name) - - # Load the .vmfb model file - if vmfb_path: - mod = ireert.VmModule.mmap(config.vm_instance, vmfb_path) - elif os.path.exists(f"{safe_name}.vmfb"): - mod = ireert.VmModule.mmap(config.vm_instance, f"{safe_name}.vmfb") - else: - raise FileNotFoundError("No vmfb_path provided, required for run_vmfb") - - # Prepare the modules for the IREE runtime context - vm_modules = [mod, ireert.create_hal_module(config.vm_instance, config.device)] - - # Include parameter module if external weight file is used - if external_weight_file: - param_module = ireert.create_io_parameters_module( - config.vm_instance, index.create_provider(scope="model") - ) - vm_modules.insert(0, param_module) - - # Create the system context with the given configuration and modules - ctx = ireert.SystemContext(vm_modules=vm_modules, config=config) - - # Initialize the tokenizer - tokenizer = AutoTokenizer.from_pretrained( - hf_model_name, use_fast=False, use_auth_token=hf_auth_token - ) - - # Convert the prompt to input tensor - initial_input = tokenizer(prompt, return_tensors="pt") - example_input_id = initial_input.input_ids - device_inputs = [ireert.asdevicearray(config.device, example_input_id)] - - # Get the compiled module - ModuleCompiled = ctx.modules.state_update - results = ModuleCompiled["run_initialize"](*device_inputs) - - def format_out(results): - # Convert the output to a PyTorch tensor - return torch.tensor(results.to_host()[0][0]) - - # Token generation loop - while True: - next_token_tensor = format_out(results) - yield next_token_tensor.item() # Yield the scalar value of the tensor - - # Run the next step of the model - results = ModuleCompiled["run_forward"](results) - - # Check for the end-of-sentence token - if next_token_tensor.item() == tokenizer.eos_token_id and break_on_eos: - break - - -def get_torch_string( - prompt, - hf_auth_token, - hf_model_name, - precision, - quantization, - tokens_to_compare=50, -): - print("Using prompt:") - print(prompt) - print("To generate torch reference string...") - torch_gen = torch_token_generator( - prompt=prompt, - hf_auth_token=hf_auth_token, - hf_model_name=hf_model_name, - break_on_eos=True, - precision=precision, - quantization=quantization, - ) - tokenizer = AutoTokenizer.from_pretrained( - hf_model_name, use_fast=False, use_auth_token=hf_auth_token - ) - - print( - "Generating Torch tokens... The pipeline needs to be initialized first so the first few tokens may take a while." - ) - # read until stopiteration - torch_tokens = list(tqdm(torch_gen, desc="Generating Torch tokens")) - torch_str = tokenizer.decode(torch.tensor(torch_tokens).numpy()) - - return torch_str - - -def get_turbine_vmfb_string( - prompt, - hf_auth_token, - hf_model_name, - vmfb_path, - external_weight_file, - device, - tokens_to_compare=50, -): - # Initialize generators with the prompt and specific arguments - # check if torch string cache exists - # cache is at python/turbine_models/tests/vmfb_comparison_cached_torch_output.txt - - # Decode and print the outputs - tokenizer = AutoTokenizer.from_pretrained( - hf_model_name, use_fast=False, use_auth_token=hf_auth_token - ) - - # Run turbine until an equal number of tokens has been generated - print( - "Generating Turbine tokens... The pipeline needs to be initialized first so the first few tokens may take a while." - ) - turbine_gen = turbine_token_generator( - prompt=prompt, - hf_model_name=hf_model_name, - vmfb_path=vmfb_path, - external_weight_file=external_weight_file, - hf_auth_token=hf_auth_token, - break_on_eos=False, - device=device, - ) - turbine_tokens = [] - for _ in tqdm(range(tokens_to_compare), desc="Generating Turbine tokens"): - token = next(turbine_gen) - turbine_tokens.append(token) - del turbine_gen - - turbine_str = tokenizer.decode(torch.tensor(turbine_tokens).numpy()) - return turbine_str