Skip to content

Commit

Permalink
Don't disable text chunking when GPT4vision is enabled (#1355)
Browse files Browse the repository at this point in the history
* Dont disable chunking when using vision, graceful degrade

* Adding test
  • Loading branch information
pamelafox authored Mar 13, 2024
1 parent 87f2b9d commit d896376
Show file tree
Hide file tree
Showing 12 changed files with 185 additions and 33 deletions.
17 changes: 12 additions & 5 deletions app/backend/core/imageshelper.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import base64
import logging
import math
import os
import re
from io import BytesIO
from typing import Optional

from azure.core.exceptions import ResourceNotFoundError
from azure.storage.blob.aio import ContainerClient
from PIL import Image
from typing_extensions import Literal, Required, TypedDict
Expand All @@ -22,12 +24,17 @@ class ImageURL(TypedDict, total=False):

async def download_blob_as_base64(blob_container_client: ContainerClient, file_path: str) -> Optional[str]:
base_name, _ = os.path.splitext(file_path)
blob = await blob_container_client.get_blob_client(base_name + ".png").download_blob()

if not blob.properties:
image_filename = base_name + ".png"
try:
blob = await blob_container_client.get_blob_client(image_filename).download_blob()
if not blob.properties:
logging.warning(f"No blob exists for {image_filename}")
return None
img = base64.b64encode(await blob.readall()).decode("utf-8")
return f"data:image/png;base64,{img}"
except ResourceNotFoundError:
logging.warning(f"No blob exists for {image_filename}")
return None
img = base64.b64encode(await blob.readall()).decode("utf-8")
return f"data:image/png;base64,{img}"


async def fetch_image(blob_container_client: ContainerClient, result: Document) -> Optional[ImageURL]:
Expand Down
4 changes: 2 additions & 2 deletions scripts/prepdocs.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,15 +101,15 @@ def setup_list_file_strategy(
if datalake_filesystem is None or datalake_path is None:
raise ValueError("DataLake file system and path are required when using Azure Data Lake Gen2")
adls_gen2_creds: Union[AsyncTokenCredential, str] = azure_credential if datalake_key is None else datalake_key
logger.info(f"Using Data Lake Gen2 Storage Account {datalake_storage_account}")
logger.info("Using Data Lake Gen2 Storage Account: %s", datalake_storage_account)
list_file_strategy = ADLSGen2ListFileStrategy(
data_lake_storage_account=datalake_storage_account,
data_lake_filesystem=datalake_filesystem,
data_lake_path=datalake_path,
credential=adls_gen2_creds,
)
elif local_files:
logger.info(f"Using local files in {local_files}")
logger.info("Using local files: %s", local_files)
list_file_strategy = LocalListFileStrategy(path_pattern=local_files)
else:
raise ValueError("Either local_files or datalake_storage_account must be provided.")
Expand Down
15 changes: 9 additions & 6 deletions scripts/prepdocslib/blobmanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,14 @@ async def upload_blob(self, file: File) -> Optional[List[str]]:
# Re-open and upload the original file
with open(file.content.name, "rb") as reopened_file:
blob_name = BlobManager.blob_name_from_file_name(file.content.name)
logger.info(f"\tUploading blob for whole file -> {blob_name}")
logger.info("Uploading blob for whole file -> %s", blob_name)
await container_client.upload_blob(blob_name, reopened_file, overwrite=True)

if self.store_page_images and os.path.splitext(file.content.name)[1].lower() == ".pdf":
return await self.upload_pdf_blob_images(service_client, container_client, file)
if self.store_page_images:
if os.path.splitext(file.content.name)[1].lower() == ".pdf":
return await self.upload_pdf_blob_images(service_client, container_client, file)
else:
logger.info("File %s is not a PDF, skipping image upload", file.content.name)

return None

Expand All @@ -84,11 +87,11 @@ async def upload_pdf_blob_images(
try:
font = ImageFont.truetype("/usr/share/fonts/truetype/freefont/FreeMono.ttf", 20)
except OSError:
logger.info("\tUnable to find arial.ttf or FreeMono.ttf, using default font")
logger.info("Unable to find arial.ttf or FreeMono.ttf, using default font")

for i in range(page_count):
blob_name = BlobManager.blob_image_name_from_file_page(file.content.name, i)
logger.info(f"\tConverting page {i} to image and uploading -> {blob_name}")
logger.info("Converting page %s to image and uploading -> %s", i, blob_name)

doc = fitz.open(file.content.name)
page = doc.load_page(i)
Expand Down Expand Up @@ -154,7 +157,7 @@ async def remove_blob(self, path: Optional[str] = None):
)
) or (path is not None and blob_path == os.path.basename(path)):
continue
logger.info(f"\tRemoving blob {blob_path}")
logger.info("Removing blob %s", blob_path)
await container_client.delete_blob(blob_path)

