From ec595a55484d0c6c90ab906d67efc3f509100bf5 Mon Sep 17 00:00:00 2001 From: Dev P5 Date: Thu, 6 Jun 2024 13:22:22 +0000 Subject: [PATCH 1/2] Make load_in_8bit optional --- dalm/datasets/qa_gen/question_answer_generation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dalm/datasets/qa_gen/question_answer_generation.py b/dalm/datasets/qa_gen/question_answer_generation.py index fbce902..c9d360c 100644 --- a/dalm/datasets/qa_gen/question_answer_generation.py +++ b/dalm/datasets/qa_gen/question_answer_generation.py @@ -114,11 +114,11 @@ def split_dataset( def generate_qa_from_dataset( - dataset: Dataset, passage_column_name: str, title_column_name: str, sample_size: int, batch_size: int + dataset: Dataset, passage_column_name: str, title_column_name: str, sample_size: int, batch_size: int, load_in_8bit: bool = True ) -> DatasetDict: logger.info(f"Generating question answer pairs with batch size: {batch_size}") tokenizer = AutoTokenizer.from_pretrained(QA_MODEL) - model = AutoModelForSeq2SeqLM.from_pretrained(QA_MODEL, device_map="auto", load_in_8bit=True) + model = AutoModelForSeq2SeqLM.from_pretrained(QA_MODEL, device_map="auto", load_in_8bit=load_in_8bit) # shuffle data dataset.shuffle(seed=42) # select a subset From ca94701f19f62c7e85b1b2993a73597e8d5cc22b Mon Sep 17 00:00:00 2001 From: Dev P5 Date: Thu, 6 Jun 2024 13:22:32 +0000 Subject: [PATCH 2/2] Fix typo --- dalm/models/rag_e2e_base_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dalm/models/rag_e2e_base_model.py b/dalm/models/rag_e2e_base_model.py index d5d4176..0d446f2 100644 --- a/dalm/models/rag_e2e_base_model.py +++ b/dalm/models/rag_e2e_base_model.py @@ -31,7 +31,7 @@ def __init__( ) -> None: super(AutoModelForRagE2E, self).__init__() - # Retriver initialization + # Retriever initialization self.retriever_model = AutoModel.from_pretrained( retriever_name, quantization_config=AutoModelForRagE2E.__get_bnb_config()