From cd8c9b3e21b3029b59bf7779d37f465d05eb1c0c Mon Sep 17 00:00:00 2001 From: kiddothe2b Date: Mon, 29 Jul 2024 10:59:20 +0200 Subject: [PATCH] Initial updates --- augment_data/generate_questions.py | 43 +++++++++++++++++++++++------- finetune_llms/finetune_llms.py | 7 +++-- run_scripts/finetune_llms.sh | 3 +-- run_scripts/generate_questions.sh | 3 +-- 4 files changed, 40 insertions(+), 16 deletions(-) diff --git a/augment_data/generate_questions.py b/augment_data/generate_questions.py index 46b05b2..cae2c75 100644 --- a/augment_data/generate_questions.py +++ b/augment_data/generate_questions.py @@ -8,13 +8,15 @@ from typing import List import argparse import json +import re from nltk.tokenize import sent_tokenize, word_tokenize SYSTEM_PROMPT = 'You are a helpful AI assistant with expertise on EU politics.' -INSTRUCTION = 'This is a statement by an MEP of the {} political group in the European Parliament: "{}".\n\n What do you think was the question asked? The question should start as "What is your opinion on..." and should not be longer than 16 words.' +INSTRUCTION = 'This is a public statement by an MEP of the {} political group in the European Parliament on {} for the debate titled "{}":\n\n"{}"\n\nWhat do you think was the question asked? The question should start as "What is your opinion on..." and should not be longer than 16 words.' ASSISTANT_START = 'The question asked was likely: "What is your opinion on ' + def truncate_text(text, max_length): ''' Truncate text to the maximum length ''' sentences = sent_tokenize(text) @@ -29,6 +31,20 @@ def truncate_text(text, max_length): break return truncated_text + +def date_iso_in_text(text): + ''' Extarct date in ISO formatand generate full text date ''' + date =text.split('-') + year = date[0] + month = date[1] + day = date[2] + # Numerical month to text month + month = 'January' if month == '01' else 'February' if month == '02' else 'March' if month == '03' else 'April' if month == '04' else 'May' if month == '05' else 'June' if month == '06' else 'July' if month == '07' else 'August' if month == '08' else 'September' if month == '09' else 'October' if month == '10' else 'November' if month == '11' else 'December' + # Generate full text date + full_text_date = f'{day}th of {month} {year}' + return full_text_date + + def main(): ''' set default hyperparams in default_hyperparams.py ''' parser = argparse.ArgumentParser() @@ -36,8 +52,8 @@ def main(): # Required arguments parser.add_argument('--model_name', default='meta-llama/Meta-Llama-3.1-8B-Instruct', help='Model name in HF Hub') parser.add_argument('--max_length', default=64, type=int, help='Maximum length of the generated text') - parser.add_argument('--parties', default=['PPE', 'S&D', 'GUE/NGL'], type=List, help='List of party names to consider when filtering') - parser.add_argument('--debug', default=True, type=bool, help='Whether to use debug mode') + parser.add_argument('--parties', default=['S&D'], type=List, help='List of party names to consider when filtering') + parser.add_argument('--debug', default=False, type=bool, help='Whether to use debug mode') config = parser.parse_args() # Load eu-elections dataset @@ -46,11 +62,14 @@ def main(): if config.debug: print('Debugging mode activated') config.model_name = 'gpt2' + tokenizer_name = 'meta-llama/Meta-Llama-3.1-8B-Instruct' config.quant = False config.max_length = 8 + else: + tokenizer_name = config.model # Load tokenizer and model - tokenizer = AutoTokenizer.from_pretrained(config.model_name, token=True) + tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, token=True) # Compute free memory for each GPU if torch.cuda.is_available(): @@ -91,22 +110,26 @@ def main(): for example in tqdm.tqdm(dataset): if example['speaker_party'] not in config.parties: continue + text = example['text'] if example['translated_text'] is None else example['translated_text'] try: # Truncate the text to the maximum length - if len(example['text'].split(' ')) < 100: + if len(text.split(' ')) < 100: continue - elif len(example['text'].split(' ')) > 256: - truncated_text = truncate_text(example['text'], 256) + elif len(text.split(' ')) > 256: + truncated_text = truncate_text(text, 256) else: - truncated_text = example['text'] + truncated_text = text # Print the instruction + example['debate_title'] = re.split('(\(debate\)|Video of)', example['debate_title'])[0].strip() + example['debate_title'] = re.split('\(', example['debate_title'], maxsplit=1)[0].strip() + example['full_date'] = date_iso_in_text(example['date']) annotation_request = tokenizer.apply_chat_template( conversation=[{"role": "system", "content": SYSTEM_PROMPT}, - {"role": "user", "content": INSTRUCTION.format(example['speaker_party'], truncated_text.strip())}], + {"role": "user", "content": INSTRUCTION.format(example['speaker_party'], example['full_date'], example['debate_title'], truncated_text.strip())}], tokenize=False, add_generation_prompt=False) annotation_request += ASSISTANT_START - print('INSTRUCTION:\n', annotation_request.split('user<|end_header_id|>\n\n ')[1].split('<|eot_id|><|start_header_id|>assistant<|end_header_id|>')[0].strip()) + print('INSTRUCTION:\n', annotation_request.split('user<|end_header_id|>\n\n')[1].split('<|eot_id|><|start_header_id|>assistant<|end_header_id|>')[0].strip()) # Get the response from the chatbot responses = pipeline( annotation_request, diff --git a/finetune_llms/finetune_llms.py b/finetune_llms/finetune_llms.py index cbe45b0..486acee 100644 --- a/finetune_llms/finetune_llms.py +++ b/finetune_llms/finetune_llms.py @@ -54,7 +54,7 @@ def main(): parser.add_argument('--seed', default=42, type=int) parser.add_argument('--output_extension', default='sd-2014', help='Output extension for output directory') parser.add_argument('--pseudo_qa', default=True, type=bool, help='Whether to turn the text into a pseudo question') - parser.add_argument('--debug', default=True, type=bool, help='Whether to infer summaries') + parser.add_argument('--debug', default=False, type=bool, help='Whether to use debug mode') param_config = parser.parse_args() # Setup logging @@ -73,8 +73,11 @@ def main(): if param_config.debug: print('Debugging mode activated') param_config.model_name = 'gpt2' + tokenizer_name = 'meta-llama/Meta-Llama-3.1-8B-Instruct' param_config.quant = False param_config.max_length = 8 + else: + tokenizer_name = param_config.model # Fix parties' list param_config.party_names = param_config.party_names.split(',') if param_config.party_names is not None else None @@ -112,7 +115,7 @@ def main(): model.config.use_cache = False model.config.pretraining_tp = 1 - tokenizer = AutoTokenizer.from_pretrained(param_config.model_name) + tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) tokenizer.pad_token = tokenizer.eos_token tokenizer.padding_side = "right" diff --git a/run_scripts/finetune_llms.sh b/run_scripts/finetune_llms.sh index a1b0159..fb45fb1 100644 --- a/run_scripts/finetune_llms.sh +++ b/run_scripts/finetune_llms.sh @@ -17,5 +17,4 @@ MODEL_PATH='meta-llama/Meta-Llama-3.1-8B-Instruct' export PYTHONPATH=. python ./finetune_llms/finetune_llms.py \ - --model_name ${MODEL_PATH} \ - --debug false + --model_name ${MODEL_PATH} \ No newline at end of file diff --git a/run_scripts/generate_questions.sh b/run_scripts/generate_questions.sh index 1b5a583..3385274 100644 --- a/run_scripts/generate_questions.sh +++ b/run_scripts/generate_questions.sh @@ -18,5 +18,4 @@ export PYTHONPATH=. python ./augment_data/generate_questions.py \ --model_name ${MODEL_PATH} \ - --max_length 64 \ - --debug false + --max_length 64