@classmethod
Expand Down
7 changes: 6 additions & 1 deletion scripts/prepdocslib/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,11 @@ async def create_embedding_batch(self, texts: List[str]) -> List[List[float]]:
with attempt:
emb_response = await client.embeddings.create(model=self.open_ai_model_name, input=batch.texts)
embeddings.extend([data.embedding for data in emb_response.data])
logger.info(f"Batch Completed. Batch size {len(batch.texts)} Token count {batch.token_length}")
logger.info(
"Computed embeddings in batch. Batch size: %d, Token count: %d",
len(batch.texts),
batch.token_length,
)

return embeddings

Expand All @@ -111,6 +115,7 @@ async def create_embedding_single(self, text: str) -> List[float]:
):
with attempt:
emb_response = await client.embeddings.create(model=self.open_ai_model_name, input=text)
logger.info("Computed embedding for text section. Character count: %d", len(text))

return emb_response.data[0].embedding

Expand Down
15 changes: 10 additions & 5 deletions scripts/prepdocslib/filestrategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,21 @@


async def parse_file(
file: File, file_processors: dict[str, FileProcessor], category: Optional[str] = None
file: File,
file_processors: dict[str, FileProcessor],
category: Optional[str] = None,
image_embeddings: Optional[ImageEmbeddings] = None,
) -> List[Section]:
key = file.file_extension()
processor = file_processors.get(key)
if processor is None:
logger.info(f"Skipping '{file.filename()}', no parser found.")
logger.info("Skipping '%s', no parser found.", file.filename())
return []
logger.info(f"Parsing '{file.filename()}'")
logger.info("Ingesting '%s'", file.filename())
pages = [page async for page in processor.parser.parse(content=file.content)]
logger.info(f"Splitting '{file.filename()}' into sections")
logger.info("Splitting '%s' into sections", file.filename())
if image_embeddings:
logger.warning("Each page will be split into smaller chunks of text, but images will be of the entire page.")
sections = [
Section(split_page, content=file, category=category) for split_page in processor.splitter.split_pages(pages)
]
Expand Down Expand Up @@ -76,7 +81,7 @@ async def run(self):
files = self.list_file_strategy.list()
async for file in files:
try:
sections = await parse_file(file, self.file_processors, self.category)
sections = await parse_file(file, self.file_processors, self.category, self.image_embeddings)
if sections:
blob_sas_uris = await self.blob_manager.upload_blob(file)
blob_image_embeddings: Optional[List[List[float]]] = None
Expand Down
2 changes: 1 addition & 1 deletion scripts/prepdocslib/htmlparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ async def parse(self, content: IO) -> AsyncGenerator[Page, None]:
Returns:
Page: The parsed html Page.
"""
logger.info(f"\tExtracting text from '{content.name}' using local HTML parser (BeautifulSoup)")
logger.info("Extracting text from '%s' using local HTML parser (BeautifulSoup)", content.name)

data = content.read()
soup = BeautifulSoup(data, "html.parser")
Expand Down
2 changes: 1 addition & 1 deletion scripts/prepdocslib/listfilestrategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def check_md5(self, path: str) -> bool:
stored_hash = md5_f.read()

if stored_hash and stored_hash.strip() == existing_hash.strip():
logger.info(f"Skipping {path}, no changes detected.")
logger.info("Skipping %s, no changes detected.", path)
return True

# Write the hash
Expand Down
4 changes: 2 additions & 2 deletions scripts/prepdocslib/pdfparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class LocalPdfParser(Parser):
"""

