Skip to content

Commit

Permalink
FIX TST NaN issue with HQQ GPU test (#2143)
Browse files Browse the repository at this point in the history
This test calculates the correlation coefficient of HQQ model outputs.
Although the model outputs are finite, the resulting matrix contains
NaNs. Casting the outputs from 16 to 32 bit precision resolves the
issue.
  • Loading branch information
BenjaminBossan authored Oct 10, 2024
1 parent 5758a7e commit 0aa7e3a
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions tests/test_gpu_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -2917,18 +2917,18 @@ def test_hqq_lora_model_outputs(self):
output_hqq = model(**inputs).logits

# check that outputs of HQQ are highly correlated; there are outliers, so don't check for equality
cc_matrix = torch.corrcoef(torch.stack((output_normal.flatten(), output_hqq.flatten())))
cc_matrix = torch.corrcoef(torch.stack((output_normal.float().flatten(), output_hqq.float().flatten())))
assert cc_matrix.min() > 0.97

# check that outputs are the same after merging
cc_matrix = torch.corrcoef(torch.stack((output_normal.flatten(), output_hqq.flatten())))
cc_matrix = torch.corrcoef(torch.stack((output_normal.float().flatten(), output_hqq.float().flatten())))
assert cc_matrix.min() > 0.97

# check outputs are the same after unmerging
model.unmerge_adapter()
with torch.inference_mode():
output_unmerged = model(**inputs).logits
cc_matrix = torch.corrcoef(torch.stack((output_normal.flatten(), output_unmerged.flatten())))
cc_matrix = torch.corrcoef(torch.stack((output_normal.float().flatten(), output_unmerged.float().flatten())))
assert cc_matrix.min() > 0.97

# check that the results are the same after saving and loading
Expand Down Expand Up @@ -2957,7 +2957,9 @@ def test_hqq_lora_model_outputs(self):
model = model.merge_and_unload()
with torch.inference_mode():
output_merged_unloaded = model(**inputs).logits
cc_matrix = torch.corrcoef(torch.stack((output_normal.flatten(), output_merged_unloaded.flatten())))
cc_matrix = torch.corrcoef(
torch.stack((output_normal.float().flatten(), output_merged_unloaded.float().flatten()))
)
assert cc_matrix.min() > 0.97


Expand Down

0 comments on commit 0aa7e3a

Please sign in to comment.