Skip to content

Commit

Permalink
Merge pull request #2 from runninglsy/main
Browse files Browse the repository at this point in the history
add support for Ovis1.5 models
  • Loading branch information
zwcolin authored Aug 1, 2024
2 parents 4fc91a5 + 1bf46dd commit d9118f3
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 0 deletions.
44 changes: 44 additions & 0 deletions src/generate_lib/ovis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Adapted from https://huggingface.co/AIDC-AI/Ovis1.5-Llama3-8B
# This has support for the Ovis model series

import torch
from PIL import Image
from transformers import AutoModelForCausalLM
from tqdm import tqdm

def generate_response(model_path, queries):
model = AutoModelForCausalLM.from_pretrained(model_path,
torch_dtype=torch.bfloat16,
multimodal_max_length=8192,
trust_remote_code=True).cuda()
text_tokenizer = model.get_text_tokenizer()
visual_tokenizer = model.get_visual_tokenizer()
conversation_formatter = model.get_conversation_formatter()

for k in tqdm(queries):
query = queries[k]['question']
image = queries[k]["figure_path"]
image = Image.open(image).convert('RGB')
query = f'<image>\n{query}'
prompt, input_ids = conversation_formatter.format_query(query)
input_ids = torch.unsqueeze(input_ids, dim=0).to(device=model.device)
attention_mask = torch.ne(input_ids, text_tokenizer.pad_token_id).to(device=model.device)
pixel_values = [visual_tokenizer.preprocess_image(image).to(
dtype=visual_tokenizer.dtype, device=visual_tokenizer.device)]

with torch.inference_mode():
gen_kwargs = dict(
max_new_tokens=1024,
do_sample=False,
top_p=None,
top_k=None,
temperature=None,
repetition_penalty=None,
eos_token_id=model.generation_config.eos_token_id,
pad_token_id=text_tokenizer.pad_token_id,
use_cache=True
)
output_ids = model.generate(input_ids, pixel_values=pixel_values, attention_mask=attention_mask, **gen_kwargs)[0]
response = text_tokenizer.decode(output_ids, skip_special_tokens=True)

queries[k]['response'] = response
4 changes: 4 additions & 0 deletions src/generate_lib/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,10 @@ def get_generate_fn(model_path):
# vila
elif model_name in ['VILA1.5-40b']:
from .vila15 import generate_response
# ovis
elif model_name in ['Ovis1.5-Llama3-8B',
'Ovis1.5-Gemma2-9B']:
from .ovis import generate_response
else:
raise ValueError(f"Model {model_name} not supported")
return generate_response

0 comments on commit d9118f3

Please sign in to comment.