Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Discrepancy with Optimizer States and Model State Dict when using store_param_remainders==True #1842

Open
alxzhang-amazon opened this issue Sep 5, 2024 · 8 comments

Comments

@alxzhang-amazon
Copy link

alxzhang-amazon commented Sep 5, 2024

Context:

  • mixed precision training
  • model: bf16
  • optimizer: fp32

When we dump a V2-format optimizer state dict, it contains the params fields which we expect to be high-precision versions of the model parameters.

When we train using store_params==True and store_param_remainders==False, we see this expected behavior, where when we downcast the optimizer state dict params, they match the model state dict parameters exactly.

However, when we train using store_params==False and store_param_remainders==True, we do not see this expected behavior. Instead, we see that there are differences in the down-casted optimizer state parameters and the actual model parameters.

I am wondering if this is intentional, or if perhaps this function https://github.com/NVIDIA/apex/blob/master/apex/contrib/optimizers/distributed_fused_adam.py#L247
is not lossless.

I have provided a sample script that should showcase this issue.
We see when using store_params==True and store_param_remainders==False all tensors match.
We see when using store_params==False and store_param_remainders==True we see mismatch.

store_param_remainders_dfa_test.txt
store_params_dfa_test.txt

The script used to highlight this issue:

import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
from apex.contrib.optimizers.distributed_fused_adam import DistributedFusedAdam
import torch.multiprocessing as mp
import io

DEVICE = "cuda"


class InjectedException(Exception):
    pass


def cleanup():
    dist.barrier()
    dist.destroy_process_group()





def helper_test_v2_params_match_model_params(rank, world_size, store_params=True, store_param_remainders=False):
    # =======================================
    # 1. Create a model and DFA optimizer.
    # =======================================
    # torch.set_printoptions(precision=20)
    

    dist.init_process_group("nccl", rank=rank, world_size=world_size)

    torch.cuda.set_device(rank)

    model = torch.nn.Transformer(nhead=16, num_encoder_layers=12).to(rank)
    model = model.to(torch.bfloat16)
    model = nn.parallel.DistributedDataParallel(model, device_ids=[rank])
    # model.eval()

    optimizer = DistributedFusedAdam(
        model.parameters(), lr=1.2e-2, 
        store_param_remainders=store_param_remainders, store_params=store_params, 
        param_sync_dtype=torch.bfloat16, overlap_grad_sync=True, bucket_cap_mb=250, overlap_param_sync=True, 
        contiguous_grad_buffer=True, adam_w_mode=True, eps=1e-08, weight_decay=0.1, betas=[0.9, 0.95])
    

    # =======================================
    # 2. Run a training step
    # =======================================
    src = torch.rand((10, 32, 512), device=rank, dtype=torch.bfloat16)
    tgt = torch.rand((20, 32, 512), device=rank, dtype=torch.bfloat16)
    tgt_y = torch.randint(0, 512, (20, 32), device=rank, dtype=torch.long).view(-1)
    
    criterion = torch.nn.CrossEntropyLoss()
    optimizer.zero_grad()

    output = model(src, tgt)
    loss = criterion(output.view(-1, output.size(-1)), tgt_y)
    loss.backward()
    optimizer.step()

    # =======================================
    # 3. Save the optimizer state 
    # =======================================
    osd = [optimizer.state_dict(state_dict_format=2)]
    msd = model.state_dict()

    osd = osd[0]
    state = osd['state']

    param_index = 0
    for k,v in msd.items():
        print(f"Verifying Key: {k}")
        # optimizer states is stored as a list of a single element        

        master_param = state[param_index]['param']
        downcasted_param = master_param.to(dtype=torch.bfloat16, device=DEVICE)
        assert_equal(v, downcasted_param)

        param_index += 1
    



    print(f"Finished Helper V2 Params Match Model Params Test")
    cleanup()


def assert_equal(a, b):
    assert type(a) == type(b), f"Object types do not match: {type(a)} != {type(b)}"
    if isinstance(a, (int, float, str)):
        assert a == b, f"Values do not match: {a} != {b}"
    elif isinstance(a, torch.Tensor):
        compare_tensor(a, b)
    elif isinstance(a, list):
        assert len(a) == len(b), "Lists differ in length"
        for item1, item2 in zip(a, b):
            assert_equal(item1, item2)
    elif isinstance(a, dict):
        assert a.keys() == b.keys(), "Dictionary differ in keys"
        for k in a:
            assert_equal(a[k], b[k])
    elif isinstance(a, io.BytesIO):
        a.seek(0)
        b.seek(0)
        assert a.read() == b.read()
    else:
        raise TypeError(f"Unsupported data type: {type(a)}")
    


