Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add GPT-4V as evaluator #276

Draft
wants to merge 15 commits into
base: main
Choose a base branch
from
4 changes: 2 additions & 2 deletions configs/config_all.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ process:
radius: 2 # radius of blur kernel
- image_captioning_from_gpt4v_mapper: # generate samples whose texts are generated based on gpt-4-visison and the image
mode: 'description' # mode of text generated from images, can be one of ['resoning', 'description', 'conversation', 'custom']
api_key: '' # the API key to authenticate the request
max_token: 500 # the maximum number of tokens to generate. Default is 500.
api_key: null # the API key to authenticate the request
max_tokens: 500 # the maximum number of tokens to generate. Default is 500.
temperature: 1.0 # controls the randomness of the output (range from 0 to 1). Default is 0.
system_prompt: '' # a string prompt used to set the context of a conversation and provide global guidance or rules for the gpt4-vision so that it can generate responses in the expected way. If `mode` set to `custom`, the parameter will be used
user_prompt: '' # a string prompt to guide the generation of gpt4-vision for each samples. It's "" in default, which means no prompt provided
Expand Down
89 changes: 12 additions & 77 deletions data_juicer/ops/mapper/image_captioning_from_gpt4v_mapper.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import copy

import requests
from jsonargparse.typing import ClosedUnitInterval
from loguru import logger

Expand All @@ -9,6 +8,7 @@
load_image_byte,
remove_non_special_tokens,
remove_special_tokens)
from data_juicer.utils.model_utils import call_gpt_vision_api

from ..base_op import OPERATORS, Mapper
from ..op_fusion import LOADED_IMAGES
Expand All @@ -23,75 +23,6 @@
}


def call_gpt_vision_api(api_key,
system_prompt,
user_prompt,
base64_image,
max_tokens=500,
temperature=1.0,
model='gpt-4-vision-preview'):
api_url = 'https://api.openai.com/v1/chat/completions'
headers = {
'Content-Type': 'application/json',
'Authorization': f'Bearer {api_key}'
}
data = {
'model':
model,
'messages': [{
'role': 'system',
'content': system_prompt
}, {
'role':
'user',
'content': [{
'type': 'text',
'text': user_prompt
}, {
'type': 'image_url',
'image_url': {
'url': f'data:image/jpeg;base64,{base64_image}',
'detail': 'low'
}
}]
}],
'max_tokens':
max_tokens,
'temperature':
temperature
}
try:
response = requests.post(api_url, headers=headers, json=data)
response.raise_for_status()
result = response.json()

if 'choices' in result and result['choices']:
return result['choices'][0]['text']
else:
logger.warning('No results returned from the API, return None.')
return None

except requests.exceptions.HTTPError as errh:
if errh.response.status_code == 401:
logger.warning('Invalid API key provided.')
elif errh.response.status_code == 429:
logger.warning(
'API request limit has been reached. Please try again later.')
else:
logger.warning(f'HTTP error occurred: {errh}')
except requests.exceptions.ConnectionError:
logger.warning('Network error occurred. Please check your connection.')
except requests.exceptions.Timeout:
logger.warning('The request timed out. Please try again later.')
except requests.exceptions.RequestException as err:
logger.warningt(f'An error occurred: {err}')
except Exception as e:
logger.warning(f'An unexpected error occurred: {e}')

logger.warning('API request failed, return None.')
return None


@OPERATORS.register_module('image_captioning_from_gpt4v_mapper')
@LOADED_IMAGES.register_module('image_captioning_from_gpt4v_mapper')
class ImageCaptioningFromGPT4VMapper(Mapper):
Expand All @@ -100,8 +31,8 @@ class ImageCaptioningFromGPT4VMapper(Mapper):

