Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Commit

Permalink
unit test for precomputing scales
Browse files Browse the repository at this point in the history
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
weifengpy committed Jul 10, 2024
1 parent 9ef67fb commit fa2f08a
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions test/test_fsdp2/test_fsdp2_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch
import torch.distributed as dist
import torch.nn as nn
from float8_experimental.fsdp_utils import precompute_float8_amax_for_fsdp
from float8_experimental.fsdp_utils import precompute_float8_scale_for_fsdp


def check_parity_no_mp(
Expand All @@ -31,7 +31,7 @@ def check_parity_no_mp(
# TODO(future): add amax syncing once delayed scaling is supported
optim.step()
if model is fsdp_model and precompute:
precompute_float8_amax_for_fsdp(model)
precompute_float8_scale_for_fsdp(model)
test_cls.assertEqual(losses[0], losses[1])


Expand Down

0 comments on commit fa2f08a

Please sign in to comment.