From dfb85cfb0f4376fe1e714b452a8cc4702b354a96 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E6=98=9F=E6=BC=A9?= Date: Mon, 3 Jun 2024 17:31:07 +0800 Subject: [PATCH] solve #4 --- utils/retrieval/flashcard.py | 107 +++++++++++++++++++++++++++++++++++ 1 file changed, 107 insertions(+) create mode 100644 utils/retrieval/flashcard.py diff --git a/utils/retrieval/flashcard.py b/utils/retrieval/flashcard.py new file mode 100644 index 0000000..cb8d9cc --- /dev/null +++ b/utils/retrieval/flashcard.py @@ -0,0 +1,107 @@ +# flashcard +from typing import Optional, Any +import transformers +import torch +from peft import PeftModel +from transformers.utils import is_accelerate_available, is_bitsandbytes_available +from transformers import ( + AutoModel, + AutoTokenizer, + AutoModelForCausalLM, + GenerationConfig, + pipeline, +) +import re +import utils.globalvar +import datasets + + +def formatting_prompts_func(ipt): + text = f"### Instruction: Answer the question truthfully.\n### Input: {ipt}\n### Output: " + return text + +### Query Generation ############################################### +def llama2_pipeline(prompt): + base_model = "meta-llama/Llama-2-7b-hf" + peft_model = "veggiebird/llama-2-7b-medical-flashcards-8bit" + + # load the model only once + if utils.globalvar.bio_model is None: + utils.globalvar.bio_model = AutoModelForCausalLM.from_pretrained( + base_model, + use_safetensors=True, + torch_dtype=torch.float16, + load_in_8bit=True + ) + + utils.globalvar.bio_model = PeftModel.from_pretrained(utils.globalvar.bio_model, peft_model) + + utils.globalvar.bio_tokenizer = AutoTokenizer.from_pretrained(base_model) + + print("Model loaded...") + pipeline = transformers.pipeline( + "text-generation", + model=utils.globalvar.bio_model, + tokenizer=utils.globalvar.bio_tokenizer, + torch_dtype=torch.float16, + device_map="auto", + ) + + sequences = pipeline( + prompt, + do_sample=False, + top_k=10, + num_return_sequences=1, + eos_token_id=utils.globalvar.bio_tokenizer.eos_token_id, + max_length=256, + ) + + return sequences[0]["generated_text"].strip() + +############################################### + + +### Query Knowl. ############################################### +def extract_responses(content): + pattern = r"### Output:(.+?)###" + matches = re.findall(pattern, content, re.DOTALL) + return [match.strip() for match in matches] + + +def generate_flashcard_query(input): + prompt = formatting_prompts_func(input) + query = llama2_pipeline(prompt) + processed_query = extract_responses(query) + return query, processed_query + + +def execute_flashcard_query(query): + model = AutoModel.from_pretrained("princeton-nlp/sup-simcse-roberta-large") + tokenizer = AutoTokenizer.from_pretrained("princeton-nlp/sup-simcse-roberta-large") + dataset = datasets.load_dataset('veggiebird/medical-flashcards') + dataset = dataset["train"] + dataset.add_faiss_index(column='embeddings') + + query_inputs = tokenizer(query, padding=True, truncation=True, return_tensors="pt") + query_embedding = model(**query_inputs, output_hidden_states=True, return_dict=True).pooler_output.detach().numpy() + scores, retrieved_Examples = dataset.get_nearest_examples("embeddings", query_embedding, k=1) + pre_knowl = retrieved_Examples["output"][0].strip() + try: + knowl = ' '.join(re.split(r'(?<=[.:;])\s', pre_knowl)[:3]) + except: + knowl = pre_knowl + return knowl + +############################################### + + +def retrieve_flashcard_knowledge(input, data_point): + knowl = "" + print("Generate query...") + query, processed_query = generate_flashcard_query(input) + if len(processed_query) != 0: + print("Query:", processed_query[0]) + print("Retrieve knowledge...") + knowl = execute_flashcard_query(processed_query[0]) + print(knowl) + return knowl \ No newline at end of file