Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add deepcopy and copy for Param4bit #1060

Merged
3 changes: 3 additions & 0 deletions .git-blame-ignore-revs
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,6 @@ ea7c14f8ef64924f2d0ff80df3cdabf2c7299848

# Remove f-prefix from strings that don't use formatting
7727fa4c8c6c1ef2b109120aff4196a0a6bf3ed6

# format tests/linear_4bit.py
34735ba89de8235ea9da6ef409f814dcea9e2038
15 changes: 15 additions & 0 deletions bitsandbytes/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -706,6 +706,21 @@ def to(self, device):
self.state2.absmax = self.state2.absmax.to(device)
self.state2.code = self.state2.code.to(device)

def __eq__(self, other):
if not isinstance(other, QuantState):
return False

return (
torch.allclose(self.absmax, other.absmax, atol=1e-6) and
self.shape == other.shape and
torch.allclose(self.code, other.code, atol=1e-6) and
self.dtype == other.dtype and
self.blocksize == other.blocksize and
self.quant_type == other.quant_type and
(self.offset == other.offset if self.offset is not None and other.offset is not None else self.offset is other.offset) and
(self.state2 == other.state2 if self.state2 is not None and other.state2 is not None else self.state2 is other.state2)
)


def quantize_blockwise(
A: Tensor,
Expand Down
43 changes: 40 additions & 3 deletions bitsandbytes/nn/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import copy
from typing import Any, Dict, Optional, TypeVar, Union, overload
import warnings

Expand Down Expand Up @@ -191,7 +192,7 @@ class Params4bit(torch.nn.Parameter):
def __new__(
cls,
data: Optional[torch.Tensor] = None,
requires_grad=True,
requires_grad=False, # quantized weights should be frozen by default
quant_state: Optional[QuantState] = None,
blocksize: int = 64,
compress_statistics: bool = True,
Expand All @@ -214,6 +215,37 @@ def __new__(
self.module = module
return self

def __getstate__(self):
state = self.__dict__
state["data"] = self.data
state["requires_grad"] = self.requires_grad
return state

def __setstate__(self, state):
self.requires_grad = state["requires_grad"]
self.blocksize = state["blocksize"]
self.compress_statistics = state["compress_statistics"]
self.quant_type = state["quant_type"]
self.quant_state = state["quant_state"]
self.data = state["data"]
self.quant_storage = state["quant_storage"]
self.bnb_quantized = state["bnb_quantized"]
self.module = state["module"]

def __deepcopy__(self,memo):
new_instance = type(self).__new__(type(self))
state = self.__getstate__()
new_instance.__setstate__(state)
new_instance.quant_state = copy.deepcopy(state["quant_state"])
new_instance.data = copy.deepcopy(state["data"])
return new_instance

def __copy__(self):
new_instance = type(self).__new__(type(self))
state = self.__getstate__()
new_instance.__setstate__(state)
return new_instance

@classmethod
def from_prequantized(cls, data: torch.Tensor, quantized_stats: Dict[str, Any], requires_grad: bool = False, device='cuda', **kwargs) -> "Params4bit":
self = torch.Tensor._make_subclass(cls, data.to(device))
Expand All @@ -227,8 +259,13 @@ def from_prequantized(cls, data: torch.Tensor, quantized_stats: Dict[str, Any],

def _quantize(self, device):
w = self.data.contiguous().cuda(device)
w_4bit, quant_state = bnb.functional.quantize_4bit(w, blocksize=self.blocksize, compress_statistics=self.compress_statistics,
quant_type=self.quant_type, quant_storage=self.quant_storage)
w_4bit, quant_state = bnb.functional.quantize_4bit(
w,
blocksize=self.blocksize,
compress_statistics=self.compress_statistics,
quant_type=self.quant_type,
quant_storage=self.quant_storage,
)
self.data = w_4bit
self.quant_state = quant_state
if self.module is not None:
Expand Down
77 changes: 63 additions & 14 deletions tests/test_linear4bit.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import copy
import os
import pickle
from tempfile import TemporaryDirectory

import pytest
Expand All @@ -8,13 +10,14 @@
from tests.helpers import TRUE_FALSE

storage = {
'uint8': torch.uint8,
'float16': torch.float16,
'bfloat16': torch.bfloat16,
'float32': torch.float32
"uint8": torch.uint8,
"float16": torch.float16,
"bfloat16": torch.bfloat16,
"float32": torch.float32,
}

@pytest.mark.parametrize("quant_storage", ['uint8', 'float16', 'bfloat16', 'float32'])

@pytest.mark.parametrize("quant_storage", ["uint8", "float16", "bfloat16", "float32"])
@pytest.mark.parametrize("bias", TRUE_FALSE)
@pytest.mark.parametrize("compress_statistics", TRUE_FALSE)
@pytest.mark.parametrize("quant_type", ["nf4", "fp4"])
Expand All @@ -24,7 +27,9 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora
device = "cuda"
layer_shape = (300, 400)

linear = torch.nn.Linear(*layer_shape, dtype=original_dtype, device="cpu") # original layer
linear = torch.nn.Linear(
*layer_shape, dtype=original_dtype, device="cpu"
) # original layer

# Quantizing original layer
linear_q = bnb.nn.Linear4bit(
Expand All @@ -36,7 +41,9 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora
quant_type=quant_type,
device="meta",
)
new_weight = bnb.nn.Params4bit(data=linear.weight, quant_type=quant_type, requires_grad=False)
new_weight = bnb.nn.Params4bit(
data=linear.weight, quant_type=quant_type, requires_grad=False
)
linear_q.weight = new_weight
if bias:
linear_q.bias = torch.nn.Parameter(linear.bias)
Expand Down Expand Up @@ -80,7 +87,12 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora
quant_storage=storage[quant_storage],
device="meta",
)
linear_qs.weight = bnb.nn.Params4bit(data=linear.weight, requires_grad=False, quant_type=quant_type, quant_storage=storage[quant_storage])
linear_qs.weight = bnb.nn.Params4bit(
data=linear.weight,
requires_grad=False,
quant_type=quant_type,
quant_storage=storage[quant_storage],
)
if bias:
linear_qs.bias = torch.nn.Parameter(linear.bias)
linear_qs = linear_qs.to(device)
Expand All @@ -91,15 +103,15 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora

q0 = a.quant_state
q1 = b.quant_state
for attr in ('code', 'dtype', 'blocksize', 'absmax'):
for attr in ("code", "dtype", "blocksize", "absmax"):
c, d = getattr(q0, attr), getattr(q1, attr)
if isinstance(c, torch.Tensor):
assert torch.equal(c, d)
else:
assert c == d, f"{c} != {d}"

if q0.state2 is not None:
for attr in ('code', 'dtype', 'blocksize', 'absmax'):
for attr in ("code", "dtype", "blocksize", "absmax"):
c, d = getattr(q0.state2, attr), getattr(q1.state2, attr)
if isinstance(c, torch.Tensor):
assert torch.equal(c, d)
Expand All @@ -125,7 +137,7 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora
assert torch.equal(a, c)

# Test moving to CPU and back to GPU
linear_q2.to('cpu')
linear_q2.to("cpu")
linear_q2.to(device)
d = linear_qs(x)
assert c.dtype == d.dtype
Expand All @@ -139,10 +151,47 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora
torch.save(linear.state_dict(), state_path)
torch.save(linear_q.state_dict(), state_path_4bit)

size_orig, size_4 = os.path.getsize(state_path), os.path.getsize(
state_path_4bit
size_orig, size_4 = (
os.path.getsize(state_path),
os.path.getsize(state_path_4bit),
)
size_ratio = size_4 / size_orig
target_compression = 0.143 if original_dtype == torch.float32 else 0.29 # these numbers get lower as weight shape increases
target_compression = (
0.143 if original_dtype == torch.float32 else 0.29
) # these numbers get lower as weight shape increases
ratio_error_msg = f"quantized_size {size_4:,} is larger on disk than {target_compression:.2%} of original size {size_orig:,}"
assert size_ratio < target_compression, ratio_error_msg


def test_copy_param():
tensor = torch.tensor([1.0, 2.0, 3.0, 4.0])
param = bnb.nn.Params4bit(data=tensor, requires_grad=False).cuda(0)

shallow_copy_param = copy.copy(param)
assert param.quant_state is shallow_copy_param.quant_state
assert param.data.data_ptr() == shallow_copy_param.data.data_ptr()


def test_deepcopy_param():
tensor = torch.tensor([1.0, 2.0, 3.0, 4.0])
param = bnb.nn.Params4bit(data=tensor, requires_grad=False).cuda(0)
copy_param = copy.deepcopy(param)
assert param.quant_state is not copy_param.quant_state
assert param.data.data_ptr() != copy_param.data.data_ptr()


def test_params4bit_real_serialization():
original_tensor = torch.tensor([1.0, 2.0, 3.0, 4.0], dtype=torch.float32)
original_param = bnb.nn.Params4bit(data=original_tensor, quant_type="fp4")

original_param.cuda(0) # move to CUDA to trigger quantization

serialized_param = pickle.dumps(original_param)
deserialized_param = pickle.loads(serialized_param)

assert torch.equal(original_param.data, deserialized_param.data)
assert original_param.requires_grad == deserialized_param.requires_grad == False
assert original_param.quant_type == deserialized_param.quant_type
assert original_param.blocksize == deserialized_param.blocksize
assert original_param.compress_statistics == deserialized_param.compress_statistics
assert original_param.quant_state == deserialized_param.quant_state
Loading