Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrating generic_float struct for adding datatypes #3522

Merged
merged 64 commits into from
Nov 8, 2024
Merged

Conversation

richagadgil
Copy link
Contributor

No description provided.

@richagadgil richagadgil self-assigned this Oct 10, 2024
@pfultz2
Copy link
Collaborator

pfultz2 commented Oct 10, 2024

A couple of things:

  • Move generic_float to migraphx/generic_float.hpp
  • Add a specialization for std::numeric_limits
  • Add a specialization for migraphx::is_floating_point(and remove the half one)
  • Add a specialization for std::common_type(and remove the half one)
  • Add a test/generic_float.cpp with the fp32 tests I wrote and add some tests for operator overloads
  • Add pragmas to disable the duplicate branch warnings in the bit_cast function

The specializations should use the template type like:

template<unsigned int E, unsigned int M, unsigned int F>
class numeric_limits<migraphx::generic_float<E, M, F>>
{
...
};

@pfultz2
Copy link
Collaborator

pfultz2 commented Oct 10, 2024

Also, we can use the fp8 template type as well to reduce the number common_type overloads:

template<unsigned int E, unsigned int M, unsigned int F, migraphx::fp8::f8_type T, bool FNUZ>
struct common_type<migraphx::generic_float<E, M, F>, migraphx::fp8::float8<T, FNUZ>> : std::common_type<float, float>
{};

template<unsigned int E, unsigned int M, unsigned int F, migraphx::fp8::f8_type T, bool FNUZ>
struct common_type<migraphx::fp8::float8<T, FNUZ>, migraphx::generic_float<E, M, F>> : std::common_type<float, float>
{};

@pfultz2
Copy link
Collaborator

pfultz2 commented Oct 18, 2024

For the fp16 tests, we want test similiar to the fp8, but instead of having an array lookup table, we would sample some values into a map and test that:

TEST_CASE(check_half_values)
{
    for(auto [x, f] : half_lut)
    {
        auto h = migraphx::bit_cast<migraphx::half>(x);
        if(std::isnan(f))
        {
            CHECK(std::isnan(h));
        }
        else if(std::isinf(f))
        {
            CHECK(std::isinf(h));
            CHECK((h < 0) == (f < 0));
            CHECK(bit_equal(x, migraphx::half(f)));
        }
        else
        {
            CHECK(migraphx::float_equal(float(h), f));
            CHECK(bit_equal(x, migraphx::half(f)));
        }
    }
}

I have a map of a thousand or so values we can use for this test. Also we will want to test the numeric limits by checking the bits match what we would expect:

TEST_CASE(check_numeric_limits)
{
    CHECK(bit_equal(std::numeric_limits<migraphx::half>::min(), uint16_t{0x0400}));
    CHECK(bit_equal(std::numeric_limits<migraphx::half>::lowest(), uint16_t{0xfbff}));
    CHECK(bit_equal(std::numeric_limits<migraphx::half>::max(), uint16_t{0x7bff}));
    CHECK(bit_equal(std::numeric_limits<migraphx::half>::epsilon(), uint16_t{0x1400}));
    CHECK(bit_equal(std::numeric_limits<migraphx::half>::denorm_min(), uint16_t{0x0001}));
    CHECK(bit_equal(std::numeric_limits<migraphx::half>::infinity(), uint16_t{0x7c00}));
    CHECK(bit_equal(std::numeric_limits<migraphx::half>::quiet_NaN(), uint16_t{0x7fff}));
    CHECK(bit_equal(std::numeric_limits<migraphx::half>::signaling_NaN(), uint16_t{0x7dff}));
}

In addition, it would be good to have some tests for overflow and underflow like for std::numeric_limits<half>::max() + std::numeric_limits<float>::epsilon().

@causten causten requested a review from CharlieL7 October 22, 2024 16:45
test/half.cpp Outdated
CHECK(bit_equal(std::numeric_limits<migraphx::half>::signaling_NaN(), uint16_t{0x7d00}));
}

static const std::map<uint16_t, float> half_lut = {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We probably need to wrap this in a function to fix the tidy warning:

const std::map<uint16_t, float>& half_lut()
{
    static const std::map<uint16_t, float> result = { ... };
    return result;
}

@pfultz2 pfultz2 reopened this Nov 4, 2024
@pfultz2
Copy link
Collaborator

pfultz2 commented Nov 4, 2024

Overall this looks, we just need to fix the tidy warnings.

Copy link
Collaborator

@pfultz2 pfultz2 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, just need to fix the CI checks.


constexpr float32_parts get_parts(float f) { return migraphx::bit_cast<float32_parts>(f); }

#pragma pack(push, 1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs to be surrounded by #ifdef _MSC_VER.


#pragma pack(push, 1)
template <unsigned int MantissaSize, unsigned int ExponentSize, unsigned int Flags = 0>
struct alignas(1) __attribute__((may_alias)) generic_float
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You need the packed attribute for gcc/clang, but this probably breaks windows though. Instead of using macros, doing [[gnu::packed, gnu::may_alias]] may work instead.


#pragma pack(push, 1)
template <unsigned int MantissaSize, unsigned int ExponentSize, unsigned int Flags = 0>
struct alignas(1) __attribute__((may_alias)) generic_float
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also alignment is wrong. It should be alignas((MantissaSize+ExponentSize+1)/8), I dont know if that compiles.

return temp;
}
};
#pragma pack(pop)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also needs a #ifdef _MSC_VER.

