diff --git a/dalm/datasets/qa_gen/question_answer_generation.py b/dalm/datasets/qa_gen/question_answer_generation.py index fbce902..c9d360c 100644 --- a/dalm/datasets/qa_gen/question_answer_generation.py +++ b/dalm/datasets/qa_gen/question_answer_generation.py @@ -114,11 +114,11 @@ def split_dataset( def generate_qa_from_dataset( - dataset: Dataset, passage_column_name: str, title_column_name: str, sample_size: int, batch_size: int + dataset: Dataset, passage_column_name: str, title_column_name: str, sample_size: int, batch_size: int, load_in_8bit: bool = True ) -> DatasetDict: logger.info(f"Generating question answer pairs with batch size: {batch_size}") tokenizer = AutoTokenizer.from_pretrained(QA_MODEL) - model = AutoModelForSeq2SeqLM.from_pretrained(QA_MODEL, device_map="auto", load_in_8bit=True) + model = AutoModelForSeq2SeqLM.from_pretrained(QA_MODEL, device_map="auto", load_in_8bit=load_in_8bit) # shuffle data dataset.shuffle(seed=42) # select a subset diff --git a/dalm/models/rag_e2e_base_model.py b/dalm/models/rag_e2e_base_model.py index d5d4176..0d446f2 100644 --- a/dalm/models/rag_e2e_base_model.py +++ b/dalm/models/rag_e2e_base_model.py @@ -31,7 +31,7 @@ def __init__( ) -> None: super(AutoModelForRagE2E, self).__init__() - # Retriver initialization + # Retriever initialization self.retriever_model = AutoModel.from_pretrained( retriever_name, quantization_config=AutoModelForRagE2E.__get_bnb_config()