async def parse(self, content: IO) -> AsyncGenerator[Page, None]:
logger.info(f"\tExtracting text from '{content.name}' using local PDF parser (pypdf)")
logger.info("Extracting text from '%s' using local PDF parser (pypdf)", content.name)

reader = PdfReader(content)
pages = reader.pages
Expand All @@ -46,7 +46,7 @@ def __init__(
self.credential = credential

async def parse(self, content: IO) -> AsyncGenerator[Page, None]:
logger.info(f"Extracting text from '{content.name}' using Azure Document Intelligence")
logger.info("Extracting text from '%s' using Azure Document Intelligence", content.name)

async with DocumentIntelligenceClient(
endpoint=self.endpoint, credential=self.credential
Expand Down
12 changes: 7 additions & 5 deletions scripts/prepdocslib/searchmanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def __init__(
self.search_images = search_images

async def create_index(self, vectorizers: Optional[List[VectorSearchVectorizer]] = None):
logger.info(f"Ensuring search index {self.search_info.index_name} exists")
logger.info("Ensuring search index %s exists", self.search_info.index_name)

async with self.search_info.create_search_index_client() as search_index_client:
fields = [
Expand Down Expand Up @@ -175,10 +175,10 @@ async def create_index(self, vectorizers: Optional[List[VectorSearchVectorizer]]
),
)
if self.search_info.index_name not in [name async for name in search_index_client.list_index_names()]:
logger.info(f"Creating {self.search_info.index_name} search index")
logger.info("Creating %s search index", self.search_info.index_name)
await search_index_client.create_index(index)
else:
logger.info(f"Search index {self.search_info.index_name} already exists")
logger.info("Search index %s already exists", self.search_info.index_name)

async def update_content(self, sections: List[Section], image_embeddings: Optional[List[List[float]]] = None):
MAX_BATCH_SIZE = 1000
Expand Down Expand Up @@ -220,7 +220,9 @@ async def update_content(self, sections: List[Section], image_embeddings: Option
await search_client.upload_documents(documents)

async def remove_content(self, path: Optional[str] = None, only_oid: Optional[str] = None):
logger.info(f"Removing sections from '{path or '<all>'}' from search index '{self.search_info.index_name}'")
logger.info(
"Removing sections from '{%s or '<all>'}' from search index '%s'", path, self.search_info.index_name
)
async with self.search_info.create_search_client() as search_client:
while True:
filter = None if path is None else f"sourcefile eq '{os.path.basename(path)}'"
Expand All @@ -233,6 +235,6 @@ async def remove_content(self, path: Optional[str] = None, only_oid: Optional[st
if not only_oid or document["oids"] == [only_oid]:
documents_to_remove.append({"id": document["id"]})
removed_docs = await search_client.delete_documents(documents_to_remove)
logger.info(f"\tRemoved {len(removed_docs)} sections from index")
logger.info("Removed %d sections from index", len(removed_docs))
# It can take a few seconds for search results to reflect changes, so wait a bit
await asyncio.sleep(2)
5 changes: 0 additions & 5 deletions scripts/prepdocslib/textsplitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,11 +133,6 @@ def split_page_by_max_tokens(self, page_num: int, text: str) -> Generator[SplitP
yield from self.split_page_by_max_tokens(page_num, second_half)

def split_pages(self, pages: List[Page]) -> Generator[SplitPage, None, None]:
# Chunking is disabled when using GPT4V. To be updated in the future.
if self.has_image_embeddings:
for i, page in enumerate(pages):
yield SplitPage(page_num=i, text=page.text)

def find_page(offset):
num_pages = len(pages)
for i in range(num_pages - 1):
Expand Down
34 changes: 34 additions & 0 deletions tests/test_blob_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,40 @@ async def mock_upload_blob(self, name, *args, **kwargs):
await blob_manager.upload_blob(f)


@pytest.mark.asyncio
@pytest.mark.skipif(sys.version_info.minor < 10, reason="requires Python 3.10 or higher")
async def test_upload_blob_no_image(monkeypatch, mock_env, caplog):
blob_manager = BlobManager(
endpoint=f"https://{os.environ['AZURE_STORAGE_ACCOUNT']}.blob.core.windows.net",
credential=MockAzureCredential(),
container=os.environ["AZURE_STORAGE_CONTAINER"],
account=os.environ["AZURE_STORAGE_ACCOUNT"],
resourceGroup=os.environ["AZURE_STORAGE_RESOURCE_GROUP"],
subscriptionId=os.environ["AZURE_SUBSCRIPTION_ID"],
store_page_images=True,
)

with NamedTemporaryFile(suffix=".xlsx") as temp_file:
f = File(temp_file.file)
filename = os.path.basename(f.content.name)

# Set up mocks used by upload_blob
async def mock_exists(*args, **kwargs):
return True

monkeypatch.setattr("azure.storage.blob.aio.ContainerClient.exists", mock_exists)

async def mock_upload_blob(self, name, *args, **kwargs):
assert name == filename
return True

monkeypatch.setattr("azure.storage.blob.aio.ContainerClient.upload_blob", mock_upload_blob)

with caplog.at_level("INFO"):
await blob_manager.upload_blob(f)
assert "skipping image upload" in caplog.text


@pytest.mark.asyncio
@pytest.mark.skipif(sys.version_info.minor < 10, reason="requires Python 3.10 or higher")
async def test_dont_remove_if_no_container(monkeypatch, mock_env, blob_manager):
Expand Down
101 changes: 101 additions & 0 deletions tests/test_fetch_image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import os

import aiohttp
import pytest
from azure.core.exceptions import ResourceNotFoundError
from azure.core.pipeline.transport import (
AioHttpTransportResponse,
AsyncHttpTransport,
HttpRequest,
)
from azure.storage.blob.aio import BlobServiceClient

from approaches.approach import Document
from core.imageshelper import fetch_image

from .mocks import MockAzureCredential


@pytest.mark.asyncio
async def test_content_file(monkeypatch, mock_env, mock_acs_search):
class MockAiohttpClientResponse404(aiohttp.ClientResponse):
def __init__(self, url, body_bytes, headers=None):
self._body = body_bytes
self._headers = headers
self._cache = {}
self.status = 404
self.reason = "Not Found"
self._url = url

class MockAiohttpClientResponse(aiohttp.ClientResponse):
def __init__(self, url, body_bytes, headers=None):
self._body = body_bytes
self._headers = headers
self._cache = {}
self.status = 200
self.reason = "OK"
self._url = url

class MockTransport(AsyncHttpTransport):
async def send(self, request: HttpRequest, **kwargs) -> AioHttpTransportResponse:
if request.url.endswith("notfound.png"):
raise ResourceNotFoundError(MockAiohttpClientResponse404(request.url, b""))
else:
return AioHttpTransportResponse(
request,
MockAiohttpClientResponse(
request.url,
b"test content",
{
"Content-Type": "application/octet-stream",
"Content-Range": "bytes 0-27/28",
"Content-Length": "28",
},
),
)

async def __aenter__(self):
return self

async def __aexit__(self, *args):
pass

async def open(self):
pass

async def close(self):
pass

# Then we can plug this into any SDK via kwargs:
blob_client = BlobServiceClient(
f"https://{os.environ['AZURE_STORAGE_ACCOUNT']}.blob.core.windows.net",
credential=MockAzureCredential(),
transport=MockTransport(),
retry_total=0, # Necessary to avoid unnecessary network requests during tests
)
blob_container_client = blob_client.get_container_client(os.environ["AZURE_STORAGE_CONTAINER"])

test_document = Document(
id="test",
content="test content",
embedding=[1, 2, 3],
image_embedding=[4, 5, 6],
oids=[],
groups=[],
captions=[],
category="",
sourcefile="test.pdf",
sourcepage="test.pdf#page2",
)
image_url = await fetch_image(blob_container_client, test_document)
assert image_url is not None
assert image_url["url"] == "data:image/png;base64,dGVzdCBjb250ZW50"
assert image_url["detail"] == "auto"

test_document.sourcepage = "notfound.pdf"
image_url = await fetch_image(blob_container_client, test_document)
assert image_url is None

test_document.sourcepage = ""
image_url = await fetch_image(blob_container_client, test_document)
assert image_url is None

0 comments on commit d896376

Please sign in to comment.