Skip to content

Commit

Permalink
Fixed bug where bfloat16 datatype wasn't handled correctly when conve…
Browse files Browse the repository at this point in the history
…rting to numpy (responding to issue 26)
  • Loading branch information
JohnMark Taylor committed Aug 30, 2024
1 parent 2a25882 commit 813c834
Showing 1 changed file with 12 additions and 3 deletions.
15 changes: 12 additions & 3 deletions torchlens/helper_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,7 +522,10 @@ def get_tensor_memory_amount(t: torch.Tensor) -> int:
Returns:
Size of tensor in bytes.
"""
return getsizeof(np.array(clean_cpu(t.data)))
cpu_data = clean_cpu(t.data)
if cpu_data.dtype == torch.bfloat16:
cpu_data = cpu_data.to(torch.float16)
return getsizeof(np.array(cpu_data))


def human_readable_size(size: int, decimal_places: int = 1) -> str:
Expand Down Expand Up @@ -564,7 +567,10 @@ def print_override(t: torch.Tensor, func_name: str):
Returns:
The string representation of the tensor.
"""
n = np.array(clean_cpu(t.data))
cpu_data = clean_cpu(t.data)
if cpu_data.dtype == torch.bfloat16:
cpu_data = cpu_data.to(torch.float16)
n = np.array(cpu_data)
np_str = getattr(n, func_name)()
np_str = np_str.replace("array", "tensor")
np_str = np_str.replace("\n", "\n ")
Expand All @@ -590,7 +596,10 @@ def safe_copy(x, detach_tensor: bool = False):
if issubclass(type(x), (torch.Tensor, torch.nn.Parameter)):
if not detach_tensor:
return clean_clone(x)
vals_np = clean_cpu(x.data).numpy()
vals_cpu = clean_cpu(x.data)
if vals_cpu.dtype == torch.bfloat16:
vals_cpu = vals_cpu.to(torch.float16)
vals_np = vals_cpu.numpy()
vals_tensor = clean_from_numpy(vals_np)
if hasattr(x, "tl_tensor_label_raw"):
vals_tensor.tl_tensor_label_raw = x.tl_tensor_label_raw
Expand Down

0 comments on commit 813c834

Please sign in to comment.