Skip to content

Commit

Permalink
fix: fix duplicate images
Browse files Browse the repository at this point in the history
  • Loading branch information
IcyKallen committed Jan 27, 2025
1 parent ce80cde commit e8e0802
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 9 deletions.
Binary file modified source/lambda/job/dep/dist/llm_bot_dep-0.1.0-py3-none-any.whl
Binary file not shown.
63 changes: 54 additions & 9 deletions source/lambda/job/dep/llm_bot_dep/figure_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,14 @@ def invoke_llm(self, img, prompt, prefix="<output>", stop="</output>"):
{
"role": "user",
"content": [
{"type": "image", "source": {"type": "base64", "media_type": "image/jpeg", "data": base64_encoded}},
{
"type": "image",
"source": {
"type": "base64",
"media_type": "image/jpeg",
"data": base64_encoded,
},
},
{"type": "text", "text": prompt},
],
},
Expand Down Expand Up @@ -132,7 +139,10 @@ def get_description(self, img, context, tag):
def get_mermaid(self, img, classification):
with open(MERMAID_TEMPLATE_PATH) as f:
mermaid_prompt = f.read()
prompt = mermaid_prompt.format(diagram_type=classification, diagram_example=self.mermaid_prompt[classification])
prompt = mermaid_prompt.format(
diagram_type=classification,
diagram_example=self.mermaid_prompt[classification],
)
output = self.invoke_llm(img, prompt, prefix="<description>", stop="</mermaid>")
return output

Expand Down Expand Up @@ -192,7 +202,12 @@ def encode_image_to_base64(image_path: str) -> str:


def upload_image_to_s3(
image_data: Union[str, bytes], bucket: str, file_name: str, splitting_type: str, idx: int, is_bytes: bool = False
image_data: Union[str, bytes],
bucket: str,
file_name: str,
splitting_type: str,
idx: int,
is_bytes: bool = False,
):
"""Upload image to S3 from either a file path or binary data.
Expand Down Expand Up @@ -237,7 +252,13 @@ def download_image_from_url(img_url: str) -> str:


def process_single_image(
img_path: str, context: str, image_tag: str, bucket_name: str, file_name: str, idx: int
img_path: str,
context: str,
image_tag: str,
bucket_name: str,
file_name: str,
idx: int,
s3_link: str = None,
) -> str:
"""Process a single image and return its understanding text.
Expand Down Expand Up @@ -268,10 +289,12 @@ def process_single_image(
understanding = figure_llm.figure_understand(image_base64, context, image_tag, s3_link=f"{idx:05d}.jpg")

# Update S3 link
updated_s3_link = upload_image_to_s3(img_path, bucket_name, file_name, "image", idx)
understanding = understanding.replace(f"<link>{idx:05d}.jpg</link>", f"<link>{updated_s3_link}</link>")
if not s3_link:
s3_link = upload_image_to_s3(img_path, bucket_name, file_name, "image", idx)

understanding = understanding.replace(f"<link>{idx:05d}.jpg</link>", f"<link>{s3_link}</link>")

return understanding
return understanding, s3_link


def process_markdown_images_with_llm(content: str, bucket_name: str, file_name: str) -> str:
Expand All @@ -296,6 +319,8 @@ def process_markdown_images_with_llm(content: str, bucket_name: str, file_name:
image_pattern = r"!\[([^\]]*)\]\(([^)]+)\)"
last_end = 0
result = ""
# Add mapping to track image paths and their S3 objects
image_s3_mapping = {}

for idx, match in enumerate(re.finditer(image_pattern, content), 1):
start, end = match.span()
Expand All @@ -309,20 +334,40 @@ def process_markdown_images_with_llm(content: str, bucket_name: str, file_name:
# Handle URL images
if img_path.startswith(("http://", "https://")):
try:
img_path = download_image_from_url(img_path)
local_img_path = download_image_from_url(img_path)
except Exception as e:
logger.error(f"Error downloading image from URL {img_path}: {e}")
result += match.group(1)
last_end = end
continue
else:
logger.error(f"Image path {img_path} is not a URL")
result += match.group(1)
last_end = end
continue

# Get context
context_start = max(0, start - 200)
context_end = min(len(content), end + 200)
context = f"{content[context_start:start]}\n<image>\n{image_tag}\n</image>\n{content[end:context_end]}"

# Check if image was already processed
s3_link = image_s3_mapping.get(img_path)

# Process the image
understanding = process_single_image(img_path, context, image_tag, bucket_name, file_name, idx)
understanding, updated_s3_link = process_single_image(
local_img_path,
context,
image_tag,
bucket_name,
file_name,
idx,
s3_link,
)

# If this is a new image path, store its S3 object name
if not s3_link and understanding:
image_s3_mapping[img_path] = updated_s3_link

if understanding:
result += f"\n\n{understanding}\n\n"
Expand Down

0 comments on commit e8e0802

Please sign in to comment.