Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: add explicit UTF-8 encoding for XML string conversion #524

Merged
merged 1 commit into from
Jan 23, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 19 additions & 36 deletions source/lambda/job/dep/llm_bot_dep/splitter_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,8 @@

import boto3
from langchain.docstore.document import Document
from langchain.text_splitter import (
TextSplitter,
)
from llm_bot_dep.constant import SplittingType, FigureNode
from langchain.text_splitter import TextSplitter
from llm_bot_dep.constant import FigureNode, SplittingType
from llm_bot_dep.storage_utils import save_content_to_s3
from lxml import etree

Expand All @@ -22,9 +20,7 @@ def _make_spacy_pipeline_for_splitting(pipeline: str) -> Any: # avoid importing
try:
import spacy
except ImportError:
raise ImportError(
"Spacy is not installed, please install it with `pip install spacy`."
)
raise ImportError("Spacy is not installed, please install it with `pip install spacy`.")
if pipeline == "sentencizer":
from spacy.lang.en import English

Expand All @@ -38,19 +34,15 @@ def _make_spacy_pipeline_for_splitting(pipeline: str) -> Any: # avoid importing
class NLTKTextSplitter(TextSplitter):
"""Splitting text using NLTK package."""

def __init__(
self, separator: str = "\n\n", language: str = "english", **kwargs: Any
) -> None:
def __init__(self, separator: str = "\n\n", language: str = "english", **kwargs: Any) -> None:
"""Initialize the NLTK splitter."""
super().__init__(**kwargs)
try:
from nltk.tokenize import sent_tokenize

self._tokenizer = sent_tokenize
except ImportError:
raise ImportError(
"NLTK is not installed, please install it with `pip install nltk`."
)
raise ImportError("NLTK is not installed, please install it with `pip install nltk`.")
self._separator = separator
self._language = language

Expand Down Expand Up @@ -138,11 +130,7 @@ def find_child(headers: dict, header_id: str):
level = headers[header_id]["level"]

for id, header in headers.items():
if (
header["level"] == level + 1
and id not in children
and header["parent"] == header_id
):
if header["level"] == level + 1 and id not in children and header["parent"] == header_id:
children.append(id)

return children
Expand Down Expand Up @@ -244,10 +232,7 @@ def _set_chunk_id(
else:
# Move one step to get the next chunk_id
same_heading_dict[current_heading] += 1
if (
len(id_index_dict[current_heading])
> same_heading_dict[current_heading]
):
if len(id_index_dict[current_heading]) > same_heading_dict[current_heading]:
metadata["chunk_id"] = id_index_dict[current_heading][
same_heading_dict[current_heading]
]
Expand Down Expand Up @@ -275,14 +260,12 @@ def _get_current_heading_list(self, current_heading, current_heading_level_map):
return ""

return joint_title_list

def split_text(self, text: Document) -> List[Document]:
if self.res_bucket is not None:
save_content_to_s3(s3, text, self.res_bucket, SplittingType.BEFORE.value)
else:
logger.warning(
"No resource bucket is defined, skip saving content into S3 bucket"
)
logger.warning("No resource bucket is defined, skip saving content into S3 bucket")

lines = text.page_content.strip().split("\n")
chunks = []
Expand Down Expand Up @@ -313,7 +296,7 @@ def split_text(self, text: Document) -> List[Document]:
current_heading, current_heading_level_map
)
current_heading = current_heading.replace("#", "").strip()

try:
self._set_chunk_id(
id_index_dict, current_heading, metadata, same_heading_dict
Expand All @@ -325,9 +308,7 @@ def split_text(self, text: Document) -> List[Document]:
id_prefix = str(uuid.uuid4())[:8]
metadata["chunk_id"] = f"$0-{id_prefix}"
if metadata["chunk_id"] in heading_hierarchy:
metadata["heading_hierarchy"] = heading_hierarchy[
metadata["chunk_id"]
]
metadata["heading_hierarchy"] = heading_hierarchy[metadata["chunk_id"]]
page_content = "\n".join(current_chunk_content)
metadata["complete_heading"] = current_heading_list
if have_figure:
Expand Down Expand Up @@ -359,10 +340,14 @@ def split_text(self, text: Document) -> List[Document]:
figure_description = xml_node.find(FigureNode.DESCRIPTION.value)
figure_value = xml_node.find(FigureNode.VALUE.value)
figure_s3_link = xml_node.findtext(FigureNode.LINK.value)
chunk_figure_content = etree.tostring(figure_description).decode("utf-8")
chunk_figure_content = etree.tostring(figure_description, encoding="utf-8").decode(
"utf-8"
)
if figure_value is not None:
chunk_figure_content += "\n" + etree.tostring(figure_value).decode("utf-8")

chunk_figure_content += "\n" + etree.tostring(
figure_value, encoding="utf-8"
).decode("utf-8")

figure_item = {}
figure_item["content_type"] = figure_type
figure_item["figure_path"] = figure_s3_link
Expand All @@ -385,9 +370,7 @@ def split_text(self, text: Document) -> List[Document]:
)
current_heading = current_heading.replace("#", "").strip()
try:
self._set_chunk_id(
id_index_dict, current_heading, metadata, same_heading_dict
)
self._set_chunk_id(id_index_dict, current_heading, metadata, same_heading_dict)
except KeyError:
logger.info(f"No standard heading found")
id_prefix = str(uuid.uuid4())[:8]
Expand Down
Loading