diff --git a/ingest.py b/ingest.py index 5819e0f8..7bd92e15 100644 --- a/ingest.py +++ b/ingest.py @@ -38,7 +38,7 @@ def load_single_document(file_path: str) -> Document: return loader.load()[0] except Exception as ex: file_log('%s loading error: \n%s' % (file_path, ex)) - return None + return None def load_document_batch(filepaths): logging.info("Loading document batch") @@ -56,15 +56,17 @@ def load_document_batch(filepaths): return (data_list, filepaths) -def load_documents(source_dir: str) -> list[Document]: +def load_documents(source_dir: str, extensions: str) -> list[Document]: + exts = extensions.split(',') + # Loads all documents from the source documents directory, including nested folders paths = [] for root, _, files in os.walk(source_dir): for file_name in files: - print('Importing: ' + file_name) file_extension = os.path.splitext(file_name)[1] source_file_path = os.path.join(root, file_name) - if file_extension in DOCUMENT_MAP.keys(): + if file_extension in exts and file_extension in DOCUMENT_MAP.keys(): + print('Importing: ' + file_name) paths.append(source_file_path) # Have at least one worker and at most INGEST_THREADS workers @@ -93,7 +95,7 @@ def load_documents(source_dir: str) -> list[Document]: docs.extend(contents) except Exception as ex: file_log('Exception: %s' % (ex)) - + return docs @@ -139,10 +141,27 @@ def split_documents(documents: list[Document]) -> tuple[list[Document], list[Doc ), help="Device to run on. (Default is cuda)", ) -def main(device_type): + +@click.option( + "--ingest_self", + default=False, + type=bool, + help="Ingest the current directory.", +) + +@click.option( + "--extensions", + default=','.join(DOCUMENT_MAP.keys()), + type=str, + help="List of extensions to ingest, e.g.: .md,.py,.pdf (Default is all supported extensions)", +) + +def main(device_type, ingest_self, extensions): + source_dir = SOURCE_DIRECTORY if not ingest_self else os.getcwd() + # Load documents and split in chunks - logging.info(f"Loading documents from {SOURCE_DIRECTORY}") - documents = load_documents(SOURCE_DIRECTORY) + logging.info(f"Loading documents from {source_dir}") + documents = load_documents(source_dir, extensions) text_documents, python_documents = split_documents(documents) text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200) python_splitter = RecursiveCharacterTextSplitter.from_language( @@ -171,7 +190,6 @@ def main(device_type): persist_directory=PERSIST_DIRECTORY, client_settings=CHROMA_SETTINGS, ) - if __name__ == "__main__":