Skip to content

Commit d1cada3

Browse files
committed
Switch to "device" argument to enable hpu code path
Signed-off-by: Sergey Plotnikov <[email protected]>
1 parent f9b6d2c commit d1cada3

File tree

8 files changed

+119
-36
lines changed

8 files changed

+119
-36
lines changed

docs/hpu.md

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# InstructLab Training on HPU
2+
3+
## HPU specific changes
4+
Next changes are required to enable training on HPU:
5+
6+
|GPU|HPU|
7+
|---|---|
8+
|`from accelerate import Accelerator` | `from optimum.habana.accelerate import GaudiAccelerator`|
9+
|`from accelerate.utils import FullyShardedDataParallelPlugin` | `from optimum.habana.accelerate.utils import GaudiFullyShardedDataParallelPlugin` |
10+
11+
It is also recommended to use HPU optimized versions of transformers:
12+
13+
```python
14+
from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi
15+
adapt_transformers_to_gaudi()
16+
```
17+
18+
## Bucketing
19+
Multipack sampler implementation produces wide range of batches with different sample lengths and number of samples. Each of these combinations leads to graph recompilation and this recompilation takes time and slows down training. To reduce number of recompilations HPU implementation uses bucketing approach, when maximum sample length in batch is aligned to some predefined value. It is similar to padding but all samples in the batch are padded not to the longest sample but to the some slightly bigger value.
20+
21+
![bucketing vs. padding](./hpu_pic/bucketing_vs_padding.png)
22+
23+
24+
To compute bucked size, we use next algorithm:
25+
- Firstly, we find MSB of the longest sample in the batch, let's call it S.
26+
- Then we slice the range [2 ** S, 2 ** (S+1)] into 16 buckets of the same size.
27+
- Then we use top boundary of the smallest suitable bucked as padding value.
28+
29+
This approach limits overhead of the bucketing to 1/16 th of the longest sample and allows us to significantly reduce number of recompilations.
30+
31+
## How to run
32+
To run training make next changes to config file:
33+
```json
34+
train:
35+
device: hpu
36+
distributed_backend: fsdp
37+
fsdp_cpu_offload_optimizer: false
38+
is_padding_free: true
39+
pipeline: accelerated
40+
disable_flash_attn: true
41+
```
42+
43+
And use this command line:
44+
```bash
45+
ilab --config=./config.yaml model train --pipeline accelerated --data-path ./data.jsonl
46+
```
47+
48+

docs/hpu_pic/bucketing_vs_padding.png

30.9 KB
Loading

src/instructlab/training/accelerator.py

Lines changed: 26 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,6 @@
33
from typing import Callable, Optional
44

55
# Third Party
6-
from instructlab.training.hpu_utils import is_torch_hpu_available
7-
if is_torch_hpu_available():
8-
from optimum.habana.accelerate import GaudiAccelerator as TransformersAccel
9-
else:
10-
from accelerate import Accelerator as TransformersAccel
116

127
from torch.utils.data import DataLoader
138
from transformers import get_scheduler
@@ -37,6 +32,7 @@ def __init__(
3732
deepspeed_cpu_offload_optimizer_pin_memory: Optional[bool] = False,
3833
deepspeed_cpu_offload_optimizer_ratio: Optional[float] = None,
3934
fsdp_cpu_offload_params: Optional[bool] = False,
35+
device: Optional[str] = None,
4036
):
4137
self.samples_per_gpu = samples_per_gpu
4238
self.save_samples = save_samples
@@ -53,6 +49,7 @@ def __init__(
5349
deepspeed_cpu_offload_optimizer_ratio
5450
)
5551
self.fsdp_cpu_offload_params = fsdp_cpu_offload_params
52+
self.device_str = device
5653

5754
if self.distributed_framework == DistributedBackend.DEEPSPEED:
5855
# Standard
@@ -74,6 +71,12 @@ def __init__(
7471
"fsdp_plugin": self.get_fsdp_config(),
7572
"mixed_precision": "bf16",
7673
}
74+
75+
if device == "hpu":
76+
from optimum.habana.accelerate import GaudiAccelerator as TransformersAccel
77+
else:
78+
from accelerate import Accelerator as TransformersAccel
79+
7780
self.accelerator = TransformersAccel(
7881
**accel_args,
7982
)
@@ -129,11 +132,6 @@ def get_fsdp_config(self):
129132
from functools import partial
130133

