forked from modal-labs/llm-finetuning
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
109 lines (89 loc) · 3.25 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
from modal import gpu, Mount
from common import stub, N_GPUS, GPU_MEM, BASE_MODELS, VOLUME_CONFIG
@stub.function(
volumes=VOLUME_CONFIG,
memory=1024 * 100,
timeout=3600 * 4,
)
def download(model_name: str):
from huggingface_hub import snapshot_download
from transformers.utils import move_cache
try:
snapshot_download(model_name, local_files_only=True)
print(f"Volume contains {model_name}.")
except FileNotFoundError:
print(f"Downloading {model_name} (no progress bar) ...")
snapshot_download(model_name)
move_cache()
print("Committing /pretrained directory (no progress bar) ...")
stub.pretrained_volume.commit()
def library_entrypoint(config):
from llama_recipes.finetuning import main
main(**config)
@stub.function(
volumes=VOLUME_CONFIG,
mounts=[
Mount.from_local_dir("./datasets", remote_path="/root"),
],
gpu=gpu.A100(count=N_GPUS, memory=GPU_MEM),
timeout=3600 * 12,
)
def train(train_kwargs):
from torch.distributed.run import elastic_launch, parse_args, config_from_args
torch_args = parse_args(["--nnodes", "1", "--nproc_per_node", str(N_GPUS), ""])
print(f"{torch_args=}\n{train_kwargs=}")
elastic_launch(
config=config_from_args(torch_args)[0],
entrypoint=library_entrypoint,
)(train_kwargs)
print("Committing results volume (no progress bar) ...")
stub.results_volume.commit()
@stub.local_entrypoint() # Runs locally to kick off remote training job.
def main(
dataset: str,
base: str = "chat7",
run_id: str = "",
num_epochs: int = 10,
batch_size: int = 16,
):
print(f"Welcome to Modal Llama fine-tuning.")
model_name = BASE_MODELS[base]
print(f"Syncing base model {model_name} to volume.")
download.remote(model_name)
if not run_id:
import secrets
run_id = f"{base}-{secrets.token_hex(3)}"
elif not run_id.startswith(base):
run_id = f"{base}-{run_id}"
print(f"Beginning run {run_id=}.")
train.remote(
{
"model_name": BASE_MODELS[base],
"output_dir": f"/results/{run_id}",
"batch_size_training": batch_size,
"lr": 3e-4,
"num_epochs": num_epochs,
"val_batch_size": 1,
# --- Dataset options ---
"dataset": "custom_dataset",
"custom_dataset.file": dataset,
# --- FSDP options ---
"enable_fsdp": True,
"low_cpu_fsdp": True, # Optimization for FSDP model loading (RAM won't scale with num GPUs)
"fsdp_config.use_fast_kernels": True, # Only works when FSDP is on
"fsdp_config.fsdp_activation_checkpointing": True, # Activation checkpointing for fsdp
"pure_bf16": True,
# --- Required for 70B ---
"fsdp_config.fsdp_cpu_offload": True,
"fsdp_peft_cpu_offload_for_save": True, # Experimental
# --- PEFT options ---
"use_peft": True,
"peft_method": "lora",
"lora_config.r": 8,
"lora_config.lora_alpha": 16,
}
)
print(f"Training completed {run_id=}.")
print(
f"Test: `modal run inference.py --base {base} --run-id {run_id} --prompt '...'`."
)