Skip to content

Commit

Permalink
Merge pull request #5 from yukw777/videomae-baseline
Browse files Browse the repository at this point in the history
Add the training script for the VideoMAE baselines
  • Loading branch information
yukw777 authored Jan 25, 2024
2 parents ba421c7 + 97db704 commit 30c6b1f
Show file tree
Hide file tree
Showing 2 changed files with 290 additions and 0 deletions.
149 changes: 149 additions & 0 deletions scripts/baselines/videomae/videomae_train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
from dataclasses import dataclass
from typing import Any

import torch
import transformers
from pytorchvideo.transforms import (
ApplyTransformToKey,
Normalize,
Permute,
RandomShortSideScale,
UniformTemporalSubsample,
)
from torchmetrics.functional.classification import multiclass_f1_score
from torchvision.transforms import (
Compose,
Lambda,
RandomCrop,
RandomHorizontalFlip,
Resize,
)

from eilev.data.frame import FrameDataset


@dataclass
class ModelArguments:
model_name_or_path: str
num_frames: int
verb: bool


@dataclass
class DataArguments:
train_frames_dir: str
val_frames_dir: str
train_annotation_file: str = None # type: ignore
val_annotation_file: str = None # type: ignore


def train() -> None:
parser = transformers.HfArgumentParser(
(ModelArguments, DataArguments, transformers.TrainingArguments)
)
model_args, data_args, training_args = parser.parse_args_into_dataclasses()

training_args.remove_unused_columns = False
training_args.load_best_model_at_end = True

processor = transformers.VideoMAEImageProcessor.from_pretrained(
model_args.model_name_or_path
)
if "shortest_edge" in processor.size:
height = width = processor.size["shortest_edge"]
else:
height = processor.size["height"]
width = processor.size["width"]
train_data = FrameDataset(
data_args.train_frames_dir,
annotation_file=data_args.train_annotation_file,
transform=ApplyTransformToKey(
"video",
transform=Compose(
[
UniformTemporalSubsample(model_args.num_frames),
Lambda(lambda x: x * processor.rescale_factor),
Normalize(processor.image_mean, processor.image_std),
RandomShortSideScale(min_size=256, max_size=320),
RandomCrop((height, width)),
RandomHorizontalFlip(),
Permute((1, 0, 2, 3)),
]
),
),
)
val_data = FrameDataset(
data_args.val_frames_dir,
annotation_file=data_args.val_annotation_file,
transform=ApplyTransformToKey(
"video",
transform=Compose(
[
UniformTemporalSubsample(model_args.num_frames),
# Can't use VideoMAEImageProcessor here b/c it doesn't
# play nicely with Tensors, e.g., creating a tensor from
# a list of numpy.ndarrays, which is extremely slow.
Lambda(lambda x: x * processor.rescale_factor),
Normalize(processor.image_mean, processor.image_std),
Resize((height, width), antialias=True),
Permute((1, 0, 2, 3)),
]
),
),
)

# Can't use train_data and val_data here since their transform functions fail b/c
# we set return_frames to False
tmp_train_data = FrameDataset(
data_args.train_frames_dir,
annotation_file=data_args.train_annotation_file,
return_frames=False,
)
tmp_val_data = FrameDataset(
data_args.val_frames_dir,
annotation_file=data_args.val_annotation_file,
return_frames=False,
)
label_key = "structured_verb" if model_args.verb else "structured_noun"
labels = sorted({item[label_key] for item in iter(tmp_train_data + tmp_val_data)})
label2id = {label: i for i, label in enumerate(labels)}
id2label = {i: label for label, i in label2id.items()}

model = transformers.VideoMAEForVideoClassification.from_pretrained(
model_args.model_name_or_path,
low_cpu_mem_usage=True,
ignore_mismatched_sizes=True,
label2id=label2id,
id2label=id2label,
num_frames=model_args.num_frames,
)

def compute_metrics(eval_pred):
return {
"f1": multiclass_f1_score(
torch.tensor(eval_pred.predictions).argmax(dim=1),
torch.tensor(eval_pred.label_ids),
len(labels),
).item()
}

def collate_fn(examples: list[dict[str, Any]]):
pixel_values = torch.stack([example["video"] for example in examples])
labels = torch.tensor([label2id[example[label_key]] for example in examples])
return {"pixel_values": pixel_values, "labels": labels}

trainer = transformers.Trainer(
model=model,
args=training_args,
train_dataset=train_data,
eval_dataset=val_data,
compute_metrics=compute_metrics,
data_collator=collate_fn,
)
trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
model.save_pretrained(training_args.output_dir)
processor.save_pretrained(training_args.output_dir)


if __name__ == "__main__":
train()
141 changes: 141 additions & 0 deletions slurm-scripts/train/submit_videomae_train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
import argparse
import base64
import os
import subprocess