@migraphx-bot
Copy link
Collaborator

Test Batch Rate new
9ae05a
Rate old
624c8d
Diff Compare
torchvision-resnet50 64 3,260.10 3,261.66 -0.05%
torchvision-resnet50_fp16 64 6,987.01 6,990.31 -0.05%
torchvision-densenet121 32 2,435.66 2,436.87 -0.05%
torchvision-densenet121_fp16 32 4,094.32 4,089.75 0.11%
torchvision-inceptionv3 32 1,637.98 1,637.96 0.00%
torchvision-inceptionv3_fp16 32 2,763.22 2,767.26 -0.15%
cadene-inceptionv4 16 775.91 776.97 -0.14%
cadene-resnext64x4 16 811.84 811.00 0.10%
slim-mobilenet 64 7,536.11 7,537.82 -0.02%
slim-nasnetalarge 64 211.57 211.51 0.03%
slim-resnet50v2 64 3,506.73 3,504.05 0.08%
bert-mrpc-onnx 8 1,149.43 1,146.98 0.21%
bert-mrpc-tf 1 467.46 502.37 -6.95% 🔴
pytorch-examples-wlang-gru 1 421.12 421.14 -0.00%
pytorch-examples-wlang-lstm 1 389.03 402.51 -3.35% 🔴
torchvision-resnet50_1 1 775.94 800.87 -3.11% 🔴
cadene-dpn92_1 1 401.24 435.76 -7.92% 🔴
cadene-resnext101_1 1 382.72 383.41 -0.18%
onnx-taau-downsample 1 342.84 342.91 -0.02%
dlrm-criteoterabyte 1 33.34 33.35 -0.02%
dlrm-criteoterabyte_fp16 1 52.71 52.74 -0.06%
agentmodel 1 7,955.64 8,416.52 -5.48% 🔴
unet_fp16 2 58.96 58.97 -0.02%
resnet50v1_fp16 1 942.21 940.39 0.19%
resnet50v1_int8 1 1,021.08 1,022.33 -0.12%
bert_base_cased_fp16 64 1,170.08 1,171.13 -0.09%
bert_large_uncased_fp16 32 363.65 363.49 0.04%
bert_large_fp16 1 200.42 200.63 -0.11%
distilgpt2_fp16 16 2,202.27 2,202.77 -0.02%
yolov5s 1 549.41 538.22 2.08%
tinyllama 1 43.45 43.49 -0.08%
vicuna-fastchat 1 171.45 171.92 -0.27%
whisper-tiny-encoder 1 418.63 418.87 -0.06%
whisper-tiny-decoder 1 425.61 428.54 -0.68%

This build is not recommended to merge 🔴

@migraphx-bot
Copy link
Collaborator


     ✅ bert-mrpc-onnx: PASSED: MIGraphX meets tolerance

     ✅ bert-mrpc-tf: PASSED: MIGraphX meets tolerance

     ✅ pytorch-examples-wlang-gru: PASSED: MIGraphX meets tolerance

     ✅ pytorch-examples-wlang-lstm: PASSED: MIGraphX meets tolerance

     ✅ torchvision-resnet50_1: PASSED: MIGraphX meets tolerance

     ✅ cadene-dpn92_1: PASSED: MIGraphX meets tolerance

     ✅ cadene-resnext101_1: PASSED: MIGraphX meets tolerance

     ✅ dlrm-criteoterabyte: PASSED: MIGraphX meets tolerance

     ✅ agentmodel: PASSED: MIGraphX meets tolerance

     ✅ unet: PASSED: MIGraphX meets tolerance

     ✅ resnet50v1: PASSED: MIGraphX meets tolerance

     ✅ bert_base_cased_fp16: PASSED: MIGraphX meets tolerance

🔴bert_large_uncased_fp16: FAILED: MIGraphX is not within tolerance - check verbose output


     ✅ bert_large: PASSED: MIGraphX meets tolerance

     ✅ yolov5s: PASSED: MIGraphX meets tolerance

     ✅ tinyllama: PASSED: MIGraphX meets tolerance

     ✅ vicuna-fastchat: PASSED: MIGraphX meets tolerance

     ✅ whisper-tiny-encoder: PASSED: MIGraphX meets tolerance

     ✅ whisper-tiny-decoder: PASSED: MIGraphX meets tolerance

     ✅ distilgpt2_fp16: PASSED: MIGraphX meets tolerance

@causten causten merged commit f5df004 into develop Nov 8, 2024
28 of 34 checks passed
@causten causten deleted the generic_float branch November 8, 2024 17:59
V6ser pushed a commit to V6ser/AMDMIGraphX that referenced this pull request Feb 3, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants