Skip to content

Commit

Permalink
Fix bitsandbytes h100 issue (#93)
Browse files Browse the repository at this point in the history
* Make load_in_8bit optional

* Fix typo

---------

Co-authored-by: Dev P5 <[email protected]>
  • Loading branch information
tleyden and Dev P5 authored Jun 6, 2024
1 parent 699aedd commit 4431d51
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
4 changes: 2 additions & 2 deletions dalm/datasets/qa_gen/question_answer_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,11 +114,11 @@ def split_dataset(


def generate_qa_from_dataset(
dataset: Dataset, passage_column_name: str, title_column_name: str, sample_size: int, batch_size: int
dataset: Dataset, passage_column_name: str, title_column_name: str, sample_size: int, batch_size: int, load_in_8bit: bool = True
) -> DatasetDict:
logger.info(f"Generating question answer pairs with batch size: {batch_size}")
tokenizer = AutoTokenizer.from_pretrained(QA_MODEL)
model = AutoModelForSeq2SeqLM.from_pretrained(QA_MODEL, device_map="auto", load_in_8bit=True)
model = AutoModelForSeq2SeqLM.from_pretrained(QA_MODEL, device_map="auto", load_in_8bit=load_in_8bit)
# shuffle data
dataset.shuffle(seed=42)
# select a subset
Expand Down
2 changes: 1 addition & 1 deletion dalm/models/rag_e2e_base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def __init__(
) -> None:
super(AutoModelForRagE2E, self).__init__()

# Retriver initialization
# Retriever initialization
self.retriever_model = AutoModel.from_pretrained(
retriever_name,
quantization_config=AutoModelForRagE2E.__get_bnb_config()
Expand Down

0 comments on commit 4431d51

Please sign in to comment.