Skip to content

Commit

Permalink
Use internal quantizer for input to the modules (NVIDIA#1551)
Browse files Browse the repository at this point in the history
Internal quantizer for input to the modules

Signed-off-by: Przemek Tredak <[email protected]>
  • Loading branch information
ptrendx authored Mar 10, 2025
1 parent 5bb771e commit b3e7035
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion transformer_engine/pytorch/module/layernorm_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -1358,7 +1358,7 @@ def _get_quantizers(self, fp8_output):
grad_output_quantizer = None
output_quantizer = None
input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT]
input_quantizer.internal = False
input_quantizer.internal = True
weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT]
weight_quantizer.internal = True
if fp8_output:
Expand Down
2 changes: 1 addition & 1 deletion transformer_engine/pytorch/module/layernorm_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -1528,7 +1528,7 @@ def _get_quantizers(self):
) = [None] * 8
if self.fp8:
fc1_input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT]
fc1_input_quantizer.internal = False # temporary
fc1_input_quantizer.internal = True
fc1_weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT]
fc1_weight_quantizer.internal = True
fc2_input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_INPUT]
Expand Down
2 changes: 1 addition & 1 deletion transformer_engine/pytorch/module/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -1136,7 +1136,7 @@ def _get_quantizers(self, fp8_output, fp8_grad):
grad_output_quantizer = None
output_quantizer = None
input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT]
input_quantizer.internal = False
input_quantizer.internal = True
weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT]
weight_quantizer.internal = True
if fp8_output:
Expand Down

0 comments on commit b3e7035

Please sign in to comment.