def __init__(self,
mode: str = 'description',
api_key: str = '',
max_token: int = 500,
api_key: str = None,
max_tokens: int = 500,
temperature: ClosedUnitInterval = 1.0,
system_prompt: str = '',
user_prompt: str = '',
Expand All @@ -116,7 +47,7 @@ def __init__(self,
:param mode: mode of text generated from images, can be one of
['resoning', 'description', 'conversation', 'custom']
:param api_key: the API key to authenticate the request.
:param max_token: the maximum number of tokens to generate.
:param max_tokens: the maximum number of tokens to generate.
Default is 500.
:param temperature: controls the randomness of the output (range
from 0 to 1). Default is 0.
Expand Down Expand Up @@ -162,7 +93,7 @@ def __init__(self,

self.mode = mode
self.api_key = api_key
self.max_token = max_token
self.max_tokens = max_tokens
self.temperature = temperature
self.user_prompt = user_prompt
self.user_prompt_key = user_prompt_key
Expand Down Expand Up @@ -219,10 +150,14 @@ def _process_single_sample(self, sample):
generated_text_single_chunk = []
for image_key in loaded_image_keys[offset:offset + img_count]:
image = images[image_key]
res = call_gpt_vision_api(self.api_key, self.system_prompt,
res = call_gpt_vision_api(self.system_prompt,
prompt_texts,
image_byte_to_base64(image),
self.max_token, self.temperature)
image_byte_to_base64(
image, 'image/jpeg'),
api_key=self.api_key,
max_tokens=self.max_tokens,
temperature=self.temperature,
**self.extra_args)
generated_text_single_chunk.append(res)
if self.any_or_all == 'all' and not all(
generated_text_single_chunk):
Expand Down
48 changes: 44 additions & 4 deletions data_juicer/utils/mm_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import base64
import datetime
import imghdr
import io
import mimetypes
import os
import re
import shutil
Expand Down Expand Up @@ -109,13 +112,27 @@ def load_image_byte(path):
return image_data


def image_path_to_base64(image_path):
def image_path_to_base64(image_path, mime_type=None):
with open(image_path, 'rb') as image_file:
return base64.b64encode(image_file.read()).decode('utf-8')
image_byte = image_file.read()
return image_byte_to_base64(image_byte, mime_type)


def image_byte_to_base64(image_byte):
return base64.b64encode(image_byte).decode('utf-8')
def image_byte_to_base64(image_byte, mime_type=None):
image_base64 = base64.b64encode(image_byte).decode('utf-8')

if mime_type is None:
return image_base64

if mime_type == 'auto':
# guess correct mime_type
image_type = imghdr.what(None, h=image_byte)
mime_type = mimetypes.types_map.get(f'.{image_type}')
if mime_type is None:
mime_type = 'image/jpeg'

# https://www.rfc-editor.org/rfc/rfc2397
return f'data:{mime_type};base64,{image_base64}'


def pil_to_opencv(pil_image):
Expand Down Expand Up @@ -175,6 +192,29 @@ def load_video(path):
return container


def video_path_to_base64(video_path, frame_num=None, fps=None, mime_type=None):
duration = get_video_duration(video_path)

if fps is not None:
valid_frame_num = round(duration * fps)
if frame_num is not None:
valid_frame_num = min(valid_frame_num, frame_num)
else:
valid_frame_num = frame_num or 0

frames_base64 = []
if frame_num > 0:
frames = extract_video_frames_uniformly(video_path, valid_frame_num)
for frame in frames:
img = frame.to_image()
buf = io.BytesIO()
img.save(buf, format='JPEG')
buf.seek(0)
frame_base64 = image_byte_to_base64(buf.read(), mime_type)
frames_base64.append(frame_base64)
return frames_base64


def get_video_duration(input_video: Union[str, av.container.InputContainer],
video_stream_index=0):
"""
Expand Down
112 changes: 104 additions & 8 deletions data_juicer/utils/model_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import fnmatch
import os
from functools import partial
from typing import Optional, Union
from typing import Any, List, Optional, Union

import multiprocess as mp
import requests
import wget
from loguru import logger

Expand All @@ -14,8 +15,8 @@
MODEL_ZOO = {}

# Default cached models links for downloading
MODEL_LINKS = 'https://dail-wlcb.oss-cn-wulanchabu.aliyuncs.com/' \
'data_juicer/models/'
MODEL_LINKS = ('https://dail-wlcb.oss-cn-wulanchabu.aliyuncs.com/'
'data_juicer/models/')

# Backup cached models links for downloading
BACKUP_MODEL_LINKS = {
Expand Down Expand Up @@ -174,9 +175,9 @@ def prepare_nltk_model(lang, name_pattern='punkt.{}.pickle'):
'pt': 'portuguese',
'es': 'spanish'
}
assert lang in nltk_to_punkt.keys(
), 'lang must be one of the following: {}'.format(
list(nltk_to_punkt.keys()))
assert (lang in nltk_to_punkt.keys()
), 'lang must be one of the following: {}'.format(
list(nltk_to_punkt.keys()))
model_name = name_pattern.format(nltk_to_punkt[lang])

logger.info('Loading nltk punkt split model...')
Expand Down Expand Up @@ -319,6 +320,7 @@ def __init__(self, config: Blip2Config) -> None:
self.post_init()

from transformers import AutoProcessor

processor = AutoProcessor.from_pretrained(
pretrained_model_name_or_path, trust_remote_code=trust_remote_code)
if return_model:
Expand Down Expand Up @@ -418,10 +420,11 @@ def prepare_spacy_model(lang, name_pattern='{}_core_web_md-3.5.0'):
# decompress the compressed model if it's not decompressed
def decompress_model(compressed_model_path):
decompressed_model_path = compressed_model_path.replace('.zip', '')
if os.path.exists(decompressed_model_path) \
and os.path.isdir(decompressed_model_path):
if os.path.exists(decompressed_model_path) and os.path.isdir(
decompressed_model_path):
return decompressed_model_path
import zipfile

with zipfile.ZipFile(compressed_model_path) as zf:
zf.extractall(DJMC)
return decompressed_model_path
Expand Down Expand Up @@ -506,6 +509,7 @@ def prepare_recognizeAnything_model(
:param input_size: the input size of the model.
"""
from ram.models import ram_plus

logger.info('Loading recognizeAnything model...')
try:
model = ram_plus(pretrained=check_model(pretrained_model_name_or_path),
Expand Down Expand Up @@ -573,3 +577,95 @@ def get_model(model_key=None, rank=None):
rank = 0 if rank is None else rank
move_to_cuda(MODEL_ZOO[model_key], rank)
return MODEL_ZOO[model_key]


def call_gpt_vision_api(
system_prompt: str = '',
user_prompt: str = '',
images: Union[str, List[str], None] = None,
*,
api_key: str = None,
model: str = 'gpt-4-vision-preview',
max_tokens: int = 500,
temperature: float = 0.0,
**kwargs: Any,
):
images = [images] if isinstance(images, str) else (images or [])

api_url = 'https://api.openai.com/v1/chat/completions'

if api_key is None:
api_key = os.getenv('OPENAI_API_KEY')
if api_key is None:
logger.error(
'The api_key must be set either by passing it to the function '
'call or by setting the OPENAI_API_KEY environment variable')
return ''

headers = {
'Content-Type': 'application/json',
'Authorization': f'Bearer {api_key}',
}
image_payload = [{
'type': 'image_url',
'image_url': {
'url': url,
'detail': 'low'
}
} for url in images]
data = {
'model':
model,
'messages': [
{
'role': 'system',
'content': system_prompt
},
{
'role':
'user',
'content': [
{
'type': 'text',
'text': user_prompt
},
*image_payload,
],
},
],
'max_tokens':
max_tokens,
'temperature':
temperature,
**kwargs,
}

try:
response = requests.post(api_url, headers=headers, json=data)
response.raise_for_status()
result = response.json()

if 'choices' in result and result['choices']:
return result['choices'][0]['message']['content']
else:
logger.warning('No results returned from the API.')
return ''
except requests.exceptions.HTTPError as errh:
if errh.response.status_code == 401:
logger.warning('Invalid API key provided.')
elif errh.response.status_code == 429:
logger.warning(
'API request limit has been reached. Please try again later.')
else:
logger.warning(f'HTTP error occurred: {errh}')
except requests.exceptions.ConnectionError:
logger.warning('Network error occurred. Please check your connection.')
except requests.exceptions.Timeout:
logger.warning('The request timed out. Please try again later.')
except requests.exceptions.RequestException as err:
logger.warningt(f'An error occurred: {err}')
except Exception as e:
logger.warning(f'An unexpected error occurred: {e}')

logger.warning('API request failed.')
return ''
Empty file added tools/mm_eval/__init__.py
Empty file.
Loading
Loading