Skip to content

Commit

Permalink
Initial updates
Browse files Browse the repository at this point in the history
  • Loading branch information
iliaschalkidis committed Jul 29, 2024
1 parent 3a6ea21 commit cd8c9b3
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 16 deletions.
43 changes: 33 additions & 10 deletions augment_data/generate_questions.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,15 @@
from typing import List
import argparse
import json
import re
from nltk.tokenize import sent_tokenize, word_tokenize


SYSTEM_PROMPT = 'You are a helpful AI assistant with expertise on EU politics.'
INSTRUCTION = 'This is a statement by an MEP of the {} political group in the European Parliament: "{}".\n\n What do you think was the question asked? The question should start as "What is your opinion on..." and should not be longer than 16 words.'
INSTRUCTION = 'This is a public statement by an MEP of the {} political group in the European Parliament on {} for the debate titled "{}":\n\n"{}"\n\nWhat do you think was the question asked? The question should start as "What is your opinion on..." and should not be longer than 16 words.'
ASSISTANT_START = 'The question asked was likely: "What is your opinion on '


def truncate_text(text, max_length):
''' Truncate text to the maximum length '''
sentences = sent_tokenize(text)
Expand All @@ -29,15 +31,29 @@ def truncate_text(text, max_length):
break
return truncated_text


def date_iso_in_text(text):
''' Extarct date in ISO formatand generate full text date '''
date =text.split('-')
year = date[0]
month = date[1]
day = date[2]
# Numerical month to text month
month = 'January' if month == '01' else 'February' if month == '02' else 'March' if month == '03' else 'April' if month == '04' else 'May' if month == '05' else 'June' if month == '06' else 'July' if month == '07' else 'August' if month == '08' else 'September' if month == '09' else 'October' if month == '10' else 'November' if month == '11' else 'December'
# Generate full text date
full_text_date = f'{day}th of {month} {year}'
return full_text_date


def main():
''' set default hyperparams in default_hyperparams.py '''
parser = argparse.ArgumentParser()

# Required arguments
parser.add_argument('--model_name', default='meta-llama/Meta-Llama-3.1-8B-Instruct', help='Model name in HF Hub')
parser.add_argument('--max_length', default=64, type=int, help='Maximum length of the generated text')
parser.add_argument('--parties', default=['PPE', 'S&D', 'GUE/NGL'], type=List, help='List of party names to consider when filtering')
parser.add_argument('--debug', default=True, type=bool, help='Whether to use debug mode')
parser.add_argument('--parties', default=['S&D'], type=List, help='List of party names to consider when filtering')
parser.add_argument('--debug', default=False, type=bool, help='Whether to use debug mode')
config = parser.parse_args()

# Load eu-elections dataset
Expand All @@ -46,11 +62,14 @@ def main():
if config.debug:
print('Debugging mode activated')
config.model_name = 'gpt2'
tokenizer_name = 'meta-llama/Meta-Llama-3.1-8B-Instruct'
config.quant = False
config.max_length = 8
else:
tokenizer_name = config.model

# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(config.model_name, token=True)
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, token=True)

# Compute free memory for each GPU
if torch.cuda.is_available():
Expand Down Expand Up @@ -91,22 +110,26 @@ def main():
for example in tqdm.tqdm(dataset):
if example['speaker_party'] not in config.parties:
continue
text = example['text'] if example['translated_text'] is None else example['translated_text']
try:
# Truncate the text to the maximum length
if len(example['text'].split(' ')) < 100:
if len(text.split(' ')) < 100:
continue
elif len(example['text'].split(' ')) > 256:
truncated_text = truncate_text(example['text'], 256)
elif len(text.split(' ')) > 256:
truncated_text = truncate_text(text, 256)
else:
truncated_text = example['text']
truncated_text = text

# Print the instruction
example['debate_title'] = re.split('(\(debate\)|Video of)', example['debate_title'])[0].strip()
example['debate_title'] = re.split('\(', example['debate_title'], maxsplit=1)[0].strip()
example['full_date'] = date_iso_in_text(example['date'])
annotation_request = tokenizer.apply_chat_template(
conversation=[{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": INSTRUCTION.format(example['speaker_party'], truncated_text.strip())}],
{"role": "user", "content": INSTRUCTION.format(example['speaker_party'], example['full_date'], example['debate_title'], truncated_text.strip())}],
tokenize=False, add_generation_prompt=False)
annotation_request += ASSISTANT_START
print('INSTRUCTION:\n', annotation_request.split('user<|end_header_id|>\n\n ')[1].split('<|eot_id|><|start_header_id|>assistant<|end_header_id|>')[0].strip())
print('INSTRUCTION:\n', annotation_request.split('user<|end_header_id|>\n\n')[1].split('<|eot_id|><|start_header_id|>assistant<|end_header_id|>')[0].strip())
# Get the response from the chatbot
responses = pipeline(
annotation_request,
Expand Down
7 changes: 5 additions & 2 deletions finetune_llms/finetune_llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def main():
parser.add_argument('--seed', default=42, type=int)
parser.add_argument('--output_extension', default='sd-2014', help='Output extension for output directory')
parser.add_argument('--pseudo_qa', default=True, type=bool, help='Whether to turn the text into a pseudo question')
parser.add_argument('--debug', default=True, type=bool, help='Whether to infer summaries')
parser.add_argument('--debug', default=False, type=bool, help='Whether to use debug mode')
param_config = parser.parse_args()

# Setup logging
Expand All @@ -73,8 +73,11 @@ def main():
if param_config.debug:
print('Debugging mode activated')
param_config.model_name = 'gpt2'
tokenizer_name = 'meta-llama/Meta-Llama-3.1-8B-Instruct'
param_config.quant = False
param_config.max_length = 8
else:
tokenizer_name = param_config.model

# Fix parties' list
param_config.party_names = param_config.party_names.split(',') if param_config.party_names is not None else None
Expand Down Expand Up @@ -112,7 +115,7 @@ def main():
model.config.use_cache = False
model.config.pretraining_tp = 1

tokenizer = AutoTokenizer.from_pretrained(param_config.model_name)
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

Expand Down
3 changes: 1 addition & 2 deletions run_scripts/finetune_llms.sh
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,4 @@ MODEL_PATH='meta-llama/Meta-Llama-3.1-8B-Instruct'
export PYTHONPATH=.

python ./finetune_llms/finetune_llms.py \
--model_name ${MODEL_PATH} \
--debug false
--model_name ${MODEL_PATH}
3 changes: 1 addition & 2 deletions run_scripts/generate_questions.sh
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,4 @@ export PYTHONPATH=.

python ./augment_data/generate_questions.py \
--model_name ${MODEL_PATH} \
--max_length 64 \
--debug false
--max_length 64

0 comments on commit cd8c9b3

Please sign in to comment.