Skip to content

Commit b43cddc

Browse files
committedFeb 24, 2025·
ljx adapt npu
1 parent e692eaa commit b43cddc

File tree

5 files changed

+60
-36
lines changed

5 files changed

+60
-36
lines changed
 

‎internlm/core/context/parallel_context.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import torch
1212
import torch.distributed as dist
1313

14-
from internlm.accelerator import get_accelerator
14+
from internlm.accelerator import AcceleratorType, get_accelerator
1515
from internlm.utils.common import SingletonMeta
1616
from internlm.utils.config import Config
1717
from internlm.utils.logger import get_logger
@@ -309,7 +309,10 @@ def init_global_dist(self, rank: int, world_size: int, backend: str, host: str,
309309
use_cpu (bool): whether to set up cpu process group.
310310
"""
311311
# initialize the default process group
312-
init_method = f"tcp://[{host}]:{port}"
312+
if internlm_accelerator.get_accelerator_backend() == AcceleratorType.GPU:
313+
init_method = f"tcp://[{host}]:{port}"
314+
else:
315+
init_method = f"tcp://{host}:{port}"
313316
dist.init_process_group(
314317
rank=rank,
315318
world_size=world_size,

‎internlm/core/engine.py

+40-29
Original file line numberDiff line numberDiff line change
@@ -3,23 +3,31 @@
33

44
# adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/engine
55

6-
from contextlib import nullcontext
76
from typing import List, Optional
87

98
import torch
10-
import transformer_engine.pytorch as te
119
from torch.nn import Module
1210
from torch.nn.modules.loss import _Loss
1311
from torch.optim.lr_scheduler import _LRScheduler
14-
from transformer_engine.common.recipe import DelayedScaling, Format
1512

13+
from internlm.accelerator import AcceleratorType, get_accelerator
1614
from internlm.core.context import ParallelMode
1715
from internlm.core.context import global_context as gpc
1816
from internlm.core.gradient_handler import BaseGradientHandler
1917
from internlm.solver.optimizer import BaseOptimizer
2018
from internlm.solver.schedulers import Beta2Scheduler
2119
from internlm.utils.common import get_batch_size, move_to_device
2220

21+
try:
22+
from contextlib import nullcontext
23+
24+
import transformer_engine.pytorch as te
25+
from transformer_engine.common.recipe import DelayedScaling, Format
26+
except ImportError:
27+
pass
28+
29+
internlm_accelerator = get_accelerator()
30+
2331

2432
class Engine:
2533
"""
@@ -83,27 +91,28 @@ def __init__(
8391
# build gradient handler
8492
self._gradient_handlers = gradient_handlers if gradient_handlers else []
8593

86-
# FP8 GEMM
87-
fp8_cfg = gpc.config.get("fp8", None)
88-
self.use_fp8 = fp8_cfg is not None
89-
self.fp8_recipe = None
90-
self.fp8_group = None
91-
if self.use_fp8:
92-
self.fp8_group = gpc.get_group(ParallelMode.GLOBAL)
93-
if fp8_cfg.format == "e4m3":
94-
fp8_format = Format.E4M3
95-
elif fp8_cfg.format == "hybrid":
96-
fp8_format = Format.HYBRID
97-
else:
98-
raise ValueError("The DelayedScaling recipe only supports E4M3 and HYBRID formats.")
99-
self.fp8_recipe = DelayedScaling(
100-
margin=fp8_cfg.margin,
101-
interval=fp8_cfg.interval,
102-
fp8_format=fp8_format,
103-
amax_history_len=fp8_cfg.amax_history_len,
104-
amax_compute_algo=fp8_cfg.amax_compute_algo,
105-
override_linear_precision=(False, False, not fp8_cfg.fp8_wgrad),
106-
)
94+
if internlm_accelerator.get_accelerator_backend() == AcceleratorType.GPU:
95+
# FP8 GEMM
96+
fp8_cfg = gpc.config.get("fp8", None)
97+
self.use_fp8 = fp8_cfg is not None
98+
self.fp8_recipe = None
99+
self.fp8_group = None
100+
if self.use_fp8:
101+
self.fp8_group = gpc.get_group(ParallelMode.GLOBAL)
102+
if fp8_cfg.format == "e4m3":
103+
fp8_format = Format.E4M3
104+
elif fp8_cfg.format == "hybrid":
105+
fp8_format = Format.HYBRID
106+
else:
107+
raise ValueError("The DelayedScaling recipe only supports E4M3 and HYBRID formats.")
108+
self.fp8_recipe = DelayedScaling(
109+
margin=fp8_cfg.margin,
110+
interval=fp8_cfg.interval,
111+
fp8_format=fp8_format,
112+
amax_history_len=fp8_cfg.amax_history_len,
113+
amax_compute_algo=fp8_cfg.amax_compute_algo,
114+
override_linear_precision=(False, False, not fp8_cfg.fp8_wgrad),
115+
)
107116

108117
@property
109118
def model(self):
@@ -193,11 +202,13 @@ def __call__(self, *args, **kwargs):
193202
Returns:
194203
torch.Tensor: The output of the model.
195204
"""
196-
with te.fp8_autocast(
197-
enabled=self.use_fp8, fp8_recipe=self.fp8_recipe, fp8_group=self.fp8_group
198-
) if self.use_fp8 else nullcontext():
199-
output = self.model(*args, **kwargs)
200-
return output
205+
if internlm_accelerator.get_accelerator_backend() == AcceleratorType.GPU:
206+
with te.fp8_autocast(
207+
enabled=self.use_fp8, fp8_recipe=self.fp8_recipe, fp8_group=self.fp8_group
208+
) if self.use_fp8 else nullcontext():
209+
output = self.model(*args, **kwargs)
210+
return output
211+
return self.model(*args, **kwargs)
201212

202213
def load_batch(self, data_iter, to_gpu=True):
203214
"""

‎internlm/core/trainer_builder.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import torch.distributed as dist
99
from torch.utils.data import DataLoader
1010

11+
from internlm.accelerator import AcceleratorType, get_accelerator
1112
from internlm.checkpoint.checkpoint_manager import CheckpointManager
1213
from internlm.core.context import ParallelMode
1314
from internlm.core.context import global_context as gpc
@@ -48,6 +49,7 @@
4849

4950
# global llm logger
5051
logger = logging.getLogger(__file__)
52+
internlm_accelerator = get_accelerator()
5153

5254

5355
class TrainerBuilder(Trainer):
@@ -114,7 +116,8 @@ def __init__(
114116
criterion = self._initialize_criterion()
115117

116118
# initialize cpu offload manager for selective checkpoint
117-
initialize_offload_manager(gpc.config.get("selective_checkpoint_offload", False))
119+
if internlm_accelerator.get_accelerator_backend() == AcceleratorType.GPU:
120+
initialize_offload_manager(gpc.config.get("selective_checkpoint_offload", False))
118121

119122
# initialize train state
120123
train_state = get_train_state(train_dl)

‎internlm/model/model_implementations/builder.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
from typing import List, Union
22

33
import torch
4-
import transformer_engine.pytorch as te
54
from torch import nn
65

6+
from internlm.accelerator import AcceleratorType, get_accelerator
77
from internlm.core.context import ParallelMode
88
from internlm.core.context import global_context as gpc
99
from internlm.core.parallel.shard import pipeline_parallel_sharding_wrapper
@@ -20,7 +20,13 @@
2020
from internlm.utils.logger import get_logger
2121
from internlm.utils.parallel import is_using_fsdp, is_using_hf, is_using_isp
2222

23+
try:
24+
import transformer_engine.pytorch as te
25+
except ImportError:
26+
pass
27+
2328
logger = get_logger(__file__)
29+
internlm_accelerator = get_accelerator()
2430

2531

2632
def simple_swap(model, device):
@@ -156,7 +162,8 @@ def traverse(module):
156162
else:
157163
traverse(model)
158164

159-
if gpc.config.get("fp8", None) is not None:
160-
simple_swap(model, fsdp_init_method)
165+
if internlm_accelerator.get_accelerator_backend() == AcceleratorType.GPU:
166+
if gpc.config.get("fp8", None) is not None:
167+
simple_swap(model, fsdp_init_method)
161168

162169
return model

‎setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def get_requires() -> List[str]:
2727

2828
extra_require = {
2929
"torch": ["torch>=2.1.0"],
30-
"torch-npu": ["torch==2.1.0", "torch-npu==2.1.0.post3"],
30+
"torch-npu": ["torch==2.1.0", "torch-npu==2.1.0.post3", "numpy==1.26.4", "scipy", "decorator"],
3131
}
3232

3333
setup(

0 commit comments

Comments
 (0)
Please sign in to comment.