Skip to content

Commit

Permalink
add deepcopy and copy for Param4bit (#1060)
Browse files Browse the repository at this point in the history
* fix deepcopy and copy

* add tests

* remove line

* ruff fix

* ruff

* Update tests/test_linear4bit.py

Co-authored-by: Aarni Koskela <[email protected]>

* add missing state

* ruff format

* ignore formatting commit for git blame

* Params4bit should be initialized as frozen by default

* add test for serialization round-tripping

* add comparison capability for QuantSate

* add back accidentally remove line

---------

Co-authored-by: Aarni Koskela <[email protected]>
Co-authored-by: Titus von Koeller <[email protected]>
  • Loading branch information
3 people committed Feb 21, 2024
1 parent b0730f4 commit cfd6ac7
Show file tree
Hide file tree
Showing 4 changed files with 121 additions and 17 deletions.
3 changes: 3 additions & 0 deletions .git-blame-ignore-revs
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,6 @@ ea7c14f8ef64924f2d0ff80df3cdabf2c7299848

# Remove f-prefix from strings that don't use formatting
7727fa4c8c6c1ef2b109120aff4196a0a6bf3ed6

# format tests/linear_4bit.py
34735ba89de8235ea9da6ef409f814dcea9e2038
15 changes: 15 additions & 0 deletions bitsandbytes/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -706,6 +706,21 @@ def to(self, device):
self.state2.absmax = self.state2.absmax.to(device)
self.state2.code = self.state2.code.to(device)

def __eq__(self, other):
if not isinstance(other, QuantState):
return False

return (
torch.allclose(self.absmax, other.absmax, atol=1e-6) and
self.shape == other.shape and
torch.allclose(self.code, other.code, atol=1e-6) and
self.dtype == other.dtype and
self.blocksize == other.blocksize and
self.quant_type == other.quant_type and
(self.offset == other.offset if self.offset is not None and other.offset is not None else self.offset is other.offset) and
(self.state2 == other.state2 if self.state2 is not None and other.state2 is not None else self.state2 is other.state2)
)


def quantize_blockwise(
A: Tensor,
Expand Down
43 changes: 40 additions & 3 deletions bitsandbytes/nn/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import copy
from typing import Any, Dict, Optional, TypeVar, Union, overload
import warnings

Expand Down Expand Up @@ -191,7 +192,7 @@ class Params4bit(torch.nn.Parameter):
def __new__(
cls,
data: Optional[torch.Tensor] = None,
requires_grad=True,
requires_grad=False, # quantized weights should be frozen by default
quant_state: Optional[QuantState] = None,
blocksize: int = 64,
compress_statistics: bool = True,
Expand All @@ -214,6 +215,37 @@ def __new__(
self.module = module
return self

def __getstate__(self):
state = self.__dict__
state["data"] = self.data
state["requires_grad"] = self.requires_grad
return state

def __setstate__(self, state):
self.requires_grad = state["requires_grad"]
self.blocksize = state["blocksize"]
self.compress_statistics = state["compress_statistics"]
self.quant_type = state["quant_type"]
self.quant_state = state["quant_state"]
self.data = state["data"]
self.quant_storage = state["quant_storage"]
self.bnb_quantized = state["bnb_quantized"]
self.module = state["module"]

def __deepcopy__(self,memo):
new_instance = type(self).__new__(type(self))
state = self.__getstate__()
new_instance.__setstate__(state)
new_instance.quant_state = copy.deepcopy(state["quant_state"])
new_instance.data = copy.deepcopy(state["data"])
return new_instance

def __copy__(self):
new_instance = type(self).__new__(type(self))
state = self.__getstate__()
new_instance.__setstate__(state)
return new_instance

@classmethod
def from_prequantized(cls, data: torch.Tensor, quantized_stats: Dict[str, Any], requires_grad: bool = False, device='cuda', **kwargs) -> "Params4bit":
self = torch.Tensor._make_subclass(cls, data.to(device))
Expand All @@ -227,8 +259,13 @@ def from_prequantized(cls, data: torch.Tensor, quantized_stats: Dict[str, Any],

def _quantize(self, device):
w = self.data.contiguous().cuda(device)
w_4bit, quant_state = bnb.functional.quantize_4bit(w, blocksize=self.blocksize, compress_statistics=self.compress_statistics,
quant_type=self.quant_type, quant_storage=self.quant_storage)
w_4bit, quant_state = bnb.functional.quantize_4bit(
w,
blocksize=self.blocksize,
compress_statistics=self.compress_statistics,
quant_type=self.quant_type,
quant_storage=self.quant_storage,
)
self.data = w_4bit
self.quant_state = quant_state
if self.module is not None:
Expand Down
77 changes: 63 additions & 14 deletions tests/test_linear4bit.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import copy
import os
import pickle
from tempfile import TemporaryDirectory

import pytest
Expand All @@ -8,13 +10,14 @@
from tests.helpers import TRUE_FALSE

storage = {
'uint8': torch.uint8,
'float16': torch.float16,
'bfloat16': torch.bfloat16,
'float32': torch.float32
"uint8": torch.uint8,
"float16": torch.float16,
"bfloat16": torch.bfloat16,
"float32": torch.float32,
}

@pytest.mark.parametrize("quant_storage", ['uint8', 'float16', 'bfloat16', 'float32'])

@pytest.mark.parametrize("quant_storage", ["uint8", "float16", "bfloat16", "float32"])
@pytest.mark.parametrize("bias", TRUE_FALSE)
@pytest.mark.parametrize("compress_statistics", TRUE_FALSE)
@pytest.mark.parametrize("quant_type", ["nf4", "fp4"])
Expand All @@ -24,7 +27,9 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora
device = "cuda"
layer_shape = (300, 400)

linear = torch.nn.Linear(*layer_shape, dtype=original_dtype, device="cpu") # original layer
linear = torch.nn.Linear(
*layer_shape, dtype=original_dtype, device="cpu"
) # original layer

# Quantizing original layer
linear_q = bnb.nn.Linear4bit(
Expand All @@ -36,7 +41,9 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora
quant_type=quant_type,
device="meta",
)
new_weight = bnb.nn.Params4bit(data=linear.weight, quant_type=quant_type, requires_grad=False)
new_weight = bnb.nn.Params4bit(
data=linear.weight, quant_type=quant_type, requires_grad=False
)
linear_q.weight = new_weight
if bias:
linear_q.bias = torch.nn.Parameter(linear.bias)
Expand Down Expand Up @@ -80,7 +87,12 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora
quant_storage=storage[quant_storage],
device="meta",
)
linear_qs.weight = bnb.nn.Params4bit(data=linear.weight, requires_grad=False, quant_type=quant_type, quant_storage=storage[quant_storage])
linear_qs.weight = bnb.nn.Params4bit(
data=linear.weight,
requires_grad=False,
quant_type=quant_type,
quant_storage=storage[quant_storage],
)
if bias:
linear_qs.bias = torch.nn.Parameter(linear.bias)
linear_qs = linear_qs.to(device)
Expand All @@ -91,15 +103,15 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora

q0 = a.quant_state
q1 = b.quant_state
for attr in ('code', 'dtype', 'blocksize', 'absmax'):
for attr in ("code", "dtype", "blocksize", "absmax"):
c, d = getattr(q0, attr), getattr(q1, attr)
if isinstance(c, torch.Tensor):
assert torch.equal(c, d)
else:
assert c == d, f"{c} != {d}"

if q0.state2 is not None:
for attr in ('code', 'dtype', 'blocksize', 'absmax'):
for attr in ("code", "dtype", "blocksize", "absmax"):
c, d = getattr(q0.state2, attr), getattr(q1.state2, attr)
if isinstance(c, torch.Tensor):
assert torch.equal(c, d)
Expand All @@ -125,7 +137,7 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora
assert torch.equal(a, c)

# Test moving to CPU and back to GPU
linear_q2.to('cpu')
linear_q2.to("cpu")
linear_q2.to(device)
d = linear_qs(x)
assert c.dtype == d.dtype
Expand All @@ -139,10 +151,47 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora
torch.save(linear.state_dict(), state_path)
torch.save(linear_q.state_dict(), state_path_4bit)

size_orig, size_4 = os.path.getsize(state_path), os.path.getsize(
state_path_4bit
size_orig, size_4 = (
os.path.getsize(state_path),
os.path.getsize(state_path_4bit),
)
size_ratio = size_4 / size_orig
target_compression = 0.143 if original_dtype == torch.float32 else 0.29 # these numbers get lower as weight shape increases
target_compression = (
0.143 if original_dtype == torch.float32 else 0.29
) # these numbers get lower as weight shape increases
ratio_error_msg = f"quantized_size {size_4:,} is larger on disk than {target_compression:.2%} of original size {size_orig:,}"
assert size_ratio < target_compression, ratio_error_msg


def test_copy_param():
tensor = torch.tensor([1.0, 2.0, 3.0, 4.0])
param = bnb.nn.Params4bit(data=tensor, requires_grad=False).cuda(0)

shallow_copy_param = copy.copy(param)
assert param.quant_state is shallow_copy_param.quant_state
assert param.data.data_ptr() == shallow_copy_param.data.data_ptr()


def test_deepcopy_param():
tensor = torch.tensor([1.0, 2.0, 3.0, 4.0])
param = bnb.nn.Params4bit(data=tensor, requires_grad=False).cuda(0)
copy_param = copy.deepcopy(param)
assert param.quant_state is not copy_param.quant_state
assert param.data.data_ptr() != copy_param.data.data_ptr()


def test_params4bit_real_serialization():
original_tensor = torch.tensor([1.0, 2.0, 3.0, 4.0], dtype=torch.float32)
original_param = bnb.nn.Params4bit(data=original_tensor, quant_type="fp4")

original_param.cuda(0) # move to CUDA to trigger quantization

serialized_param = pickle.dumps(original_param)
deserialized_param = pickle.loads(serialized_param)

assert torch.equal(original_param.data, deserialized_param.data)
assert original_param.requires_grad == deserialized_param.requires_grad == False
assert original_param.quant_type == deserialized_param.quant_type
assert original_param.blocksize == deserialized_param.blocksize
assert original_param.compress_statistics == deserialized_param.compress_statistics
assert original_param.quant_state == deserialized_param.quant_state

0 comments on commit cfd6ac7

Please sign in to comment.