Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Quantization + FSDP] Support quantize_() for DTensor #803

Open
gau-nernst opened this issue Sep 4, 2024 · 1 comment
Open

[Quantization + FSDP] Support quantize_() for DTensor #803

gau-nernst opened this issue Sep 4, 2024 · 1 comment

Comments

@gau-nernst
Copy link
Collaborator

While trying out INT8 mixed precision pretraining (#748) with torchtitan, I came across an issue that if the model is FSDP-sharded, quantize_() won't work. The fix would be adding an extra logic to handle DTensor, similar to what FP8 is doing

if isinstance(bits_fp8, DTensor):
assert isinstance(
scale, DTensor
), "Expected Float8 scale to be a DTensor if bits_fp8 is a DTensor"
bits_mesh = bits_fp8.device_mesh
bits_placements = bits_fp8.placements
local_bits = bits_fp8.to_local()
local_scale = scale.to_local()
inner_float8_tensor = Float8Tensor(
local_bits,
local_scale,
tensor.dtype,
linear_mm_config=linear_mm_config,
gemm_input_role=gemm_input_role,
)
return DTensor.from_local(
inner_float8_tensor,
bits_mesh,
bits_placements,
run_check=False,
shape=bits_fp8.size(),
stride=bits_fp8.stride(),
)

@msaroufim
Copy link
Member

Yeah this came up in some discussions with inference providers like SGLang as well

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants