|
3 | 3 |
|
4 | 4 | # adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/engine
|
5 | 5 |
|
6 |
| -from contextlib import nullcontext |
7 | 6 | from typing import List, Optional
|
8 | 7 |
|
9 | 8 | import torch
|
10 |
| -import transformer_engine.pytorch as te |
11 | 9 | from torch.nn import Module
|
12 | 10 | from torch.nn.modules.loss import _Loss
|
13 | 11 | from torch.optim.lr_scheduler import _LRScheduler
|
14 |
| -from transformer_engine.common.recipe import DelayedScaling, Format |
15 | 12 |
|
| 13 | +from internlm.accelerator import AcceleratorType, get_accelerator |
16 | 14 | from internlm.core.context import ParallelMode
|
17 | 15 | from internlm.core.context import global_context as gpc
|
18 | 16 | from internlm.core.gradient_handler import BaseGradientHandler
|
19 | 17 | from internlm.solver.optimizer import BaseOptimizer
|
20 | 18 | from internlm.solver.schedulers import Beta2Scheduler
|
21 | 19 | from internlm.utils.common import get_batch_size, move_to_device
|
22 | 20 |
|
| 21 | +try: |
| 22 | + from contextlib import nullcontext |
| 23 | + |
| 24 | + import transformer_engine.pytorch as te |
| 25 | + from transformer_engine.common.recipe import DelayedScaling, Format |
| 26 | +except ImportError: |
| 27 | + pass |
| 28 | + |
| 29 | +internlm_accelerator = get_accelerator() |
| 30 | + |
23 | 31 |
|
24 | 32 | class Engine:
|
25 | 33 | """
|
@@ -83,27 +91,28 @@ def __init__(
|
83 | 91 | # build gradient handler
|
84 | 92 | self._gradient_handlers = gradient_handlers if gradient_handlers else []
|
85 | 93 |
|
86 |
| - # FP8 GEMM |
87 |
| - fp8_cfg = gpc.config.get("fp8", None) |
88 |
| - self.use_fp8 = fp8_cfg is not None |
89 |
| - self.fp8_recipe = None |
90 |
| - self.fp8_group = None |
91 |
| - if self.use_fp8: |
92 |
| - self.fp8_group = gpc.get_group(ParallelMode.GLOBAL) |
93 |
| - if fp8_cfg.format == "e4m3": |
94 |
| - fp8_format = Format.E4M3 |
95 |
| - elif fp8_cfg.format == "hybrid": |
96 |
| - fp8_format = Format.HYBRID |
97 |
| - else: |
98 |
| - raise ValueError("The DelayedScaling recipe only supports E4M3 and HYBRID formats.") |
99 |
| - self.fp8_recipe = DelayedScaling( |
100 |
| - margin=fp8_cfg.margin, |
101 |
| - interval=fp8_cfg.interval, |
102 |
| - fp8_format=fp8_format, |
103 |
| - amax_history_len=fp8_cfg.amax_history_len, |
104 |
| - amax_compute_algo=fp8_cfg.amax_compute_algo, |
105 |
| - override_linear_precision=(False, False, not fp8_cfg.fp8_wgrad), |
106 |
| - ) |
| 94 | + if internlm_accelerator.get_accelerator_backend() == AcceleratorType.GPU: |
| 95 | + # FP8 GEMM |
| 96 | + fp8_cfg = gpc.config.get("fp8", None) |
| 97 | + self.use_fp8 = fp8_cfg is not None |
| 98 | + self.fp8_recipe = None |
| 99 | + self.fp8_group = None |
| 100 | + if self.use_fp8: |
| 101 | + self.fp8_group = gpc.get_group(ParallelMode.GLOBAL) |
| 102 | + if fp8_cfg.format == "e4m3": |
| 103 | + fp8_format = Format.E4M3 |
| 104 | + elif fp8_cfg.format == "hybrid": |
| 105 | + fp8_format = Format.HYBRID |
| 106 | + else: |
| 107 | + raise ValueError("The DelayedScaling recipe only supports E4M3 and HYBRID formats.") |
| 108 | + self.fp8_recipe = DelayedScaling( |
| 109 | + margin=fp8_cfg.margin, |
| 110 | + interval=fp8_cfg.interval, |
| 111 | + fp8_format=fp8_format, |
| 112 | + amax_history_len=fp8_cfg.amax_history_len, |
| 113 | + amax_compute_algo=fp8_cfg.amax_compute_algo, |
| 114 | + override_linear_precision=(False, False, not fp8_cfg.fp8_wgrad), |
| 115 | + ) |
107 | 116 |
|
108 | 117 | @property
|
109 | 118 | def model(self):
|
@@ -193,11 +202,13 @@ def __call__(self, *args, **kwargs):
|
193 | 202 | Returns:
|
194 | 203 | torch.Tensor: The output of the model.
|
195 | 204 | """
|
196 |
| - with te.fp8_autocast( |
197 |
| - enabled=self.use_fp8, fp8_recipe=self.fp8_recipe, fp8_group=self.fp8_group |
198 |
| - ) if self.use_fp8 else nullcontext(): |
199 |
| - output = self.model(*args, **kwargs) |
200 |
| - return output |
| 205 | + if internlm_accelerator.get_accelerator_backend() == AcceleratorType.GPU: |
| 206 | + with te.fp8_autocast( |
| 207 | + enabled=self.use_fp8, fp8_recipe=self.fp8_recipe, fp8_group=self.fp8_group |
| 208 | + ) if self.use_fp8 else nullcontext(): |
| 209 | + output = self.model(*args, **kwargs) |
| 210 | + return output |
| 211 | + return self.model(*args, **kwargs) |
201 | 212 |
|
202 | 213 | def load_batch(self, data_iter, to_gpu=True):
|
203 | 214 | """
|
|
0 commit comments