Skip to content

Commit

Permalink
🐛 Bug: Fix the bug where PNG images cannot be recognized.
Browse files Browse the repository at this point in the history
  • Loading branch information
yym68686 committed Oct 2, 2024
1 parent c1e2f2d commit 4683124
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 7 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

setup(
name="modelmerge",
version="0.11.48",
version="0.11.49",
description="modelmerge is a multi-large language model API aggregator.",
long_description=Path.open(Path("README.md"), encoding="utf-8").read(),
long_description_content_type="text/markdown",
Expand Down
25 changes: 19 additions & 6 deletions src/ModelMerge/utils/scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,19 @@ def cut_message(message: str, max_tokens: int, model_name: str):
encode_text = encoding.encode(message)
return message, len(encode_text)

import imghdr
def encode_image(image_path):
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode('utf-8')
with open(image_path, "rb") as image_file:
file_content = image_file.read()
file_type = imghdr.what(None, file_content)
base64_encoded = base64.b64encode(file_content).decode('utf-8')

if file_type == 'png':
return f"data:image/png;base64,{base64_encoded}"
elif file_type in ['jpeg', 'jpg']:
return f"data:image/jpeg;base64,{base64_encoded}"
else:
raise ValueError(f"不支持的图片格式: {file_type}")

def get_doc_from_url(url):
filename = urllib.parse.unquote(url.split("/")[-1])
Expand All @@ -42,13 +52,16 @@ def get_encode_image(image_url):
filename = get_doc_from_url(image_url)
image_path = os.getcwd() + "/" + filename
base64_image = encode_image(image_path)
prompt = f"data:image/jpeg;base64,{base64_image}"
os.remove(image_path)
return prompt
return base64_image

def get_image_message(image_url, message, engine = None):
if image_url:
base64_image = get_encode_image(image_url)
colon_index = base64_image.index(":")
semicolon_index = base64_image.index(";")
image_type = base64_image[colon_index + 1:semicolon_index]

if "gpt-4" in engine \
or (os.environ.get('claude_api_key', None) is None and "claude-3" in engine) \
or (os.environ.get('GOOGLE_AI_API_KEY', None) is None and "gemini" in engine) \
Expand All @@ -67,7 +80,7 @@ def get_image_message(image_url, message, engine = None):
"type": "image",
"source": {
"type": "base64",
"media_type": "image/jpeg",
"media_type": image_type,
"data": base64_image.split(",")[1],
}
}
Expand All @@ -80,7 +93,7 @@ def get_image_message(image_url, message, engine = None):
message.append(
{
"inlineData": {
"mimeType": "image/jpeg",
"mimeType": image_type,
"data": base64_image.split(",")[1],
}
}
Expand Down

0 comments on commit 4683124

Please sign in to comment.