Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Changes on top of upstream to get rid of type errors #248

Closed
wants to merge 3 commits into from

Conversation

alugorey
Copy link
Contributor

@alugorey alugorey commented Apr 1, 2024

Fixes the class of failed unit tests on rocm in test_base.py that fail the internal assertion Cannot convert ScalarType Float8_e4m3fn to hipDataType.

Note: We are aware of the outstanding numerical issues and are looking into it internally.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 1, 2024
@alugorey
Copy link
Contributor Author

alugorey commented Apr 3, 2024

@drisspg :)

@drisspg
Copy link
Contributor

drisspg commented Apr 4, 2024

Awesome! I will take a look at this tomorrow

@@ -28,12 +28,25 @@
IS_AMD = torch.cuda.is_available() and torch.version.hip is not None


# Helper functions to get individual F8 types based on backend architecture
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be possible to put this into configuration instead of setting it dynamically? It can be unexpected for numerics to change based on the environment. It would also be good to support numerical emulation of all of these types regardless of whether the user's machine supports a float8 matmul.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm afraid I don't understand your question. These helper functions are simply intended to grab the "right" version of the prebuilt torch F8 types. Could you elaborate on the change you'd like to see?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, it's just making the dtype flavors encoded in configuration instead of environment dependent. Having this in configuration would make it easier to debug numerics without having the target hardware.

# float8 dtypes have a default which can be changed explicitly
config = ...
config.float8_flavors = 'nuz'
do_float8_things(..., config)

versus

# float8 dtypes magically change based on the environment
do_float8_things(...)

That said, my comment is not high pri, feel free to land and we can adjust this later if it becomes important.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vkuzo Sorry, got wrapped up in other work recently and just circled back to this. Okay, I will add an option in float8_experimental/float8_experimental/config.py and instead of checking the backend architecture, the code will check against this user-settable config variable

@@ -350,7 +374,7 @@ def test_scaled_mm_vs_emulated(self, base_dtype):


class TestNumerics:
@pytest.mark.parametrize("float8_dtype", [torch.float8_e4m3fn, torch.float8_e5m2])
@pytest.mark.parametrize("float8_dtype", [fp8_e4m3_t(), fp8_e5m2_t()])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would recommend testing all cases on all hardware types instead. For things not requiring a matmul, it should just work. For things requiring a matmul, we have an emulation mode to at least help approximate it.

@@ -47,7 +51,10 @@ class TestFloat8Tensor(unittest.TestCase):
def test_preserves_dtype(self) -> None:
# hp means high precision, lp means low precision
hp_dtypes = (torch.float32, torch.float16, torch.bfloat16)
lp_dtypes = (torch.float8_e4m3fn, torch.float8_e5m2)
fp8_dtypes = (
FP8Dtypes()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

all dtypes would be nice, there should not be anything in Float8Tensor which is hardware dependent

m = nn.Linear(32, 16, device="cuda", dtype=linear_dtype)
m = get_float8_linear(linear_type, m, emulate, False)
m = get_float8_linear(linear_type, m, emulate, False, fp8_dtypes)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can enable emulation here if your hardware doesn't support the dtype under test

@alugorey
Copy link
Contributor Author

@vkuzo Ready for another review. Also, wanted to ask if there was an ETA or roadmap for when this functionality would be pulled into pytorch proper?

@vkuzo
Copy link
Contributor

vkuzo commented Apr 25, 2024

so sorry I am on holiday right now, will take a look late next week when I return, unless @drisspg wants to get to it sooner

@drisspg
Copy link
Contributor

drisspg commented Apr 26, 2024

Yeah will review tomorrow

Copy link
Contributor

@drisspg drisspg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What was the output of test_everything.sh?

Copy link
Contributor

@drisspg drisspg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What was the output of test_everything.sh?

@alugorey
Copy link
Contributor Author

@drisspg
test_everything.log

Fails in test_compile. However, I was aware of this failure and found that this is unrelated to my changes, rather, an issue with torch.compile on ROCm. This failure was next on my TODO's to address. I will upload a follow up PR.

# If True, use 'fnuz' float8 types for calculations. If the backend
# hardware does not support a particular type, the emulated implementation
# of the dtype will be used. Currently, ROCm only supports the fnuz variants.
use_fnuz_dtype = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

default to False?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, yes, that's an oversight due to being easier for me to test. will change.

@@ -128,7 +128,7 @@ def test_dtensor_fp8_autograd(mesh: DeviceMesh, size=16):
)

out = torch.nn.functional.linear(dist_x_fp8, dist_weight_fp8)
out = NoopFwToFloat8E5M2Bw.apply(out, False)
out = NoopFwToFloat8E5M2Bw.apply(out, False, fp8_e5m2_t())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is the new last arg here expected?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oversight, fixed in latest commit

@drisspg
Copy link
Contributor

drisspg commented Jun 5, 2024

just an update here, the base PR should landed yesterday

@alugorey alugorey force-pushed the fnuz_typing branch 2 times, most recently from 57120fa to 4997c19 Compare June 14, 2024 18:39
@alugorey alugorey changed the base branch from amd-support to main June 14, 2024 18:40
@alugorey
Copy link
Contributor Author

@vkuzo @drisspg
Revised/rebased drop of AMD support after the amd-support branch was merged to main

@@ -19,3 +19,8 @@
# implements pre/post-all-gather methods to do fp8 all-gather with FSDP2.
# Only dynamic scaling is supported for now.
enable_fsdp_fp8_all_gather = False

# If True, use 'fnuz' float8 types for calculations. If the backend
# hardware does not support a particular dtype, the emulated implementation
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: currently the user is responsible for toggling the emulation setting, we don't do that automatically

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought per previous comments that we wanted to go the emulated route if the backend hardware didn't support the type? in general, the user could force to emulate, but i was under the impression that cases where a dtype was being used that wasn't supported on the underlying hardware, that we wanted to go the emulated route?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep, that is correct! This is just currently done by the user explicitly, and there is no support do handle that automatically.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, you're saying the comment is wrong

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed!

Copy link
Contributor

@vkuzo vkuzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks great! thanks for helping!

@facebook-github-bot
Copy link
Contributor

@drisspg has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@alugorey
Copy link
Contributor Author

@drisspg fixed lint. But not sure what the ufmt errors are about.

@drisspg
Copy link
Contributor

drisspg commented Jun 17, 2024

@alugorey I think if you apply this patch it should work: https://gist.github.com/drisspg/2a87d54521a0b2312ac44d070f63350d

@alugorey
Copy link
Contributor Author

@drisspg Looks like you beat me to it. Still failing on 2 files though. is there documentation on what ufmt expects? some of those changes seem purely cosmetic

@drisspg
Copy link
Contributor

drisspg commented Jun 17, 2024

The reason for this formatting is due to some internal nodes on code styling. TBH
I just run ufmt format . and dont care about the actual format.

I have also have some 'pre-commit' hooks for the repo that should be set and forget. There does seem to be 2 more lint fixes.

There seems to be 1 real test_failure:

./test/test_everything.sh: line 5: rocm-smi: command not found

I found this patch passed test_everything:

https://gist.github.com/drisspg/47a29d6bf3fcca2a2c48d09b74c564aa

@facebook-github-bot
Copy link
Contributor

@drisspg has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@drisspg merged this pull request in 0bd374d.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. Merged
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants