|
| 1 | +.. |
| 2 | + Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
| 3 | +
|
| 4 | + See LICENSE for license information. |
| 5 | + |
| 6 | +Getting started |
| 7 | +============== |
| 8 | + |
| 9 | +.. note:: |
| 10 | + |
| 11 | + Precision debug tools with `Nvidia-DL-Framework-Inspect <https://github.com/NVIDIA/nvidia-dlfw-inspect>`_ for Transformer Engine are currently supported only for PyTorch. |
| 12 | + |
| 13 | +Transformer Engine provides a set of precision debug tools which allow you to easily: |
| 14 | + |
| 15 | +- log the statistics for each of the tensors in every matrix multiply (GEMM) operation, |
| 16 | +- run selected GEMMs in higher precision, |
| 17 | +- run current scaling - with one scaling factor per tensor - for particular GEMMs, |
| 18 | +- test new precisions and integrate them with FP8 training, |
| 19 | +- ... and many more. |
| 20 | + |
| 21 | +There are 4 things one needs to do to use Transformer Engine debug features: |
| 22 | + |
| 23 | +1. Create a configuration YAML file to configure the desired features. |
| 24 | +2. Import, and initialize the `Nvidia-DL-Framework-Inspect <https://github.com/NVIDIA/nvidia-dlfw-inspect>`_ tool, which is installed as the dependency of the Transformer Engine. |
| 25 | +3. One can pass ``name="..."`` when creating TE layers to easier identify layer names. If this is not provided, names will be inferred automatically. |
| 26 | +4. Invoke ``debug_api.step()`` at the end of one forward-backward pass. |
| 27 | + |
| 28 | +To start debugging, one needs to create a configuration YAML file. This file lists the features to be used in particular layers. There are 2 kinds of features: |
| 29 | + |
| 30 | +- provided by the Transformer Engine - for example, DisableFP8GEMM or LogTensorStats - they are listed in the :doc:`debug features API <3_api_features>` section |
| 31 | +- defined by the user. For details on how to create a custom feature - please read the :doc:`calls to Nvidia-DL-Framework-Inspect <3_api_te_calls>` section. |
| 32 | + |
| 33 | +.. figure:: ./img/introduction.svg |
| 34 | + :align: center |
| 35 | + |
| 36 | + Fig 1: Example of Nvidia-DL-Framework-Inspect affecting training script with 3 TE Linear Layers. |
| 37 | + ``config.yaml`` contains the specification of the features used for each Linear layer. Some feature classes are provided by TE, |
| 38 | + one - ``UserProvidedPrecision`` - is a custom feature implemented by the user. Nvidia-DL-Framework-Inspect inserts features into the layers according to the config. |
| 39 | + |
| 40 | +Example training script |
| 41 | +---------------------- |
| 42 | + |
| 43 | +Let's look at a simple example of training a Transformer layer using Transformer Engine with FP8 precision. This example demonstrates how to set up the layer, define an optimizer, and perform a few training iterations using synthetic data. |
| 44 | + |
| 45 | +.. code-block:: python |
| 46 | +
|
| 47 | + # train.py |
| 48 | +
|
| 49 | + from transformer_engine.pytorch import TransformerLayer |
| 50 | + import torch |
| 51 | + import torch.nn as nn |
| 52 | + import torch.optim as optim |
| 53 | + import transformer_engine.pytorch as te |
| 54 | +
|
| 55 | + hidden_size = 512 |
| 56 | + num_attention_heads = 8 |
| 57 | +
|
| 58 | + transformer_layer = TransformerLayer( |
| 59 | + hidden_size=hidden_size, |
| 60 | + ffn_hidden_size=hidden_size, |
| 61 | + num_attention_heads=num_attention_heads |
| 62 | + ).cuda() |
| 63 | +
|
| 64 | + dummy_input = torch.randn(10, 32, hidden_size).cuda() |
| 65 | + criterion = nn.MSELoss() |
| 66 | + optimizer = optim.Adam(transformer_layer.parameters(), lr=1e-4) |
| 67 | + dummy_target = torch.randn(10, 32, hidden_size).cuda() |
| 68 | +
|
| 69 | + for epoch in range(5): |
| 70 | + transformer_layer.train() |
| 71 | + optimizer.zero_grad() |
| 72 | + with te.fp8_autocast(enabled=True): |
| 73 | + output = transformer_layer(dummy_input) |
| 74 | + loss = criterion(output, dummy_target) |
| 75 | + loss.backward() |
| 76 | + optimizer.step() |
| 77 | +
|
| 78 | +We will demonstrate two debug features on the code above: |
| 79 | + |
| 80 | +1. Disabling FP8 precision for specific GEMM operations, such as the FC1 and FC2 forward propagation GEMM. |
| 81 | +2. Logging statistics for other GEMM operations, such as gradient statistics for data gradient GEMM within the LayerNormLinear sub-layer of the TransformerLayer. |
| 82 | + |
| 83 | +Config file |
| 84 | +---------- |
| 85 | + |
| 86 | +We need to prepare the configuration YAML file, as below |
| 87 | + |
| 88 | +.. code-block:: yaml |
| 89 | +
|
| 90 | + # config.yaml |
| 91 | +
|
| 92 | + fc1_fprop_to_fp8: |
| 93 | + enabled: True |
| 94 | + layers: |
| 95 | + layer_types: [fc1, fc2] # contains fc1 or fc2 in name |
| 96 | + transformer_engine: |
| 97 | + DisableFP8GEMM: |
| 98 | + enabled: True |
| 99 | + gemms: [fprop] |
| 100 | +
|
| 101 | + log_tensor_stats: |
| 102 | + enabled: True |
| 103 | + layers: |
| 104 | + layer_types: [layernorm_linear] # contains layernorm_linear in name |
| 105 | + transformer_engine: |
| 106 | + LogTensorStats: |
| 107 | + enabled: True |
| 108 | + stats: [max, min, mean, std, l1_norm] |
| 109 | + tensors: [activation] |
| 110 | + freq: 1 |
| 111 | + start_step: 2 |
| 112 | + end_step: 5 |
| 113 | +
|
| 114 | +Further explanation on how to create config files is in the :doc:`next part of the documentation <2_config_file_structure>`. |
| 115 | + |
| 116 | +Adjusting Python file |
| 117 | +-------------------- |
| 118 | + |
| 119 | +.. code-block:: python |
| 120 | +
|
| 121 | + # (...) |
| 122 | +
|
| 123 | + import nvdlfw_inspect.api as debug_api |
| 124 | + debug_api.initialize( |
| 125 | + config_file="./config.yaml", |
| 126 | + feature_dirs=["/path/to/transformer_engine/debug/features"], |
| 127 | + log_dir="./log", |
| 128 | + default_logging_enabled=True) |
| 129 | +
|
| 130 | + # initialization of the TransformerLayer with the name |
| 131 | + transformer_layer = TransformerLayer( |
| 132 | + name="transformer_layer", |
| 133 | + # ...) |
| 134 | +
|
| 135 | + # (...) |
| 136 | + for epoch in range(5): |
| 137 | + # forward and backward pass |
| 138 | + # ... |
| 139 | + debug_api.step() |
| 140 | +
|
| 141 | +In the modified code above, the following changes were made: |
| 142 | +
|
| 143 | +1. Added an import for ``nvdlfw_inspect.api``. |
| 144 | +2. Initialized the Nvidia-DL-Framework-Inspect by calling ``debug_api.initialize()`` with appropriate configuration, specifying the path to the config file, feature directories, and log directory. |
| 145 | +3. Added ``debug_api.step()`` after each of the forward-backward pass. |
| 146 | +
|
| 147 | +Inspecting the logs |
| 148 | +------------------ |
| 149 | +
|
| 150 | +Let's look at the files with the logs. Two files will be created: |
| 151 | +
|
| 152 | +1. debug logs. |
| 153 | +2. statistics logs. |
| 154 | +
|
| 155 | +Let's look inside them! |
| 156 | +
|
| 157 | +In the main log file, you can find detailed information about the transformer layer's GEMMs behavior. You can see that ``fc1`` and ``fc2`` fprop GEMMs are run in high precision, as intended. |
| 158 | +
|
| 159 | +.. code-block:: text |
| 160 | +
|
| 161 | + # log/nvdlfw_inspect_logs/nvdlfw_inspect_globalrank-0.log |
| 162 | +
|
| 163 | + INFO - Default logging to file enabled at ./log |
| 164 | + INFO - Reading config from ./config.yaml. |
| 165 | + INFO - Loaded configs for dict_keys(['fc1_fprop_to_fp8', 'log_tensor_stats']). |
| 166 | + INFO - transformer_layer.self_attention.layernorm_qkv: Tensor: activation, gemm fprop - FP8 quantization |
| 167 | + INFO - transformer_layer.self_attention.layernorm_qkv: Tensor: activation, gemm wgrad - FP8 quantization |
| 168 | + INFO - transformer_layer.self_attention.layernorm_qkv: Tensor: weight, gemm fprop - FP8 quantization |
| 169 | + INFO - transformer_layer.self_attention.layernorm_qkv: Tensor: weight, gemm dgrad - FP8 quantization |
| 170 | + INFO - transformer_layer.self_attention.layernorm_qkv: Tensor: gradient, gemm dgrad - FP8 quantization |
| 171 | + INFO - transformer_layer.self_attention.layernorm_qkv: Tensor: gradient, gemm wgrad - FP8 quantization |
| 172 | + INFO - transformer_layer.self_attention.proj: Tensor: activation, gemm fprop - FP8 quantization |
| 173 | + INFO - transformer_layer.self_attention.proj: Tensor: activation, gemm wgrad - FP8 quantization |
| 174 | + INFO - transformer_layer.self_attention.proj: Tensor: weight, gemm fprop - FP8 quantization |
| 175 | + INFO - transformer_layer.self_attention.proj: Tensor: weight, gemm dgrad - FP8 quantization |
| 176 | + INFO - transformer_layer.self_attention.proj: Tensor: gradient, gemm dgrad - FP8 quantization |
| 177 | + INFO - transformer_layer.self_attention.proj: Tensor: gradient, gemm wgrad - FP8 quantization |
| 178 | + INFO - transformer_layer.layernorm_mlp.fc1: Tensor: activation, gemm fprop - High precision |
| 179 | + INFO - transformer_layer.layernorm_mlp.fc1: Tensor: activation, gemm wgrad - FP8 quantization |
| 180 | + INFO - transformer_layer.layernorm_mlp.fc1: Tensor: weight, gemm fprop - High precision |
| 181 | + INFO - transformer_layer.layernorm_mlp.fc1: Tensor: weight, gemm dgrad - FP8 quantization |
| 182 | + INFO - transformer_layer.layernorm_mlp.fc1: Tensor: gradient, gemm dgrad - FP8 quantization |
| 183 | + INFO - transformer_layer.layernorm_mlp.fc1: Tensor: gradient, gemm wgrad - FP8 quantization |
| 184 | + INFO - transformer_layer.layernorm_mlp.fc2: Tensor: activation, gemm fprop - High precision |
| 185 | + INFO - transformer_layer.layernorm_mlp.fc2: Tensor: activation, gemm wgrad - FP8 quantization |
| 186 | + INFO - transformer_layer.layernorm_mlp.fc2: Tensor: weight, gemm fprop - High precision |
| 187 | + INFO - transformer_layer.layernorm_mlp.fc2: Tensor: weight, gemm dgrad - FP8 quantization |
| 188 | + INFO - transformer_layer.layernorm_mlp.fc2: Tensor: gradient, gemm dgrad - FP8 quantization |
| 189 | + INFO - transformer_layer.layernorm_mlp.fc2: Tensor: gradient, gemm wgrad - FP8 quantization |
| 190 | + INFO - transformer_layer.self_attention.layernorm_qkv: Feature=LogTensorStats, API=look_at_tensor_before_process: activation |
| 191 | + .... |
| 192 | +
|
| 193 | +The second log file (``nvdlfw_inspect_statistics_logs/nvdlfw_inspect_globalrank-0.log``) contains statistics for tensors we requested in ``config.yaml``. |
| 194 | +
|
| 195 | +.. code-block:: text |
| 196 | +
|
| 197 | + # log/nvdlfw_inspect_statistics_logs/nvdlfw_inspect_globalrank-0.log |
| 198 | +
|
| 199 | + INFO - transformer_layer.self_attention.layernorm_qkv_activation_max iteration=000002 value=4.3188 |
| 200 | + INFO - transformer_layer.self_attention.layernorm_qkv_activation_min iteration=000002 value=-4.3386 |
| 201 | + INFO - transformer_layer.self_attention.layernorm_qkv_activation_mean iteration=000002 value=0.0000 |
| 202 | + INFO - transformer_layer.self_attention.layernorm_qkv_activation_std iteration=000002 value=0.9998 |
| 203 | + INFO - transformer_layer.self_attention.layernorm_qkv_activation_l1_norm iteration=000002 value=130799.6953 |
| 204 | + INFO - transformer_layer.self_attention.layernorm_qkv_activation_max iteration=000003 value=4.3184 |
| 205 | + INFO - transformer_layer.self_attention.layernorm_qkv_activation_min iteration=000003 value=-4.3381 |
| 206 | + INFO - transformer_layer.self_attention.layernorm_qkv_activation_mean iteration=000003 value=0.0000 |
| 207 | + INFO - transformer_layer.self_attention.layernorm_qkv_activation_std iteration=000003 value=0.9997 |
| 208 | + INFO - transformer_layer.self_attention.layernorm_qkv_activation_l1_norm iteration=000003 value=130788.1016 |
| 209 | + INFO - transformer_layer.self_attention.layernorm_qkv_activation_max iteration=000004 value=4.3181 |
| 210 | + INFO - transformer_layer.self_attention.layernorm_qkv_activation_min iteration=000004 value=-4.3377 |
| 211 | + INFO - transformer_layer.self_attention.layernorm_qkv_activation_mean iteration=000004 value=0.0000 |
| 212 | + INFO - transformer_layer.self_attention.layernorm_qkv_activation_std iteration=000004 value=0.9996 |
| 213 | + INFO - transformer_layer.self_attention.layernorm_qkv_activation_l1_norm iteration=000004 value=130776.7969 |
| 214 | +
|
| 215 | +Logging using TensorBoard |
| 216 | +------------------------ |
| 217 | +
|
| 218 | +Precision debug tools support logging using `TensorBoard <https://www.tensorflow.org/tensorboard>`_. To enable it, one needs to pass the argument ``tb_writer`` to the ``debug_api.initialize()``. Let's modify ``train.py`` file. |
| 219 | +
|
| 220 | +.. code-block:: python |
| 221 | +
|
| 222 | + # (...) |
| 223 | +
|
| 224 | + from torch.utils.tensorboard import SummaryWriter |
| 225 | + tb_writer = SummaryWriter('./tensorboard_dir/run1') |
| 226 | +
|
| 227 | + # add tb_writer to the Debug API initialization |
| 228 | + debug_api.initialize( |
| 229 | + config_file="./config.yaml", |
| 230 | + feature_dirs=["/path/to/transformer_engine/debug/features"], |
| 231 | + log_dir="./log", |
| 232 | + tb_writer=tb_writer) |
| 233 | +
|
| 234 | + # (...) |
| 235 | +
|
| 236 | +Let's run training and open TensorBoard by ``tensorboard --logdir=./tensorboard_dir/run1``: |
| 237 | +
|
| 238 | +.. figure:: ./img/tensorboard.png |
| 239 | + :align: center |
| 240 | +
|
| 241 | + Fig 2: TensorBoard with plotted stats. |
0 commit comments