Skip to content

Commit

Permalink
Rewriting script
Browse files Browse the repository at this point in the history
  • Loading branch information
iliaschalkidis committed Aug 16, 2024
1 parent abf3d33 commit b97d3ea
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 5 deletions.
7 changes: 5 additions & 2 deletions augment_data/rewrite_speeches.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def main():
# Required arguments
parser.add_argument('--model_name', default='meta-llama/Meta-Llama-3.1-8B-Instruct', help='Model name in HF Hub')
parser.add_argument('--max_length', default=256, type=int, help='Maximum length of the generated text')
parser.add_argument('--start_idx', default=8980, type=int, help='Index of the first speech in the dataset')
parser.add_argument('--debug', default=False, type=bool, help='Whether to use debug mode')
config = parser.parse_args()

Expand Down Expand Up @@ -123,8 +124,10 @@ def main():

# Iterate over the examples in the dataset and save the responses
examples = 0
with open(os.path.join(DATA_DIR, 'eu_parliaments_extended_rewritten.json'), 'w') as f:
for example in tqdm.tqdm(dataset):
with open(os.path.join(DATA_DIR, f'eu_parliaments_extended_rewritten_{config.start_idx}.json'), 'w') as f:
for idx, example in tqdm.tqdm(enumerate(dataset)):
if idx <= config.start_idx:
continue
text = example['text'] if example['translated_text'] is None else example['translated_text']
if example['speaker_party'] not in party_dict.keys():
continue
Expand Down
7 changes: 4 additions & 3 deletions run_scripts/rewrite_speeches.sh
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
#!/bin/bash
#SBATCH --cpus-per-task=8 --mem=16000M
#SBATCH -p gpu --gres=gpu:a100:1
#SBATCH --output=/home/rwg642/eu-politics-llms-chronos/rewrite_speeches.txt
#SBATCH --time=24:00:00
#SBATCH --output=/home/rwg642/eu-politics-llms-chronos/rewrite_speeches_8980.txt
#SBATCH --time=150:00:00

. /etc/profile.d/modules.sh
eval "$(conda shell.bash hook)"
Expand All @@ -18,4 +18,5 @@ export PYTHONPATH=.

python ./augment_data/rewrite_speeches.py \
--model_name ${MODEL_PATH} \
--max_length 256
--max_length 256 \
--start_idx 8980

0 comments on commit b97d3ea

Please sign in to comment.