131134
# Third Party
132-
if is_torch_hpu_available():
133-
from optimum.habana.accelerate.utils import GaudiFullyShardedDataParallelPlugin
134-
else:
135-
from accelerate.utils import FullyShardedDataParallelPlugin
136-
137135
from peft.utils.other import fsdp_auto_wrap_policy
138136
from torch.distributed.fsdp import BackwardPrefetch, ShardingStrategy
139137
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload
@@ -161,17 +159,27 @@ def get_fsdp_config(self):
161159
prefetch_policy = (
162160
BackwardPrefetch.BACKWARD_POST if is_lora else BackwardPrefetch.BACKWARD_PRE
163161
)
164-
fsdp_plugin = (GaudiFullyShardedDataParallelPlugin if is_torch_hpu_available() else FullyShardedDataParallelPlugin)(
165-
auto_wrap_policy=wrap_policy,
166-
limit_all_gathers=True,
167-
backward_prefetch=prefetch_policy,
168-
sharding_strategy=ShardingStrategy[self.fsdp_sharding_strategy],
169-
cpu_offload=CPUOffload(self.fsdp_cpu_offload_params),
170-
)
171162

172-
if is_torch_hpu_available():
163+
if self.device_str == "hpu":
164+
from optimum.habana.accelerate.utils import GaudiFullyShardedDataParallelPlugin
165+
fsdp_plugin = GaudiFullyShardedDataParallelPlugin(
166+
auto_wrap_policy=wrap_policy,
167+
limit_all_gathers=True,
168+
backward_prefetch=prefetch_policy,
169+
sharding_strategy=ShardingStrategy[self.fsdp_sharding_strategy],
170+
cpu_offload=CPUOffload(self.fsdp_cpu_offload_params),
171+
)
173172
fsdp_plugin.use_orig_params=True
174173
fsdp_plugin.sync_module_states=True
174+
else:
175+
from accelerate.utils import FullyShardedDataParallelPlugin
176+
fsdp_plugin = FullyShardedDataParallelPlugin(
177+
auto_wrap_policy=wrap_policy,
178+
limit_all_gathers=True,
179+
backward_prefetch=prefetch_policy,
180+
sharding_strategy=ShardingStrategy[self.fsdp_sharding_strategy],
181+
cpu_offload=CPUOffload(self.fsdp_cpu_offload_params),
182+
)
175183

176184
# `use_orig_params` must be disabled when using LoRA and FSDP together
177185
# Source: https://huggingface.co/docs/peft/en/accelerate/fsdp#the-important-parts

src/instructlab/training/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,3 +245,5 @@ class TrainingArgs(BaseModel):
245245
log_level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = Field(
246246
default="INFO"
247247
)
248+
249+
device: Optional[str] = None

