Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add torch.compile + FSDP2 float8 all-gather in CI (#468)
fixed my bug in float8_experimental. now we can torch.compile transfromer blocks with FSDP float8 all-gather pytorch-labs/float8_experimental#321 local test: `CONFIG_FILE="./train_configs/debug_model.toml" ./run_llama_train.sh --training.enable_float8_linear --training.enable_fsdp_float8_all_gather --training.precompute_float8_dynamic_scale_for_fsdp --training.compile` profiler traces: I can see compiled region in cpu thread and float8 malmul `sm90_xmma_gemm_e4m3bf16...` in cuda stream <img width="1468" alt="Screenshot 2024-07-18 at 4 22 17 PM" src="https://github.com/user-attachments/assets/0cf58dee-aae1-4582-a3f1-b8aa48b45129">
- Loading branch information