def compare_tensor(a, b):
    if torch.equal(a, b):
        return

    differences = a != b
    num_differences = differences.sum().item()
    print(f"Tensors do not match. Number of differing elements: {num_differences}")
    
    diff_indices = differences.nonzero(as_tuple=True)
    
    # for i in range(min(10, num_differences)):  # Print at most 10 differences for brevity
    for i in range(num_differences):  # Print at most 10 differences for brevity
        idx = tuple(d[i].item() for d in diff_indices)
        a_value = a[idx].item()
        b_value = b[idx].item()
        print(f"Difference at index {idx}: a={a[idx].item()}, b={b[idx].item()}")
        print(f"Percent Difference: {abs(b_value-a_value)/b_value*100}\n")




def test_v2_params_match_model_params(tmp_path):
    """
    Checks whether the consolidated checkpoints contain the high precision master weights of the 
    model parameters. 
    We see that they do when store_params==True, but when using store_param_remainders, this is not the case. 
    """
    # dist.init_process_group("nccl", rank=0, world_size=1)
    # cleanup()
    # exit(0)
    world_size = 1


    # ======================================================================
    # 2. Run the test with set optimizer flag configs
    # ======================================================================
    store_params = False
    store_param_remainders = True
    mp.spawn(
        helper_test_v2_params_match_model_params,
        args=(world_size, store_params, store_param_remainders),
        nprocs=world_size,
        join=True,
    )




if __name__ == "__main__":
    test_v2_params_match_model_params(None)





@crcrpar
Copy link
Collaborator

crcrpar commented Sep 6, 2024

cc @timmoon10

@alxzhang-amazon
Copy link
Author

Update:
We found that there is some loss induced by the _bf16_rem_to_fp32 method which occurs at 1e-6 chance.
https://github.com/NVIDIA/apex/blob/master/apex/contrib/optimizers/distributed_fused_adam.py#L247

import torch

# this function comes from DistributedFusedAdam optimizer in apex 
# https://github.com/NVIDIA/apex/blob/master/apex/contrib/optimizers/distributed_fused_adam.py#L247
def _bf16_rem_to_fp32(bf16: torch.Tensor, rem: torch.Tensor, fp32: torch.Tensor) -> None:
    """Pack BF16 tensor and 16-bit remainders into FP32 tensor"""

    # Check inputs
    assert bf16.size() == rem.size() == fp32.size(), (
        "Tensor dimensions do not match: "
        f"bf16={list(bf16.size())}, "
        f"rem={list(rem.size())}, "
        f"fp32={list(fp32.size())}, "
    )
    assert bf16.dtype is torch.bfloat16, f"bf16 buffer has invalid dtype ({bf16.dtype})"
    assert rem.dtype is torch.int16, f"rem buffer has invalid dtype ({rem.dtype})"
    assert fp32.dtype is torch.float32, f"fp32 buffer has invalid dtype ({fp32.dtype})"

    # Undo bf16 rounding
    bf16 = bf16.view(torch.int16) - torch.where(rem < 0, 1, 0)

    # Pack bf16 and remainder into little-endian fp32
    fp32 = fp32.unsqueeze(-1).view(torch.int16)
    fp32 = torch.stack((rem, bf16), dim=-1, out=fp32)

def split_fp32_to_bf16_and_rem(fp32_tensor):
    bf16_tensor = fp32_tensor.to(torch.bfloat16)

    # Calculate the remainder: take the difference between the original FP32 and BF16-converted tensor
    # Scale by 2^16 to capture the lost precision in the 16-bit mantissa
    rem_tensor = (fp32_tensor.view(torch.int32) & 0xFFFF).to(torch.int16)
    return bf16_tensor, rem_tensor

def compare_tensor(a, b):
    if torch.equal(a, b):
        return
    differences = a != b
    num_differences = differences.sum().item()
    print(f"Tensors do not match. Number of differing elements: {num_differences}")
    diff_indices = differences.nonzero(as_tuple=True)
    for i in range(num_differences):
        idx = tuple(d[i].item() for d in diff_indices)
        a_value = a[idx].item()
        b_value = b[idx].item()
        diff_perc = abs(b_value - a_value) / abs(a_value) * 100
        print(f"Difference at index {idx}: a={a[idx].item()}, b={b[idx].item()}")
        print(f"Percent Difference: {diff_perc}\n")

def test_lossless_conversion(num_tests=100, tensor_size=(100, 100)):
    for i in range(num_tests):
        print(f'Test number {i}---')
        original_fp32 = torch.randn(tensor_size, dtype=torch.float32)
        bf16_tensor, rem_tensor = split_fp32_to_bf16_and_rem(original_fp32)
        reconstructed_fp32 = torch.empty_like(original_fp32)
        _bf16_rem_to_fp32(bf16_tensor, rem_tensor, reconstructed_fp32)
        compare_tensor(reconstructed_fp32, original_fp32)