parser = argparse.ArgumentParser()
parser.add_argument("--account", required=True)
parser.add_argument("--partition", required=True)
parser.add_argument("--model", required=True)
parser.add_argument("--verb", action="store_true")
parser.add_argument("--num_gpus", required=True, type=int)
parser.add_argument("--mem_per_gpu", required=True)
parser.add_argument("--time", required=True)
parser.add_argument("--train_frames_dir", required=True)
parser.add_argument("--train_annotation_file")
parser.add_argument("--val_frames_dir", required=True)
parser.add_argument("--val_annotation_file")
parser.add_argument("--output_dir", required=True)
parser.add_argument("--dataloader_num_workers", type=int, required=True)
parser.add_argument("--train_batch_size", type=int, required=True)
parser.add_argument("--per_device_train_batch_size", type=int, required=True)
parser.add_argument("--per_device_eval_batch_size", type=int, required=True)
parser.add_argument("--num_train_epochs", type=int, default=5)
parser.add_argument("--email")
parser.add_argument("--transformers_cache")
parser.add_argument("--wandb_project", required=True)
parser.add_argument("--resume_from_checkpoint", default=None)
parser.add_argument("--deepspeed_stage_2", action="store_true")
parser.add_argument("--dry-run", action="store_true")
args = parser.parse_args()

email = ""
if args.email is not None:
email = f"#SBATCH --mail-user={args.email}\n#SBATCH --mail-type=BEGIN,END"
transformers_cache = ""
if args.transformers_cache is not None:
transformers_cache = f"export TRANSFORMERS_CACHE={args.transformers_cache}"
resume_from_checkpoint = ""
if args.resume_from_checkpoint is not None:
resume_from_checkpoint = f"--resume_from_checkpoint {args.resume_from_checkpoint}"

deepspeed = ""
if args.deepspeed_stage_2:
encoded_config = base64.urlsafe_b64encode(
b"""{
"bf16": {
"enabled": "auto"
},
"zero_optimization": {
"stage": 2,
"offload_optimizer": {
"device": "none",
"pin_memory": true
},
"allgather_partitions": true,
"allgather_bucket_size": 2e8,
"overlap_comm": true,
"reduce_scatter": true,
"reduce_bucket_size": 2e8,
"contiguous_gradients": true
},
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"steps_per_print": 2000,
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": false
}"""
).decode()
deepspeed = f"--deepspeed {encoded_config}"

gradient_accumulation_steps = (
args.train_batch_size // args.per_device_train_batch_size // args.num_gpus
)


train_annotation_file = ""
if args.train_annotation_file is not None:
train_annotation_file = f"--train_annotation_file {args.train_annotation_file}"

val_annotation_file = ""
if args.val_annotation_file is not None:
val_annotation_file = f"--val_annotation_file {args.val_annotation_file}"

multi_gpu = f"""RDZV_ID=$RANDOM
MASTER_NODE=$(srun --nodes=1 --ntasks=1 hostname)
srun --cpus-per-task {args.dataloader_num_workers} poetry run torchrun --nnodes={args.num_gpus} --nproc_per_node=1 --rdzv-id=$RDZV_ID --rdzv-backend=c10d --rdzv-endpoint=$MASTER_NODE \\
../../scripts/baselines/videomae/videomae_train.py \\""" # noqa: E501

single_gpu = "poetry run python ../../scripts/baselines/videomae/videomae_train.py \\"

job_name = "train-" + args.model.split("/")[1] + ("-verb" if args.verb else "-noun")
output_dir = os.path.join(args.output_dir, job_name)

script = rf"""#!/bin/bash
#SBATCH --partition={args.partition}
#SBATCH --time={args.time}
#SBATCH --job-name={job_name}
{email}
#SBATCH --account={args.account}
#SBATCH --ntasks={args.num_gpus}
#SBATCH --gpus-per-task=1
#SBATCH --cpus-per-task={args.dataloader_num_workers}
#SBATCH --mem-per-gpu={args.mem_per_gpu}
#SBATCH --output=%x-%j.log
module load python/3.10.4 cuda
{transformers_cache}
export WANDB_PROJECT={args.wandb_project}
{single_gpu if args.num_gpus < 2 else multi_gpu}
--model_name_or_path {args.model} \
--num_frames 8 \
--train_frames_dir {args.train_frames_dir} \
{train_annotation_file} \
--val_frames_dir {args.val_frames_dir} \
{val_annotation_file} \
--verb {args.verb} \
--output_dir {output_dir} \
--num_train_epochs {args.num_train_epochs} \
--learning_rate 5e-5 \
--warmup_ratio 0.1 \
--per_device_train_batch_size {args.per_device_train_batch_size} \
--gradient_accumulation_steps {gradient_accumulation_steps} \
--ddp_find_unused_parameters False \
--per_device_eval_batch_size {args.per_device_eval_batch_size} \
--dataloader_num_workers {args.dataloader_num_workers} \
--bf16 True \
{deepspeed} \
--evaluation_strategy epoch \
--save_strategy epoch \
--save_total_limit 3 \
--logging_steps 10 \
--metric_for_best_model f1 \
--report_to wandb \
--run_name {job_name} \
{resume_from_checkpoint}
""" # noqa: E501
print(script)
if not args.dry_run:
subprocess.run(["sbatch"], input=script, text=True)

0 comments on commit 30c6b1f

Please sign in to comment.