From 96457ea6c400a21f4894196618eab8a836db5b15 Mon Sep 17 00:00:00 2001 From: Puneet Saraswat <61435908+saraswatpuneet@users.noreply.github.com> Date: Sun, 10 Dec 2023 18:29:00 +0530 Subject: [PATCH] add email ingestor for better processing (#182) * add email ingestor for better processing * fix * fixes * handle emails and attachments * handle emails and attachments * email only returns tokens --- querent/collectors/email/email_collector.py | 9 +- querent/core/base_engine.py | 15 ++ .../ingestors/email}/__init__.py | 0 querent/ingestors/email/email_ingestor.py | 213 ++++++++++++++++++ querent/ingestors/email/email_reader.py | 87 +++++++ querent/ingestors/ingestor_manager.py | 4 +- querent/ingestors/pdfs/pdf_ingestor_v1.py | 25 +- querent/ingestors/texts/text_ingestor.py | 7 +- tests/collectors/__init__.py | 0 .../test_aws_collector.py | 0 .../test_azure_collector.py | 0 .../test_code_ingestor.py | 0 .../test_drive_collector.py | 0 .../test_dropbox_collector.py | 0 .../test_email_collector.py | 7 +- .../test_gcs_collector.py | 0 .../test_github_collector.py | 0 .../test_jira_collector.py | 0 .../test_local_collector.py | 0 .../test_slack_collector.py | 0 tests/ingestors/test_email_ingestor.py | 58 +++++ tests/ingestors/test_generic_ingestor.py | 5 +- tests/llm_tests/mock_llm_test.py | 4 + 23 files changed, 419 insertions(+), 15 deletions(-) rename {tests/collector_tests => querent/ingestors/email}/__init__.py (100%) create mode 100644 querent/ingestors/email/email_ingestor.py create mode 100644 querent/ingestors/email/email_reader.py create mode 100644 tests/collectors/__init__.py rename tests/{collector_tests => collectors}/test_aws_collector.py (100%) rename tests/{collector_tests => collectors}/test_azure_collector.py (100%) rename tests/{collector_tests => collectors}/test_code_ingestor.py (100%) rename tests/{collector_tests => collectors}/test_drive_collector.py (100%) rename tests/{collector_tests => collectors}/test_dropbox_collector.py (100%) rename tests/{collector_tests => collectors}/test_email_collector.py (89%) rename tests/{collector_tests => collectors}/test_gcs_collector.py (100%) rename tests/{collector_tests => collectors}/test_github_collector.py (100%) rename tests/{collector_tests => collectors}/test_jira_collector.py (100%) rename tests/{collector_tests => collectors}/test_local_collector.py (100%) rename tests/{collector_tests => collectors}/test_slack_collector.py (100%) create mode 100644 tests/ingestors/test_email_ingestor.py diff --git a/querent/collectors/email/email_collector.py b/querent/collectors/email/email_collector.py index 2fc6c7cc..a33c8c70 100644 --- a/querent/collectors/email/email_collector.py +++ b/querent/collectors/email/email_collector.py @@ -8,7 +8,10 @@ from querent.common import common_errors from querent.common.types.collected_bytes import CollectedBytes from querent.common.uri import Uri -from querent.config.collector.collector_config import CollectorBackend, EmailCollectorConfig +from querent.config.collector.collector_config import ( + CollectorBackend, + EmailCollectorConfig, +) from querent.logging.logger import setup_logger @@ -53,11 +56,11 @@ async def poll(self) -> AsyncGenerator[CollectedBytes, None]: message = response_part[1] yield CollectedBytes( data=message, - file=f"{self.config.imap_folder}/{i}.email", + file=f"{self.config.imap_username}:{self.config.imap_folder}/{i}.email", ) yield CollectedBytes( data=None, - file=f"{self.config.imap_folder}/{i}.email", + file=f"{self.config.imap_username}:{self.config.imap_folder}/{i}.email", eof=True, ) except imaplib.IMAP4.error as e: diff --git a/querent/core/base_engine.py b/querent/core/base_engine.py index df62a65a..8003487b 100644 --- a/querent/core/base_engine.py +++ b/querent/core/base_engine.py @@ -2,6 +2,7 @@ import asyncio from querent.callback.event_callback_dispatcher import EventCallbackDispatcher from querent.callback.event_callback_interface import EventCallbackInterface +from querent.common.types.ingested_images import IngestedImages from querent.common.types.ingested_messages import IngestedMessages from querent.common.types.ingested_tokens import IngestedTokens from querent.common.types.ingested_code import IngestedCode @@ -113,6 +114,18 @@ async def process_code(self, data: IngestedCode): """ raise NotImplementedError + @abstractmethod + async def process_images(self, data: IngestedImages): + """ + Process images asynchronously. + Args: + data (IngestedImages): The input data to process. + Returns: + EventState: The state of the event is set with the event type and the timestamp + of the event and set using `self.set_state(event_state)`. + """ + raise NotImplementedError + @abstractmethod def validate(self) -> bool: """ @@ -200,6 +213,8 @@ async def _inner_worker(): await self.process_tokens(data) elif isinstance(data, IngestedCode): await self.process_code(data) + elif isinstance(data, IngestedImages): + await self.process_images(data) else: raise Exception( f"Invalid data type {type(data)} for {self.__class__.__name__}. Supported type: {IngestedTokens, IngestedMessages}" diff --git a/tests/collector_tests/__init__.py b/querent/ingestors/email/__init__.py similarity index 100% rename from tests/collector_tests/__init__.py rename to querent/ingestors/email/__init__.py diff --git a/querent/ingestors/email/email_ingestor.py b/querent/ingestors/email/email_ingestor.py new file mode 100644 index 00000000..e3163b24 --- /dev/null +++ b/querent/ingestors/email/email_ingestor.py @@ -0,0 +1,213 @@ +from typing import List, AsyncGenerator +import uuid +from PIL import Image +import pybase64 +import pypdf +from querent.common import common_errors +from querent.common.types.collected_bytes import CollectedBytes +from querent.common.types.ingested_images import IngestedImages +from querent.ingestors.base_ingestor import BaseIngestor +from querent.ingestors.email.email_reader import EmailReader +from querent.ingestors.ingestor_factory import IngestorFactory +from querent.logging.logger import setup_logger +from querent.processors.async_processor import AsyncProcessor +from querent.config.ingestor.ingestor_config import IngestorBackend +from querent.common.types.ingested_tokens import IngestedTokens +import email +import pytesseract +from io import BytesIO +import io + + +class EmailIngestorFactory(IngestorFactory): + SUPPORTED_EXTENSIONS = {"email", "eml"} + + async def supports(self, file_extension: str) -> bool: + return file_extension.lower() in self.SUPPORTED_EXTENSIONS + + async def create( + self, file_extension: str, processors: List[AsyncProcessor] + ) -> BaseIngestor: + if not await self.supports(file_extension): + return None + return EmailIngestor(processors) + + +class EmailIngestor(BaseIngestor): + def __init__(self, processors: List[AsyncProcessor]): + super().__init__(IngestorBackend.Email) + self.processors = processors + self.logger = setup_logger(__name__, "EmailIngestor") + self.email_reader = EmailReader() + + async def ingest( + self, poll_function: AsyncGenerator[CollectedBytes, None] + ) -> AsyncGenerator[IngestedTokens, None]: + collected_bytes = b"" + current_file = None + try: + async for chunk_bytes in poll_function: + if chunk_bytes.is_error() or chunk_bytes.is_eof(): + continue + + if current_file is None: + current_file = chunk_bytes.file + elif current_file != chunk_bytes.file: + email = await self.extract_and_process_email( + CollectedBytes(file=current_file, data=collected_bytes) + ) + yield IngestedTokens( + file=current_file, + data=email, # Wrap line in a list + error=None, + ) + yield IngestedTokens( + file=current_file, + data=None, + error=None, + ) + collected_bytes = b"" + current_file = chunk_bytes.file + collected_bytes += chunk_bytes.data + except Exception as e: + yield IngestedTokens(file=current_file, data=None, error=f"Exception: {e}") + finally: + if current_file is not None: + email = await self.extract_and_process_email( + CollectedBytes(file=current_file, data=collected_bytes) + ) + yield IngestedTokens( + file=current_file, + data=email, # Wrap line in a list + error=None, + ) + yield IngestedTokens( + file=current_file, + data=None, + error=None, + ) + + async def extract_and_process_email( + self, collected_bytes: CollectedBytes + ) -> List[str]: + text = await self.extract_text_from_email(collected_bytes) + processed_text = await self.process_data(text) + return processed_text + + async def extract_text_from_email(self, collected_bytes: CollectedBytes) -> str: + text = "" + try: + msg = email.message_from_bytes(collected_bytes.data) + email_msg = {} + ( + email_msg["From"], + email_msg["To"], + email_msg["Date"], + email_msg["Subject"], + ) = self.email_reader.obtain_header(msg) + if msg.is_multipart(): + for part in msg.walk(): + content_type = part.get_content_type() + try: + body = part.get_payload(decode=True) + except Exception as e: + continue + if body is None: + continue + text += await self.handle_sub_part(part) + text += "\n" + else: + content_type = msg.get_content_type() + body = msg.get_payload(decode=True).decode() + if content_type == "text/plain": + text = self.email_reader.clean_email_body(body) + except Exception as e: + self.logger.error(f"Error extracting text from email: {e}") + return text + + async def handle_sub_part(self, sub_part): + sub_content_type = sub_part.get_content_type() + sub_content_disposition = str(sub_part.get("Content-Disposition")) + if ( + sub_content_type == "text/plain" + and "attachment" not in sub_content_disposition + ): + return self.email_reader.clean_email_body(sub_part.get_payload()) + elif "attachment" in sub_content_disposition: + # Handle attachment as needed + return await self.handle_attachment(sub_part) + elif sub_content_type == "text/html": + # Handle HTML content as needed + # You can choose to ignore or process it + return "" + else: + # Handle other content types as needed + return "" + + async def handle_attachment(self, attachment_part) -> str: + # Get attachment data and type + attachment_data = attachment_part.get_payload(decode=True) + attachment_type = attachment_part.get_content_type() + + # Check attachment type and handle accordingly + if attachment_type.startswith("image/"): + return await self.handle_image_attachment(attachment_data) + elif attachment_type == "application/pdf": + return await self.handle_pdf_attachment(attachment_data) + else: + # Handle other attachment types as needed + self.logger.warning(f"Unsupported attachment type: {attachment_type}") + + async def handle_image_attachment(self, image_data: bytes) -> str: + try: + ocr_text = await self.get_ocr_from_image(image_data) + return ocr_text + except Exception as e: + self.logger.error(f"Error handling image attachment: {e}") + + async def handle_pdf_attachment(self, pdf_data) -> str: + try: + pdf_text = await self.extract_and_process_pdf(pdf_data) + except Exception as e: + self.logger.error(f"Error handling PDF attachment: {e}") + return pdf_text + + async def get_ocr_from_image(self, image): + """Implement this to return ocr text of the image""" + image = Image.open(io.BytesIO(image)) + text = pytesseract.image_to_string(image) + return str(text).encode("utf-8").decode("unicode_escape") + + async def extract_and_process_pdf(self, pdf_data: bytes) -> str: + pdf_text = "" + try: + path = BytesIO(pdf_data) + loader = pypdf.PdfReader(path) + + for _, page in enumerate(loader.pages): + text = page.extract_text() + pdf_text += text + "\n" + pdf_text += await self.extract_images_and_ocr(page) + + except TypeError as exc: + self.logger.error(f"Exception while extracting email {exc}") + except Exception as exc: + self.logger.error(f"Exception while extracting email {exc}") + return pdf_text + + async def extract_images_and_ocr(self, page) -> str: + ocr_text = "" + try: + for image_path in page.images: + ocr_text += await self.get_ocr_from_image(image_path) + except Exception as e: + self.logger.error(f"Error extracting images and OCR: {e}") + return ocr_text + + async def process_data(self, text: str) -> List[str]: + if self.processors is None or len(self.processors) == 0: + return [text] + processed_data = text + for processor in self.processors: + processed_data = await processor.process_text(processed_data) + return processed_data diff --git a/querent/ingestors/email/email_reader.py b/querent/ingestors/email/email_reader.py new file mode 100644 index 00000000..42bcf4c1 --- /dev/null +++ b/querent/ingestors/email/email_reader.py @@ -0,0 +1,87 @@ +import os +import re +from email.header import decode_header + +from bs4 import BeautifulSoup + + +class EmailReader: + def clean_email_body(self, email_body): + """ + Function to clean the email body. + + Args: + email_body (str): The email body to be cleaned. + + Returns: + str: The cleaned email body. + """ + if email_body is None: + email_body = "" + email_body = BeautifulSoup(email_body, "html.parser") + email_body = email_body.get_text() + email_body = "".join(email_body.splitlines()) + email_body = " ".join(email_body.split()) + email_body = email_body.encode("ascii", "ignore") + email_body = email_body.decode("utf-8", "ignore") + email_body = re.sub(r"http\S+", "", email_body) + return email_body + + def clean(self, text): + """ + Function to clean the text. + + Args: + text (str): The text to be cleaned. + + Returns: + str: The cleaned text. + """ + return "".join(c if c.isalnum() else "_" for c in text) + + def obtain_header(self, msg): + """ + Function to obtain the header of the email. + + Args: + msg (email.message.Message): The email message. + + Returns: + str: The From field of the email. + """ + if msg["Subject"] is not None: + Subject, encoding = decode_header(msg["Subject"])[0] + else: + Subject = "" + encoding = "" + if isinstance(Subject, bytes): + try: + if encoding is not None: + Subject = Subject.decode(encoding) + else: + Subject = "" + except [LookupError] as err: + pass + From = msg["From"] + To = msg["To"] + Date = msg["Date"] + return From, To, Date, Subject + + def download_attachment(self, part, subject): + """ + Function to download the attachment from the email. + + Args: + part (email.message.Message): The email message. + subject (str): The subject of the email. + + Returns: + None + """ + filename = part.get_filename() + if filename: + folder_name = self.clean(subject) + if not os.path.isdir(folder_name): + os.mkdir(folder_name) + filepath = os.path.join(folder_name, filename) + open(filepath, "wb").write(part.get_payload(decode=True)) diff --git a/querent/ingestors/ingestor_manager.py b/querent/ingestors/ingestor_manager.py index a86ef154..0d4dcdc9 100644 --- a/querent/ingestors/ingestor_manager.py +++ b/querent/ingestors/ingestor_manager.py @@ -8,6 +8,7 @@ from querent.common.types.ingested_tokens import IngestedTokens from querent.config.ingestor.ingestor_config import IngestorBackend from querent.ingestors.base_ingestor import BaseIngestor +from querent.ingestors.email.email_ingestor import EmailIngestorFactory from querent.ingestors.ingestor_factory import IngestorFactory, UnsupportedIngestor from querent.ingestors.pdfs.pdf_ingestor_v1 import PdfIngestorFactory from querent.ingestors.texts.text_ingestor import TextIngestorFactory @@ -103,7 +104,8 @@ def __init__( IngestorBackend.HTML.value: HtmlIngestorFactory(), IngestorBackend.MP4.value: VideoIngestorFactory(), IngestorBackend.GITHUB.value: GithubIngestorFactory(), - IngestorBackend.Slack.value: TextIngestorFactory(), + IngestorBackend.Slack.value: TextIngestorFactory(is_token_stream=True), + IngestorBackend.Email.value: EmailIngestorFactory(), # Add more mappings as needed } self.file_caches = LRUCache(maxsize=cache_size) diff --git a/querent/ingestors/pdfs/pdf_ingestor_v1.py b/querent/ingestors/pdfs/pdf_ingestor_v1.py index a4a361fa..04e05c57 100644 --- a/querent/ingestors/pdfs/pdf_ingestor_v1.py +++ b/querent/ingestors/pdfs/pdf_ingestor_v1.py @@ -6,6 +6,7 @@ from querent.config.ingestor.ingestor_config import IngestorBackend from querent.ingestors.base_ingestor import BaseIngestor from querent.ingestors.ingestor_factory import IngestorFactory +from querent.logging.logger import setup_logger from querent.processors.async_processor import AsyncProcessor from querent.common import common_errors import uuid @@ -36,6 +37,7 @@ class PdfIngestor(BaseIngestor): def __init__(self, processors: List[AsyncProcessor]): super().__init__(IngestorBackend.PDF) self.processors = processors + self.logger = setup_logger(__name__, "PdfIngestor") async def ingest( self, poll_function: AsyncGenerator[CollectedBytes, None] @@ -78,7 +80,11 @@ async def ingest( yield IngestedTokens(file=current_file, data=None, error=None) except Exception as exc: - yield None + yield IngestedTokens( + file=current_file, + data=None, + error=f"Exception: {exc}", + ) async def extract_and_process_pdf( self, collected_bytes: CollectedBytes @@ -109,13 +115,13 @@ async def extract_and_process_pdf( yield image_result except TypeError as exc: - print("Exception while extracting ", exc) + self.logger.error(f"Exception while extracting pdf {exc}") raise common_errors.TypeError( f"Getting type error on this file {collected_bytes.file}" ) from exc except Exception as exc: - print("Exception ", exc) + self.logger.error(f"Exception while extracting pdf {exc}") raise common_errors.UnknownError( f"Getting unknown error while handling this file: {collected_bytes.file} error - {exc}" ) from exc @@ -134,7 +140,16 @@ async def extract_images_and_ocr(self, page, page_num, text, data, file_path): ocr_text=ocr, ) except Exception as e: - print("Exception in pdf extracter ", e) + self.logger.error(f"Error extracting images and OCR: {e}") + yield IngestedImages( + file=file_path, + image=pybase64.b64encode(data), + image_name=uuid.uuid4(), + page_num=page_num, + text=text, + coordinates=None, + ocr_text=None, + ) async def get_ocr_from_image(self, image): """Implement this to return ocr text of the image""" @@ -143,6 +158,8 @@ async def get_ocr_from_image(self, image): return str(text).encode("utf-8").decode("unicode_escape") async def process_data(self, text: str) -> List[str]: + if self.processors == None or len(self.processors) == 0: + return [text] processed_data = text for processor in self.processors: processed_data = await processor.process_text(processed_data) diff --git a/querent/ingestors/texts/text_ingestor.py b/querent/ingestors/texts/text_ingestor.py index a9788d5c..d985b637 100644 --- a/querent/ingestors/texts/text_ingestor.py +++ b/querent/ingestors/texts/text_ingestor.py @@ -9,7 +9,10 @@ class TextIngestorFactory(IngestorFactory): - SUPPORTED_EXTENSIONS = {"txt", ""} + SUPPORTED_EXTENSIONS = {"txt", "slack", ""} + + def __init__(self, is_token_stream=False): + self.is_token_stream = is_token_stream async def supports(self, file_extension: str) -> bool: return file_extension.lower() in self.SUPPORTED_EXTENSIONS @@ -19,7 +22,7 @@ async def create( ) -> BaseIngestor: if not await self.supports(file_extension): return None - return TextIngestor(processors, file_extension == "") + return TextIngestor(processors, self.is_token_stream) class TextIngestor(BaseIngestor): diff --git a/tests/collectors/__init__.py b/tests/collectors/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/collector_tests/test_aws_collector.py b/tests/collectors/test_aws_collector.py similarity index 100% rename from tests/collector_tests/test_aws_collector.py rename to tests/collectors/test_aws_collector.py diff --git a/tests/collector_tests/test_azure_collector.py b/tests/collectors/test_azure_collector.py similarity index 100% rename from tests/collector_tests/test_azure_collector.py rename to tests/collectors/test_azure_collector.py diff --git a/tests/collector_tests/test_code_ingestor.py b/tests/collectors/test_code_ingestor.py similarity index 100% rename from tests/collector_tests/test_code_ingestor.py rename to tests/collectors/test_code_ingestor.py diff --git a/tests/collector_tests/test_drive_collector.py b/tests/collectors/test_drive_collector.py similarity index 100% rename from tests/collector_tests/test_drive_collector.py rename to tests/collectors/test_drive_collector.py diff --git a/tests/collector_tests/test_dropbox_collector.py b/tests/collectors/test_dropbox_collector.py similarity index 100% rename from tests/collector_tests/test_dropbox_collector.py rename to tests/collectors/test_dropbox_collector.py diff --git a/tests/collector_tests/test_email_collector.py b/tests/collectors/test_email_collector.py similarity index 89% rename from tests/collector_tests/test_email_collector.py rename to tests/collectors/test_email_collector.py index 2811f91c..ecc3beb2 100644 --- a/tests/collector_tests/test_email_collector.py +++ b/tests/collectors/test_email_collector.py @@ -2,7 +2,10 @@ import pytest import os from querent.collectors.collector_resolver import CollectorResolver -from querent.config.collector.collector_config import CollectorBackend, EmailCollectorConfig +from querent.config.collector.collector_config import ( + CollectorBackend, + EmailCollectorConfig, +) from querent.common.uri import Uri import uuid from dotenv import load_dotenv @@ -41,7 +44,7 @@ async def poll_and_print(): if chunk is not None: counter += 1 - assert counter == 1 + assert counter == 2 await poll_and_print() diff --git a/tests/collector_tests/test_gcs_collector.py b/tests/collectors/test_gcs_collector.py similarity index 100% rename from tests/collector_tests/test_gcs_collector.py rename to tests/collectors/test_gcs_collector.py diff --git a/tests/collector_tests/test_github_collector.py b/tests/collectors/test_github_collector.py similarity index 100% rename from tests/collector_tests/test_github_collector.py rename to tests/collectors/test_github_collector.py diff --git a/tests/collector_tests/test_jira_collector.py b/tests/collectors/test_jira_collector.py similarity index 100% rename from tests/collector_tests/test_jira_collector.py rename to tests/collectors/test_jira_collector.py diff --git a/tests/collector_tests/test_local_collector.py b/tests/collectors/test_local_collector.py similarity index 100% rename from tests/collector_tests/test_local_collector.py rename to tests/collectors/test_local_collector.py diff --git a/tests/collector_tests/test_slack_collector.py b/tests/collectors/test_slack_collector.py similarity index 100% rename from tests/collector_tests/test_slack_collector.py rename to tests/collectors/test_slack_collector.py diff --git a/tests/ingestors/test_email_ingestor.py b/tests/ingestors/test_email_ingestor.py new file mode 100644 index 00000000..bc1f394a --- /dev/null +++ b/tests/ingestors/test_email_ingestor.py @@ -0,0 +1,58 @@ +import asyncio +import pytest +import os +from querent.collectors.collector_resolver import CollectorResolver +from querent.config.collector.collector_config import ( + CollectorBackend, + EmailCollectorConfig, +) +from querent.common.uri import Uri +import uuid +from dotenv import load_dotenv + +from querent.ingestors.ingestor_manager import IngestorFactoryManager + +load_dotenv() + + +@pytest.fixture +def email_config(): + return EmailCollectorConfig( + backend=CollectorBackend.Email, + id=str(uuid.uuid4()), + imap_server="imap.gmail.com", # "imap.gmail.com + imap_port=993, + imap_username="puneet@querent.xyz", + imap_password=os.getenv("IMAP_PASSWORD"), + imap_folder="[Gmail]/Drafts", + imap_certfile=None, + imap_keyfile=None, + ) + + +@pytest.mark.asyncio +async def test_email_ingestor(email_config): + uri = Uri("email://") + resolver = CollectorResolver() + collector = resolver.resolve(uri, email_config) + assert collector is not None + await collector.connect() + # Set up the ingestor + ingestor_factory_manager = IngestorFactoryManager() + ingestor_factory = await ingestor_factory_manager.get_factory("email") + ingestor = await ingestor_factory.create("email", []) + ingested_call = ingestor.ingest(collector.poll()) + + async def poll_and_print(): + counter = 0 + async for ingested in ingested_call: + assert ingested is not None + if ingested is not "" or ingested is not None: + counter += 1 + assert counter == 4 + + await poll_and_print() # Notice the use of await here + + +if __name__ == "__main__": + asyncio.run(test_email_ingestor()) diff --git a/tests/ingestors/test_generic_ingestor.py b/tests/ingestors/test_generic_ingestor.py index d7c5b02d..35d74bb5 100644 --- a/tests/ingestors/test_generic_ingestor.py +++ b/tests/ingestors/test_generic_ingestor.py @@ -36,12 +36,11 @@ async def test_collect_and_ingest_generic_bytes(): # Set up the ingestor ingestor_factory_manager = IngestorFactoryManager() - ingestor_factory = await ingestor_factory_manager.get_factory("") - ingestor = await ingestor_factory.create("", []) + ingestor_factory = await ingestor_factory_manager.get_factory("slack") + ingestor = await ingestor_factory.create("slack", []) # Collect and ingest the PDF ingested_call = ingestor.ingest(collector.poll()) - counter = 0 async def poll_and_print(): counter = 0 diff --git a/tests/llm_tests/mock_llm_test.py b/tests/llm_tests/mock_llm_test.py index 6aa04492..da043653 100644 --- a/tests/llm_tests/mock_llm_test.py +++ b/tests/llm_tests/mock_llm_test.py @@ -1,6 +1,7 @@ import pytest from querent.callback.event_callback_interface import EventCallbackInterface from querent.common.types.ingested_code import IngestedCode +from querent.common.types.ingested_images import IngestedImages from querent.common.types.ingested_messages import IngestedMessages from querent.common.types.ingested_tokens import IngestedTokens from querent.common.types.querent_event import EventState, EventType @@ -41,6 +42,9 @@ async def process_code(self, data: IngestedCode): def process_messages(self, data: IngestedMessages): return super().process_messages(data) + def process_images(self, data: IngestedImages): + return super().process_images(data) + def validate(self): return True