src/instructlab/training/main_ds.py

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def train(
131131
if local_rank == 0:
132132
inner_pb = tqdm(range(num_epoch_steps), desc=f"Epoch {epoch}")
133133

134-
# blast through the batches in the train loader up to the last step within the epoch.
134+
# blast through the batches in the train loader up to the last step within the epoch.
135135
for batch in accelerator.train_loader:
136136
if global_step <= args.last_step:
137137
# in the case of resuming, last_step > 0
@@ -147,10 +147,10 @@ def train(
147147
total_length = float(torch.tensor([batch.pop("total_length")]))
148148
if not args.use_dolomite:
149149
for k in batch:
150-
batch[k] = batch[k].to('hpu' if is_torch_hpu_available() else local_rank)
150+
batch[k] = batch[k].to('hpu' if args.device == "hpu" else local_rank)
151151

152152
hpu_args = {}
153-
if is_torch_hpu_available():
153+
if args.device == "hpu":
154154
hpu_args = {
155155
"use_flash_attention":True,
156156
"lazy_mode":False,
@@ -197,7 +197,7 @@ def train(
197197
overall_throughput = args.samples_per_gpu * world_size / elapsed_time
198198
current_lr = accelerator.lr_scheduler.get_last_lr()[0]
199199

200-
if is_torch_hpu_available():
200+
if args.device == "hpu":
201201
mem_allocated = torch.hpu.memory_allocated() / (1024**3)
202202
malloc_retries = 0
203203
else:
@@ -225,8 +225,8 @@ def train(
225225
"rank": torch.distributed.get_rank(),
226226
"overall_throughput": overall_throughput,
227227
"lr": current_lr,
228-
("hpu" if is_torch_hpu_available() else "cuda") + "_mem_allocated": mem_allocated,
229-
("hpu" if is_torch_hpu_available() else "cuda") + "_malloc_retries": malloc_retries,
228+
("hpu" if args.device == "hpu" else "cuda") + "_mem_allocated": mem_allocated,
229+
("hpu" if args.device == "hpu" else "cuda") + "_malloc_retries": malloc_retries,
230230
"num_loss_counted_tokens": int(num_loss_counted_tokens),
231231
"num_tokens_rank0": int(total_length),
232232
"batch_size": int(micro_batch_size),
@@ -260,7 +260,7 @@ def train(
260260
if local_rank == 0:
261261
inner_pb.update(1)
262262

263-
if not is_torch_hpu_available():
263+
if args.device != "hpu":
264264
torch.cuda.empty_cache()
265265

266266
if args.checkpoint_at_epoch:
@@ -340,20 +340,20 @@ def main(args):
340340
args.model_type = model_conf.model_type
341341

342342
#### distributed init #####
343-
if is_torch_hpu_available():
343+
if args.device == "hpu":
344344
torch.hpu.set_device(int(os.environ["LOCAL_RANK"]))
345345
else:
346346
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
347347

348348
args.local_rank = int(os.environ["LOCAL_RANK"])
349349

350350
timeout = _get_collective_timeout()
351-
backend = "hccl" if is_torch_hpu_available() else None
351+
backend = "hccl" if args.device == "hpu" else None
352352
torch.distributed.init_process_group(backend=backend, timeout=timeout)
353353

354354
args.global_rank = torch.distributed.get_rank()
355355

356-
if is_torch_hpu_available():
356+
if args.device == "hpu":
357357
tensor = torch.ByteTensor([False]).to('hpu')
358358
else:
359359
tensor = torch.ByteTensor([False]).cuda()
@@ -407,6 +407,7 @@ def main(args):
407407
flash_enabled=flash_enabled,
408408
noise_alpha=args.NEFTune_alpha,
409409
lora_quant_bits=args.lora_quant_bits,
410+
device=args.device,
410411
)
411412

412413
args.base_model_args = m.base_model_args
@@ -446,6 +447,7 @@ def main(args):
446447
samples_per_gpu=args.samples_per_gpu,
447448
sampler=args.sampler,
448449
seed=args.seed,
450+
device=args.device,
449451
)
450452
if len(train_loader) == 0:
451453
# this happens sometimes when we have more GPUs than data to process. In this case
@@ -466,6 +468,7 @@ def main(args):
466468
samples_per_gpu=args.samples_per_gpu,
467469
sampler=args.sampler,
468470
seed=args.seed,
471+
device=args.device,
469472
)
470473

471474
if args.local_rank == 0:
@@ -497,6 +500,7 @@ def main(args):
497500
deepspeed_cpu_offload_optimizer_ratio=args.cpu_offload_optimizer_ratio,
498501
fsdp_cpu_offload_params=args.cpu_offload_params_fsdp,
499502
save_samples=args.save_samples,
503+
device=args.device,
500504
)
501505
# optimizer needs model that has been prepared by accelerator
502506
# and then accelerator needs to be prepared AGAIN once optimizer is initialized
@@ -679,6 +683,10 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None:
679683
if train_args.keep_last_checkpoint_only:
680684
command.append("--keep_last_checkpoint_only")
681685

686+
command.append(
687+
f"--device={train_args.device}"
688+
)
689+
682690
logger.info("Running training command as subprocess: %s", " ".join(command))
683691
process = None
684692
interrupt: KeyboardInterrupt | Exception | None = None
@@ -876,6 +884,14 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None:
876884
action="store_true",
877885
help="Use Liger kernels for training.",
878886
)
887+
888+
parser.add_argument(
889+
"--device",
890+
type=str,
891+
default=None,
892+
help="PyTorch device to use.",
893+
)
894+
879895
args = parser.parse_args()
880896
set_random_seed(args.seed)
881897
main(args)

src/instructlab/training/model.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@
3434
import torch
3535

3636
# First Party
37-
from instructlab.training.hpu_utils import is_torch_hpu_available
3837

3938
from instructlab.training.config import ( # Adjust this import if needed
4039
DistributedBackend,
@@ -52,11 +51,13 @@ def __init__(
5251
flash_enabled: bool = False,
5352
lora_config: Optional[LoraConfig] = None,
5453
lora_quant_bits: int = 0,
54+
device: Optional[str] = None,
5555
):
5656
self.lora_config = lora_config
5757
self.noise_alpha = noise_alpha
5858
self.tokenizer = tokenizer
5959
self.distributed_framework = distributed_framework
60+
self.device = device
6061
bnb_config = None
6162
if lora_config and lora_config.r > 0 and lora_quant_bits == 4:
6263
# Third Party
@@ -81,7 +82,7 @@ def __init__(
8182
def _post_model_init(self):
8283
"""Common initialization steps that should happen after model initialization."""
8384

84-
if is_torch_hpu_available() and os.getenv("HPU_ENABLE_TORCH_COMPILE", False):
85+
if self.device == "hpu" and os.getenv("HPU_ENABLE_TORCH_COMPILE", False):
8586
torch._dynamo.config.cache_size_limit = 10*1000
8687
torch._dynamo.config.accumulated_cache_size_limit = 20*1000
8788
self.model = torch.compile(self.model, backend="hpu_backend", dynamic=False)
@@ -274,7 +275,7 @@ def _is_causal_lm_model(self) -> bool:
274275
bool: True if the model is a causal language model, False otherwise.
275276
"""
276277
# Third Party
277-
if not is_torch_hpu_available():
278+
if self.device != "hpu":
278279
class_name = self.model.__class__.__name__
279280
else:
280281
class_name = self.model._orig_mod.__class__.__name__ if self.model.__class__.__name__ == 'OptimizedModule' else self.model.__class__.__name__
@@ -334,7 +335,7 @@ def reconcile_tokenizer(self):
334335
):
335336
self.model.config.eos_token_id = self.tokenizer.eos_token_id
336337

337-
if is_torch_hpu_available():
338+
if self.device == "hpu":
338339
model = self.model._orig_mod if self.model.__class__.__name__ == 'OptimizedModule' else self.model
339340
class_name = model.__class__.__name__
340341

@@ -410,6 +411,7 @@ def __init__(
410411
flash_enabled: bool = False,
411412
lora_config: Optional[LoraConfig] = None,
412413
lora_quant_bits: int = 0,
414+
device: Optional[str] = None,
413415
):
414416
super().__init__(
415417
model_path=model_path,
@@ -419,6 +421,7 @@ def __init__(
419421
flash_enabled=flash_enabled,
420422
lora_config=lora_config,
421423
lora_quant_bits=lora_quant_bits,
424+
device=device,
422425
)
423426
try:
424427
# Third Party
@@ -451,6 +454,7 @@ def __init__(
451454
flash_enabled: bool = False,
452455
lora_config: Optional[LoraConfig] = None,
453456
lora_quant_bits: int = 0,
457+
device: Optional[str] = None,
454458
):
455459
super().__init__(
456460
model_path=model_path,
@@ -460,6 +464,7 @@ def __init__(
460464
flash_enabled=flash_enabled,
461465
lora_config=lora_config,
462466
lora_quant_bits=lora_quant_bits,
467+
device=device,
463468
)
464469
# Third Party
465470
from instructlab.dolomite.hf_models import GPTDolomiteForCausalLM
@@ -494,6 +499,7 @@ def __init__(
494499
flash_enabled: bool = False,
495500
lora_config: Optional[LoraConfig] = None,
496501
lora_quant_bits: int = 0,
502+
device: Optional[str] = None,
497503
):
498504
super().__init__(
499505
model_path=model_path,
@@ -503,6 +509,7 @@ def __init__(
503509
flash_enabled=flash_enabled,
504510
lora_config=lora_config,
505511
lora_quant_bits=lora_quant_bits,
512+
device=device,
506513
)
507514
# Third Party
508515
from transformers import AutoModelForCausalLM

0 commit comments

Comments
 (0)