Skip to content

Commit

Permalink
Initial updates
Browse files Browse the repository at this point in the history
  • Loading branch information
iliaschalkidis committed Jul 26, 2024
1 parent 8092efa commit 3a6ea21
Showing 1 changed file with 8 additions and 5 deletions.
13 changes: 8 additions & 5 deletions augment_data/generate_questions.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,10 @@
import json
from nltk.tokenize import sent_tokenize, word_tokenize

PROMPT_PATTERN = '<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{}<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n {}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nThe question asked was likely: "What is your opinion on '

SYSTEM_PROMPT = PROMPT_PATTERN.format('You are a helpful AI assistant with expertise on EU politics.',
'This is a statement by an MEP of the {} political group in the European Parliament: "{}".\n\n What do you think was the question asked? The question should start as "What is your opinion on..." and should not be longer than 16 words.')

SYSTEM_PROMPT = 'You are a helpful AI assistant with expertise on EU politics.'
INSTRUCTION = 'This is a statement by an MEP of the {} political group in the European Parliament: "{}".\n\n What do you think was the question asked? The question should start as "What is your opinion on..." and should not be longer than 16 words.'
ASSISTANT_START = 'The question asked was likely: "What is your opinion on '

def truncate_text(text, max_length):
''' Truncate text to the maximum length '''
Expand Down Expand Up @@ -102,7 +101,11 @@ def main():
truncated_text = example['text']

# Print the instruction
annotation_request = SYSTEM_PROMPT.format(example['speaker_party'], truncated_text.strip())
annotation_request = tokenizer.apply_chat_template(
conversation=[{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": INSTRUCTION.format(example['speaker_party'], truncated_text.strip())}],
tokenize=False, add_generation_prompt=False)
annotation_request += ASSISTANT_START
print('INSTRUCTION:\n', annotation_request.split('user<|end_header_id|>\n\n ')[1].split('<|eot_id|><|start_header_id|>assistant<|end_header_id|>')[0].strip())
# Get the response from the chatbot
responses = pipeline(
Expand Down

0 comments on commit 3a6ea21

Please sign in to comment.