test_lossless_conversion()

This function is invoked during checkpoint dump, but we are concerned about the implementation details related to every optimizer step as well.
When master weights are constructed during the optimizer step, is the functionality the same as the _bf16_rem_to_fp32, and if so is there loss that occurs during the training steps?

@shuaitang5
Copy link

The loss is inevitable if we use fp32_tensor.to(bf16) operation to get the low precision param from full precision optim. The reason is that the pytorch .to operation does rounding behind the scene. The rounding rule is round to nearest, ties to even. Specifically, round up last bit of higher 16-bit if 17th bit is 1. Except when lower 16-bit is 0x8000 (ties to midway), it only rounds up when last bit of higher 16-bit is odd (this is where the conversion incurs loss and we have no way to recover since we don't know the last even bit of bf16 is rounded up from odd bit or it was even in the original fp32 number). Precisely, this rounding error/tensor mismatch happens only when the 32 bit is like this: xxxx xxxx xxxx xxx0 1000 0000 0000 0000 (https://github.com/NVIDIA/apex/blob/24.04.01/apex/contrib/optimizers/distributed_fused_adam.py#L265-L266 the undo rounding is incorrectly done in this case.)

The fix would be instead of doing a simple cast to get the bf16 param, we should rely on bit manipulation to separate the higher and lower 16 bits into bf16 param and 16-bit remainders, thus avoiding rounding/undo-rounding entirely.

@timmoon10
Copy link
Contributor

timmoon10 commented Sep 17, 2024

Great debugging. It's tricky that rn rounding is irreversible unless we store an extra bit, which seems excessive given that these errors are just at the level of machine epsilon. @alxzhang-amazon How important is bit-wise equivalence to fp32_tensor.to(torch.bfloat16)?

If we do think of a good approach to reproduce rn rounding, then we'll also need to modify the Adam kernel:

struct DistAdamWithParamRemaindersFunctor {

This is in addition to _bf16_rem_to_fp32:

@alxzhang-amazon
Copy link
Author

alxzhang-amazon commented Sep 17, 2024

Hi @timmoon10.

We are not reliant on using store_param_remainders==True, so we can switch the flag if bit-wise equivalence is required.

To clarify my understanding, is it the case that this loss occurs in the Adam kernel as well and therefore occurs during each training step as well, or is it localized to just the _bf16_rem_to_fp32 function?

(

)

if (local_p_rem[ii] < 0) local_p_bf16[ii]++; // Round up

@alxzhang-amazon
Copy link
Author

alxzhang-amazon commented Sep 17, 2024

To my understanding, the rounding methods are different between in the _bf16_rem_to_fp32 and in the Adam Kernel code.
(round to nearest, ties to even vs round to nearest, ties away from zero).

Please correct me if I am mistaken or misunderstanding.

@timmoon10
Copy link
Contributor

Both _bf16_rem_to_fp32 and the Adam kernel use "round to nearest, ties away from zero", so you should get bit-wise exact results when saving/loading state dicts. However, direct type casts (e.g. with Tensor.to) use "round to nearest, ties to even", so you should expect small numerical differences between store_param_remainders=False and store_param_remainders=True.

@alxzhang-amazon
Copy link
Author

alxzhang-amazon commented Sep 17, 2024

Thank you @timmoon10.

So to my current understanding now:

  1. NVidia Apex implemented their conversion logic (used when using store_param_remainders==True) to follow "round to nearest, ties away from zero" rounding logic. This is implemented in both DFA _bf16_rem_to_fp32 and the kernel code.

This has no impact during training steps, since between training steps we use the Kernel code to do conversion, which is consist in its rounding logic.
The parameters and param_remainders are formed using the kernel logic, and is therefore consistent.

  1. Torch operations, such as .to and .view use "round to nearest, ties to even", so there is a mismatch when the rounding conditions are met.
    https://docs.nvidia.com/cuda/floating-point/index.html#rounding-modes

To my understanding, the Distributed Fused Adam file does not use fp32_tensor.to(torch.bfloat16) to split into bf16 param and remainders. Splitting is done in kernel code. The python code only constructs the fp32 parameters when dumping the optimizer state dict.

When dumping the optimizer states, we see that we
reconstruct the fp32 optimizer states with _bf16_rem_to_fp32 which is trying to follow the "round to nearest, ties away from zero" rounding logic.

Therefore, when using store_param_remainders==True, the tensor elements are constructed using a different rounding method that Tensor.to uses, hence the difference shown when comparing the output of the tensors after using 'Tensor.to' with the original.

Does this summary seem correct to you?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants