Skip to content

Commit f9b6d2c

Browse files
committed
Enable fine tuning on HPU
1 parent 0338b35 commit f9b6d2c

File tree

7 files changed

+158
-18
lines changed

7 files changed

+158
-18
lines changed

src/instructlab/training/accelerator.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,12 @@
33
from typing import Callable, Optional
44

55
# Third Party
6-
from accelerate import Accelerator as TransformersAccel
6+
from instructlab.training.hpu_utils import is_torch_hpu_available
7+
if is_torch_hpu_available():
8+
from optimum.habana.accelerate import GaudiAccelerator as TransformersAccel
9+
else:
10+
from accelerate import Accelerator as TransformersAccel
11+
712
from torch.utils.data import DataLoader
813
from transformers import get_scheduler
914
import torch
@@ -124,7 +129,11 @@ def get_fsdp_config(self):
124129
from functools import partial
125130

126131
# Third Party
127-
from accelerate.utils import FullyShardedDataParallelPlugin
132+
if is_torch_hpu_available():
133+
from optimum.habana.accelerate.utils import GaudiFullyShardedDataParallelPlugin
134+
else:
135+
from accelerate.utils import FullyShardedDataParallelPlugin
136+
128137
from peft.utils.other import fsdp_auto_wrap_policy
129138
from torch.distributed.fsdp import BackwardPrefetch, ShardingStrategy
130139
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload
@@ -152,14 +161,18 @@ def get_fsdp_config(self):
152161
prefetch_policy = (
153162
BackwardPrefetch.BACKWARD_POST if is_lora else BackwardPrefetch.BACKWARD_PRE
154163
)
155-
fsdp_plugin = FullyShardedDataParallelPlugin(
164+
fsdp_plugin = (GaudiFullyShardedDataParallelPlugin if is_torch_hpu_available() else FullyShardedDataParallelPlugin)(
156165
auto_wrap_policy=wrap_policy,
157166
limit_all_gathers=True,
158167
backward_prefetch=prefetch_policy,
159168
sharding_strategy=ShardingStrategy[self.fsdp_sharding_strategy],
160169
cpu_offload=CPUOffload(self.fsdp_cpu_offload_params),
161170
)
162171

172+
if is_torch_hpu_available():
173+
fsdp_plugin.use_orig_params=True
174+
fsdp_plugin.sync_module_states=True
175+
163176
# `use_orig_params` must be disabled when using LoRA and FSDP together
164177
# Source: https://huggingface.co/docs/peft/en/accelerate/fsdp#the-important-parts
165178
if self.model.lora_config is not None:

src/instructlab/training/hpu_utils.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import torch
2+
from functools import lru_cache
3+
4+
5+
@lru_cache(maxsize=None)
6+
def is_torch_hpu_available() -> bool:
7+
try:
8+
import habana_frameworks.torch.core # noqa: F401
9+
except ImportError:
10+
return False
11+
return True
12+
13+
14+
def simple_bucket(length):
15+
"""
16+
This bucket algorithm merely relies on the given number instead of based on
17+
slicing the known (min, max) range for several reasons:
18+
1) Due to the use of the first-fit-decreasing (FFD) algorithm, the
19+
(min, max) sequence length of each rank will be much smaller than the
20+
(min, max) sequence length of the dataset. Bucketing on the
21+
(min, max) sequence length of the dataset is not practical
22+
2) The (min, max) sequence length of a given rank is unknown until
23+
finishing 1 epoch since the packing is done on the fly
24+
3) Due to the shuffling, the (min, max) sequence length of a given rank
25+
may vary between ranks. Once the (min, max) sequence length of a
26+
given rank changes, the bucketing also needs adjustment
27+
28+
This bucket algorithm is based on the most significant set bit of the input number.
29+
It first check what’s the most significant set bit, assuming it's bit "S",
30+
and then slice the range [2 ** S, 2 ** (S+1)] into buckets with the same size.
31+
By default the range is divided into 16 buckets, so the bucket size will be
32+
2 ** (S - 4)
33+
For example, 0b10001 will be padded to 0b10010.
34+
This approach can limit the overhead of bucketing (at most 1/16 of the input
35+
number) and also prevent recompilation due to a too small bucket size.
36+
"""
37+
l = length
38+
msb = 0
39+
while l > 0:
40+
msb += 1
41+
l = l // 2
42+
43+
align = (1 << (msb - 4)) if msb >= 4 else 1
44+
45+
return (length + align - 1) // align * align
46+
47+
48+
def bucket(length):
49+
return simple_bucket(length)

src/instructlab/training/main_ds.py

Lines changed: 45 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,14 @@
3333
UserWarning,
3434
)
3535

