diff --git a/src/generate_lib/ovis.py b/src/generate_lib/ovis.py new file mode 100644 index 0000000..3b5ef82 --- /dev/null +++ b/src/generate_lib/ovis.py @@ -0,0 +1,44 @@ +# Adapted from https://huggingface.co/AIDC-AI/Ovis1.5-Llama3-8B +# This has support for the Ovis model series + +import torch +from PIL import Image +from transformers import AutoModelForCausalLM +from tqdm import tqdm + +def generate_response(model_path, queries): + model = AutoModelForCausalLM.from_pretrained(model_path, + torch_dtype=torch.bfloat16, + multimodal_max_length=8192, + trust_remote_code=True).cuda() + text_tokenizer = model.get_text_tokenizer() + visual_tokenizer = model.get_visual_tokenizer() + conversation_formatter = model.get_conversation_formatter() + + for k in tqdm(queries): + query = queries[k]['question'] + image = queries[k]["figure_path"] + image = Image.open(image).convert('RGB') + query = f'\n{query}' + prompt, input_ids = conversation_formatter.format_query(query) + input_ids = torch.unsqueeze(input_ids, dim=0).to(device=model.device) + attention_mask = torch.ne(input_ids, text_tokenizer.pad_token_id).to(device=model.device) + pixel_values = [visual_tokenizer.preprocess_image(image).to( + dtype=visual_tokenizer.dtype, device=visual_tokenizer.device)] + + with torch.inference_mode(): + gen_kwargs = dict( + max_new_tokens=1024, + do_sample=False, + top_p=None, + top_k=None, + temperature=None, + repetition_penalty=None, + eos_token_id=model.generation_config.eos_token_id, + pad_token_id=text_tokenizer.pad_token_id, + use_cache=True + ) + output_ids = model.generate(input_ids, pixel_values=pixel_values, attention_mask=attention_mask, **gen_kwargs)[0] + response = text_tokenizer.decode(output_ids, skip_special_tokens=True) + + queries[k]['response'] = response diff --git a/src/generate_lib/utils.py b/src/generate_lib/utils.py index baaa398..db4107a 100644 --- a/src/generate_lib/utils.py +++ b/src/generate_lib/utils.py @@ -131,6 +131,10 @@ def get_generate_fn(model_path): # vila elif model_name in ['VILA1.5-40b']: from .vila15 import generate_response + # ovis + elif model_name in ['Ovis1.5-Llama3-8B', + 'Ovis1.5-Gemma2-9B']: + from .ovis import generate_response else: raise ValueError(f"Model {model_name} not supported") return generate_response