36+
from instructlab.training.hpu_utils import is_torch_hpu_available
37+
38+
if is_torch_hpu_available():
39+
import habana_frameworks.torch.core as htcore
40+
import habana_frameworks.torch.distributed.hccl
41+
from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi
42+
adapt_transformers_to_gaudi()
43+
3644
# Third Party
3745
from tqdm import tqdm
3846
from transformers import AutoConfig
@@ -139,10 +147,19 @@ def train(
139147
total_length = float(torch.tensor([batch.pop("total_length")]))
140148
if not args.use_dolomite:
141149
for k in batch:
142-
batch[k] = batch[k].to(local_rank)
150+
batch[k] = batch[k].to('hpu' if is_torch_hpu_available() else local_rank)
151+
152+
hpu_args = {}
153+
if is_torch_hpu_available():
154+
hpu_args = {
155+
"use_flash_attention":True,
156+
"lazy_mode":False,
157+
}
158+
143159
output = model(
144160
**batch,
145161
use_cache=False,
162+
**hpu_args,
146163
)
147164
loss = output.loss
148165
log_loss = loss.detach().item()
@@ -179,8 +196,14 @@ def train(
179196
elapsed_time = time.time() - start
180197
overall_throughput = args.samples_per_gpu * world_size / elapsed_time
181198
current_lr = accelerator.lr_scheduler.get_last_lr()[0]
182-
cuda_mem_allocated = torch.cuda.memory_allocated() / (1024**3)
183-
cuda_malloc_retries = torch.cuda.memory_stats()["num_alloc_retries"]
199+
200+
if is_torch_hpu_available():
201+
mem_allocated = torch.hpu.memory_allocated() / (1024**3)
202+
malloc_retries = 0
203+
else:
204+
mem_allocated = torch.cuda.memory_allocated() / (1024**3)
205+
malloc_retries = torch.cuda.memory_stats()["num_alloc_retries"]
206+
184207
global_grad_norm = (
185208
model.get_global_grad_norm()
186209
if hasattr(model, "get_global_grad_norm")
@@ -202,8 +225,8 @@ def train(
202225
"rank": torch.distributed.get_rank(),
203226
"overall_throughput": overall_throughput,
204227
"lr": current_lr,
205-
"cuda_mem_allocated": cuda_mem_allocated,
206-
"cuda_malloc_retries": cuda_malloc_retries,
228+
("hpu" if is_torch_hpu_available() else "cuda") + "_mem_allocated": mem_allocated,
229+
("hpu" if is_torch_hpu_available() else "cuda") + "_malloc_retries": malloc_retries,
207230
"num_loss_counted_tokens": int(num_loss_counted_tokens),
208231
"num_tokens_rank0": int(total_length),
209232
"batch_size": int(micro_batch_size),
@@ -236,7 +259,10 @@ def train(
236259
global_step += 1
237260
if local_rank == 0:
238261
inner_pb.update(1)
239-
torch.cuda.empty_cache()
262+
263+
if not is_torch_hpu_available():
264+
torch.cuda.empty_cache()
265+
240266
if args.checkpoint_at_epoch:
241267
base_logger.debug(f"Saving checkpoint at epoch {epoch}")
242268
save_checkpoint(
@@ -314,17 +340,24 @@ def main(args):
314340
args.model_type = model_conf.model_type
315341

316342
#### distributed init #####
317-
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
343+
if is_torch_hpu_available():
344+
torch.hpu.set_device(int(os.environ["LOCAL_RANK"]))
345+
else:
346+
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
347+
318348
args.local_rank = int(os.environ["LOCAL_RANK"])
319349

320350
timeout = _get_collective_timeout()
321-
if timeout is not None:
322-
torch.distributed.init_process_group(timeout=timeout)
323-
else:
324-
torch.distributed.init_process_group()
351+
backend = "hccl" if is_torch_hpu_available() else None
352+
torch.distributed.init_process_group(backend=backend, timeout=timeout)
325353

326354
args.global_rank = torch.distributed.get_rank()
327-
tensor = torch.ByteTensor([False]).cuda()
355+
356+
if is_torch_hpu_available():
357+
tensor = torch.ByteTensor([False]).to('hpu')
358+
else:
359+
tensor = torch.ByteTensor([False]).cuda()
360+
328361
torch.distributed.all_reduce(tensor)
329362
torch.distributed.barrier()
330363

src/instructlab/training/model.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@
3434
import torch
3535

3636
# First Party
37+
from instructlab.training.hpu_utils import is_torch_hpu_available
38+
3739
from instructlab.training.config import ( # Adjust this import if needed
3840
DistributedBackend,
3941
Optimizer,
@@ -78,6 +80,14 @@ def __init__(
7880

7981
def _post_model_init(self):
8082
"""Common initialization steps that should happen after model initialization."""
83+
84+
if is_torch_hpu_available() and os.getenv("HPU_ENABLE_TORCH_COMPILE", False):
85+
torch._dynamo.config.cache_size_limit = 10*1000
86+
torch._dynamo.config.accumulated_cache_size_limit = 20*1000
87+
self.model = torch.compile(self.model, backend="hpu_backend", dynamic=False)
88+
for layer in self.model.model.layers:
89+
layer.compile(backend="hpu_backend", dynamic=False)
90+
8191
self.reconcile_tokenizer()
8292
if self.lora_config:
8393
self.model = self.prepare_peft_model(
@@ -264,7 +274,11 @@ def _is_causal_lm_model(self) -> bool:
264274
bool: True if the model is a causal language model, False otherwise.
265275
"""
266276
# Third Party
267-
return "ForCausalLM" in self.model.__class__.__name__
277+
if not is_torch_hpu_available():
278+
class_name = self.model.__class__.__name__
279+
else:
280+
class_name = self.model._orig_mod.__class__.__name__ if self.model.__class__.__name__ == 'OptimizedModule' else self.model.__class__.__name__
281+
return "ForCausalLM" in class_name
268282

269283
def reconcile_tokenizer(self):
270284
if len(self.tokenizer) > self.model.config.vocab_size:
@@ -320,6 +334,17 @@ def reconcile_tokenizer(self):
320334
):
321335
self.model.config.eos_token_id = self.tokenizer.eos_token_id
322336

337+
if is_torch_hpu_available():
338+
model = self.model._orig_mod if self.model.__class__.__name__ == 'OptimizedModule' else self.model
339+
class_name = model.__class__.__name__
340+
341+
replace_no_split_modules = {
342+
'GaudiLlamaForCausalLM': ['GaudiLlamaDecoderLayer',]
343+
}
344+
345+
if class_name in replace_no_split_modules:
346+
model._no_split_modules = replace_no_split_modules[class_name]
347+
323348
if not self._is_causal_lm_model():
324349
raise ValueError(
325350
f"Model must be a causal language model, got {type(self.model)}"

src/instructlab/training/multipack_sampler.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@
3434
import torch
3535
import torch.distributed as dist
3636

37+
from instructlab.training.hpu_utils import is_torch_hpu_available, bucket
38+
3739

3840
def find_max_pack_len_with_padding(
3941
dataset,
@@ -68,9 +70,14 @@ def get_effective_samples_per_minibatch(num_tokens_per_gpu):
6870
6971
The function creates a sampler using the MultipackDistributedBatchSampler class, generates batches using the sampler, and then returns the ratio of the dataset size to the number of batches.
7072
"""
73+
lengths=dataset.get_lengths()
74+
if is_torch_hpu_available():
75+
bucket_v = np.vectorize(bucket)
76+
lengths = bucket_v(lengths)
77+
7178
sampler = MultipackDistributedBatchSampler(
7279
batch_max_length=num_tokens_per_gpu,
73-
lengths=dataset.get_lengths(),
80+
lengths=lengths,
7481
num_replicas=torch.distributed.get_world_size(),
7582
rank=torch.distributed.get_rank(),
7683
seed=seed,

src/instructlab/training/token_dataset.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from instructlab.training.multipack_sampler import MultipackDistributedBatchSampler
1414
from instructlab.training.utils import log_rank_0, make_collate_fn
1515

16+
from instructlab.training.hpu_utils import is_torch_hpu_available, bucket
1617

1718
class TokenDataset(Dataset):
1819
def __init__(self, data_path):
@@ -109,6 +110,10 @@ def setup_dataloader(
109110

110111
lengths = dataset.get_lengths()
111112
if sampler == "multipack":
113+
if is_torch_hpu_available():
114+
bucket_v = np.vectorize(bucket)
115+
lengths = bucket_v(lengths)
116+
112117
sampler = MultipackDistributedBatchSampler(
113118
batch_max_length=packing_max_batch_len,
114119
lengths=lengths,

src/instructlab/training/utils.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
TrainingArgs,
5252
)
5353
from instructlab.training.model import Model
54+
from instructlab.training.hpu_utils import is_torch_hpu_available, bucket
5455

5556
logger = logging.getLogger("instructlab.training")
5657

@@ -275,6 +276,9 @@ def pad_collate_fn(batch):
275276
lens = np.array([len(item["input_ids"]) for item in batch])
276277
max_len = max(lens)
277278

279+
if is_torch_hpu_available():
280+
max_len = bucket(max_len)
281+
278282
input_ids = torch.stack(
279283
[
280284
F.pad(
@@ -386,6 +390,7 @@ def reduce_sum_forward(
386390
output_attentions=output_attentions,
387391
output_hidden_states=output_hidden_states,
388392
return_dict=return_dict,
393+
**(_deprecated_arguments if is_torch_hpu_available() else {}),
389394
)
390395

391396
return_dict = isinstance(output, dict)
@@ -794,7 +799,10 @@ def set_random_seed(seed):
794799
random.seed(seed)
795800
np.random.seed(seed)
796801
torch.manual_seed(seed)
797-
torch.cuda.manual_seed_all(seed)
802+
if is_torch_hpu_available():
803+
torch.hpu.manual_seed_all(seed)
804+
else:
805+
torch.cuda.manual_seed_all(seed)
798806

799807

800808
def save_checkpoint(

0 commit comments

Comments
 (0)