From af7b492f427c6688767bdbb58c2ec87ae4308764 Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Mon, 12 Feb 2024 23:19:35 +0100 Subject: [PATCH 01/13] fix deepcopy and copy --- bitsandbytes/nn/modules.py | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 6eeecc273..81fa9aa1d 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -5,6 +5,7 @@ from typing import Any, Dict, Optional, TypeVar, Union, overload import warnings +import copy import torch from torch import Tensor, device, dtype, nn import torch.nn.functional as F @@ -213,6 +214,34 @@ def __new__( self.data = data self.module = module return self + + def __getstate__(self): + state = self.__dict__ + state["data"] = self.data + state["requires_grad"] = self.requires_grad + return state + + def __setstate__(self, state): + self.requires_grad = state["requires_grad"] + self.blocksize = state["blocksize"] + self.compress_statistics = state["compress_statistics"] + self.quant_type = state["quant_type"] + self.quant_state = state["quant_state"] + self.data = state["data"] + + def __deepcopy__(self,memo): + new_instance = type(self).__new__(type(self)) + state = self.__getstate__() + new_instance.__setstate__(state) + new_instance.quant_state = copy.deepcopy(state["quant_state"]) + new_instance.data = copy.deepcopy(state["data"]) + return new_instance + + def __copy__(self): + new_instance = type(self).__new__(type(self)) + state = self.__getstate__() + new_instance.__setstate__(state) + return new_instance @classmethod def from_prequantized(cls, data: torch.Tensor, quantized_stats: Dict[str, Any], requires_grad: bool = False, device='cuda', **kwargs) -> "Params4bit": @@ -235,6 +264,7 @@ def _quantize(self, device): self.module.quant_state = quant_state self.bnb_quantized = True return self + def cuda(self, device: Optional[Union[int, device, str]] = None, non_blocking: bool = False): return self.to(device='cuda' if device is None else device, non_blocking=non_blocking) From 6c8871b2153f3fe3d65b08e04541fa95c27e96ab Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Mon, 12 Feb 2024 23:51:01 +0100 Subject: [PATCH 02/13] add tests --- tests/test_linear4bit.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/tests/test_linear4bit.py b/tests/test_linear4bit.py index 13db28ed4..da442f63a 100644 --- a/tests/test_linear4bit.py +++ b/tests/test_linear4bit.py @@ -1,6 +1,7 @@ import os from tempfile import TemporaryDirectory +import copy import pytest import torch @@ -146,3 +147,18 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora target_compression = 0.143 if original_dtype == torch.float32 else 0.29 # these numbers get lower as weight shape increases ratio_error_msg = f"quantized_size {size_4:,} is larger on disk than {target_compression:.2%} of original size {size_orig:,}" assert size_ratio < target_compression, ratio_error_msg + +def test_copy_param(): + tensor = torch.tensor([1.,2.,3.,4.]) + param = bnb.nn.Params4bit(data = tensor, requires_grad=False).cuda(0) + + shallow_copy_param = copy.copy(param) + assert param.quant_state is shallow_copy_param.quant_state + assert param.data.data_ptr() == shallow_copy_param.data.data_ptr() + +def test_deepcopy_param(): + tensor = torch.tensor([1.,2.,3.,4.]) + param = bnb.nn.Params4bit(data = tensor, requires_grad=False).cuda(0) + copy_param = copy.deepcopy(param) + assert param.quant_state is not copy_param.quant_state + assert param.data.data_ptr() != copy_param.data.data_ptr() \ No newline at end of file From 1482d93ba84ec18a0788c2c11bc4256bae65dcde Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Mon, 12 Feb 2024 23:53:59 +0100 Subject: [PATCH 03/13] remove line --- bitsandbytes/nn/modules.py | 1 - 1 file changed, 1 deletion(-) diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 81fa9aa1d..bc91cc8f0 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -264,7 +264,6 @@ def _quantize(self, device): self.module.quant_state = quant_state self.bnb_quantized = True return self - def cuda(self, device: Optional[Union[int, device, str]] = None, non_blocking: bool = False): return self.to(device='cuda' if device is None else device, non_blocking=non_blocking) From ae0fcdfa77d42d1284d425414a47d2dfdeb47131 Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Tue, 13 Feb 2024 15:17:45 +0100 Subject: [PATCH 04/13] ruff fix --- bitsandbytes/nn/modules.py | 2 +- tests/test_functional.py | 2 +- tests/test_linear4bit.py | 2 +- tests/test_linear8bitlt.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index bc91cc8f0..da08855c1 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -2,10 +2,10 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import copy from typing import Any, Dict, Optional, TypeVar, Union, overload import warnings -import copy import torch from torch import Tensor, device, dtype, nn import torch.nn.functional as F diff --git a/tests/test_functional.py b/tests/test_functional.py index 2d4e959ad..ab997b0f5 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -9,8 +9,8 @@ from scipy.stats import norm import torch -import bitsandbytes as bnb from bitsandbytes import functional as F +import bitsandbytes as bnb from tests.helpers import ( BOOLEAN_TUPLES, TRUE_FALSE, diff --git a/tests/test_linear4bit.py b/tests/test_linear4bit.py index da442f63a..0a0eef5c3 100644 --- a/tests/test_linear4bit.py +++ b/tests/test_linear4bit.py @@ -1,7 +1,7 @@ +import copy import os from tempfile import TemporaryDirectory -import copy import pytest import torch diff --git a/tests/test_linear8bitlt.py b/tests/test_linear8bitlt.py index 6fa7efb8d..b46f985d7 100644 --- a/tests/test_linear8bitlt.py +++ b/tests/test_linear8bitlt.py @@ -5,8 +5,8 @@ import pytest import torch -import bitsandbytes as bnb from bitsandbytes import functional as F +import bitsandbytes as bnb from bitsandbytes.autograd import get_inverse_transform_indices, undo_layout from bitsandbytes.nn.modules import Linear8bitLt from tests.helpers import TRUE_FALSE, id_formatter From 587e7c2d452ab7a4c62616aea220e9a83429b8b9 Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Thu, 15 Feb 2024 16:52:41 +0100 Subject: [PATCH 05/13] ruff --- tests/test_functional.py | 2 +- tests/test_linear8bitlt.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_functional.py b/tests/test_functional.py index ab997b0f5..2d4e959ad 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -9,8 +9,8 @@ from scipy.stats import norm import torch -from bitsandbytes import functional as F import bitsandbytes as bnb +from bitsandbytes import functional as F from tests.helpers import ( BOOLEAN_TUPLES, TRUE_FALSE, diff --git a/tests/test_linear8bitlt.py b/tests/test_linear8bitlt.py index b46f985d7..6fa7efb8d 100644 --- a/tests/test_linear8bitlt.py +++ b/tests/test_linear8bitlt.py @@ -5,8 +5,8 @@ import pytest import torch -from bitsandbytes import functional as F import bitsandbytes as bnb +from bitsandbytes import functional as F from bitsandbytes.autograd import get_inverse_transform_indices, undo_layout from bitsandbytes.nn.modules import Linear8bitLt from tests.helpers import TRUE_FALSE, id_formatter From b4f938440d3c6696ddb15f1790d070b0a29aee8d Mon Sep 17 00:00:00 2001 From: Marc Sun <57196510+SunMarc@users.noreply.github.com> Date: Thu, 15 Feb 2024 10:57:16 -0500 Subject: [PATCH 06/13] Update tests/test_linear4bit.py Co-authored-by: Aarni Koskela --- tests/test_linear4bit.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/tests/test_linear4bit.py b/tests/test_linear4bit.py index 0a0eef5c3..9451592db 100644 --- a/tests/test_linear4bit.py +++ b/tests/test_linear4bit.py @@ -149,16 +149,17 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora assert size_ratio < target_compression, ratio_error_msg def test_copy_param(): - tensor = torch.tensor([1.,2.,3.,4.]) - param = bnb.nn.Params4bit(data = tensor, requires_grad=False).cuda(0) + tensor = torch.tensor([1.0, 2.0, 3.0, 4.0]) + param = bnb.nn.Params4bit(data=tensor, requires_grad=False).cuda(0) shallow_copy_param = copy.copy(param) assert param.quant_state is shallow_copy_param.quant_state assert param.data.data_ptr() == shallow_copy_param.data.data_ptr() + def test_deepcopy_param(): - tensor = torch.tensor([1.,2.,3.,4.]) - param = bnb.nn.Params4bit(data = tensor, requires_grad=False).cuda(0) + tensor = torch.tensor([1.0, 2.0, 3.0, 4.0]) + param = bnb.nn.Params4bit(data=tensor, requires_grad=False).cuda(0) copy_param = copy.deepcopy(param) assert param.quant_state is not copy_param.quant_state assert param.data.data_ptr() != copy_param.data.data_ptr() \ No newline at end of file From 9e32d6809bdbb970578880342cce2fff4381a7dd Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Wed, 21 Feb 2024 18:51:55 +0100 Subject: [PATCH 07/13] add missing state --- bitsandbytes/nn/modules.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 706950fc8..21d9eb8cf 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -228,6 +228,9 @@ def __setstate__(self, state): self.quant_type = state["quant_type"] self.quant_state = state["quant_state"] self.data = state["data"] + self.quant_storage = state["quant_storage"] + self.bnb_quantized = state["bnb_quantized"] + self.module = state["module"] def __deepcopy__(self,memo): new_instance = type(self).__new__(type(self)) From 34735ba89de8235ea9da6ef409f814dcea9e2038 Mon Sep 17 00:00:00 2001 From: Titus von Koeller <9048635+Titus-von-Koeller@users.noreply.github.com> Date: Wed, 21 Feb 2024 19:45:54 +0000 Subject: [PATCH 08/13] ruff format --- tests/test_linear4bit.py | 46 ++++++++++++++++++++++++++-------------- 1 file changed, 30 insertions(+), 16 deletions(-) diff --git a/tests/test_linear4bit.py b/tests/test_linear4bit.py index 9451592db..8a2fdc9dc 100644 --- a/tests/test_linear4bit.py +++ b/tests/test_linear4bit.py @@ -9,13 +9,14 @@ from tests.helpers import TRUE_FALSE storage = { - 'uint8': torch.uint8, - 'float16': torch.float16, - 'bfloat16': torch.bfloat16, - 'float32': torch.float32 + "uint8": torch.uint8, + "float16": torch.float16, + "bfloat16": torch.bfloat16, + "float32": torch.float32, } -@pytest.mark.parametrize("quant_storage", ['uint8', 'float16', 'bfloat16', 'float32']) + +@pytest.mark.parametrize("quant_storage", ["uint8", "float16", "bfloat16", "float32"]) @pytest.mark.parametrize("bias", TRUE_FALSE) @pytest.mark.parametrize("compress_statistics", TRUE_FALSE) @pytest.mark.parametrize("quant_type", ["nf4", "fp4"]) @@ -25,7 +26,9 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora device = "cuda" layer_shape = (300, 400) - linear = torch.nn.Linear(*layer_shape, dtype=original_dtype, device="cpu") # original layer + linear = torch.nn.Linear( + *layer_shape, dtype=original_dtype, device="cpu" + ) # original layer # Quantizing original layer linear_q = bnb.nn.Linear4bit( @@ -37,7 +40,9 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora quant_type=quant_type, device="meta", ) - new_weight = bnb.nn.Params4bit(data=linear.weight, quant_type=quant_type, requires_grad=False) + new_weight = bnb.nn.Params4bit( + data=linear.weight, quant_type=quant_type, requires_grad=False + ) linear_q.weight = new_weight if bias: linear_q.bias = torch.nn.Parameter(linear.bias) @@ -81,7 +86,12 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora quant_storage=storage[quant_storage], device="meta", ) - linear_qs.weight = bnb.nn.Params4bit(data=linear.weight, requires_grad=False, quant_type=quant_type, quant_storage=storage[quant_storage]) + linear_qs.weight = bnb.nn.Params4bit( + data=linear.weight, + requires_grad=False, + quant_type=quant_type, + quant_storage=storage[quant_storage], + ) if bias: linear_qs.bias = torch.nn.Parameter(linear.bias) linear_qs = linear_qs.to(device) @@ -92,7 +102,7 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora q0 = a.quant_state q1 = b.quant_state - for attr in ('code', 'dtype', 'blocksize', 'absmax'): + for attr in ("code", "dtype", "blocksize", "absmax"): c, d = getattr(q0, attr), getattr(q1, attr) if isinstance(c, torch.Tensor): assert torch.equal(c, d) @@ -100,7 +110,7 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora assert c == d, f"{c} != {d}" if q0.state2 is not None: - for attr in ('code', 'dtype', 'blocksize', 'absmax'): + for attr in ("code", "dtype", "blocksize", "absmax"): c, d = getattr(q0.state2, attr), getattr(q1.state2, attr) if isinstance(c, torch.Tensor): assert torch.equal(c, d) @@ -126,7 +136,7 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora assert torch.equal(a, c) # Test moving to CPU and back to GPU - linear_q2.to('cpu') + linear_q2.to("cpu") linear_q2.to(device) d = linear_qs(x) assert c.dtype == d.dtype @@ -140,14 +150,18 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora torch.save(linear.state_dict(), state_path) torch.save(linear_q.state_dict(), state_path_4bit) - size_orig, size_4 = os.path.getsize(state_path), os.path.getsize( - state_path_4bit + size_orig, size_4 = ( + os.path.getsize(state_path), + os.path.getsize(state_path_4bit), ) size_ratio = size_4 / size_orig - target_compression = 0.143 if original_dtype == torch.float32 else 0.29 # these numbers get lower as weight shape increases + target_compression = ( + 0.143 if original_dtype == torch.float32 else 0.29 + ) # these numbers get lower as weight shape increases ratio_error_msg = f"quantized_size {size_4:,} is larger on disk than {target_compression:.2%} of original size {size_orig:,}" assert size_ratio < target_compression, ratio_error_msg - + + def test_copy_param(): tensor = torch.tensor([1.0, 2.0, 3.0, 4.0]) param = bnb.nn.Params4bit(data=tensor, requires_grad=False).cuda(0) @@ -162,4 +176,4 @@ def test_deepcopy_param(): param = bnb.nn.Params4bit(data=tensor, requires_grad=False).cuda(0) copy_param = copy.deepcopy(param) assert param.quant_state is not copy_param.quant_state - assert param.data.data_ptr() != copy_param.data.data_ptr() \ No newline at end of file + assert param.data.data_ptr() != copy_param.data.data_ptr() From eead51f96fb20f084cc044994ddc49bf04ad2acc Mon Sep 17 00:00:00 2001 From: Titus von Koeller <9048635+Titus-von-Koeller@users.noreply.github.com> Date: Wed, 21 Feb 2024 19:46:58 +0000 Subject: [PATCH 09/13] ignore formatting commit for git blame --- .git-blame-ignore-revs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.git-blame-ignore-revs b/.git-blame-ignore-revs index f7dd01bdf..c0386dc9f 100644 --- a/.git-blame-ignore-revs +++ b/.git-blame-ignore-revs @@ -6,3 +6,6 @@ ea7c14f8ef64924f2d0ff80df3cdabf2c7299848 # Remove f-prefix from strings that don't use formatting 7727fa4c8c6c1ef2b109120aff4196a0a6bf3ed6 + +# format tests/linear_4bit.py +34735ba89de8235ea9da6ef409f814dcea9e2038 \ No newline at end of file From c06437301d61de74172c8366d0b9ff51b5b588f4 Mon Sep 17 00:00:00 2001 From: Titus von Koeller <9048635+Titus-von-Koeller@users.noreply.github.com> Date: Wed, 21 Feb 2024 20:21:35 +0000 Subject: [PATCH 10/13] Params4bit should be initialized as frozen by default --- bitsandbytes/nn/modules.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 21d9eb8cf..6e4b279f4 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -192,7 +192,7 @@ class Params4bit(torch.nn.Parameter): def __new__( cls, data: Optional[torch.Tensor] = None, - requires_grad=True, + requires_grad=False, # quantized weights should be frozen by default quant_state: Optional[QuantState] = None, blocksize: int = 64, compress_statistics: bool = True, @@ -214,7 +214,7 @@ def __new__( self.data = data self.module = module return self - + def __getstate__(self): state = self.__dict__ state["data"] = self.data @@ -231,7 +231,7 @@ def __setstate__(self, state): self.quant_storage = state["quant_storage"] self.bnb_quantized = state["bnb_quantized"] self.module = state["module"] - + def __deepcopy__(self,memo): new_instance = type(self).__new__(type(self)) state = self.__getstate__() @@ -259,9 +259,13 @@ def from_prequantized(cls, data: torch.Tensor, quantized_stats: Dict[str, Any], def _quantize(self, device): w = self.data.contiguous().cuda(device) - w_4bit, quant_state = bnb.functional.quantize_4bit(w, blocksize=self.blocksize, compress_statistics=self.compress_statistics, - quant_type=self.quant_type, quant_storage=self.quant_storage) - self.data = w_4bit + w_4bit, quant_state = bnb.functional.quantize_4bit( + w, + blocksize=self.blocksize, + compress_statistics=self.compress_statistics, + quant_type=self.quant_type, + quant_storage=self.quant_storage, + ) self.quant_state = quant_state if self.module is not None: self.module.quant_state = quant_state From 00b6f3180beaf81447cd095d8bc29897bf6a1c95 Mon Sep 17 00:00:00 2001 From: Titus von Koeller <9048635+Titus-von-Koeller@users.noreply.github.com> Date: Wed, 21 Feb 2024 20:22:03 +0000 Subject: [PATCH 11/13] add test for serialization round-tripping --- tests/test_linear4bit.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/tests/test_linear4bit.py b/tests/test_linear4bit.py index 8a2fdc9dc..3e62bdf3b 100644 --- a/tests/test_linear4bit.py +++ b/tests/test_linear4bit.py @@ -1,5 +1,6 @@ import copy import os +import pickle from tempfile import TemporaryDirectory import pytest @@ -177,3 +178,20 @@ def test_deepcopy_param(): copy_param = copy.deepcopy(param) assert param.quant_state is not copy_param.quant_state assert param.data.data_ptr() != copy_param.data.data_ptr() + + +def test_params4bit_real_serialization(): + original_tensor = torch.tensor([1.0, 2.0, 3.0, 4.0], dtype=torch.float32) + original_param = bnb.nn.Params4bit(data=original_tensor, quant_type="fp4") + + original_param.cuda(0) # move to CUDA to trigger quantization + + serialized_param = pickle.dumps(original_param) + deserialized_param = pickle.loads(serialized_param) + + assert torch.equal(original_param.data, deserialized_param.data) + assert original_param.requires_grad == deserialized_param.requires_grad == False + assert original_param.quant_type == deserialized_param.quant_type + assert original_param.blocksize == deserialized_param.blocksize + assert original_param.compress_statistics == deserialized_param.compress_statistics + assert original_param.quant_state == deserialized_param.quant_state From ad87fc496729b70fae2813d48123343df53c3657 Mon Sep 17 00:00:00 2001 From: Titus von Koeller <9048635+Titus-von-Koeller@users.noreply.github.com> Date: Wed, 21 Feb 2024 20:22:36 +0000 Subject: [PATCH 12/13] add comparison capability for QuantSate --- bitsandbytes/functional.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 9fc5e08f0..f0de962e1 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -706,6 +706,21 @@ def to(self, device): self.state2.absmax = self.state2.absmax.to(device) self.state2.code = self.state2.code.to(device) + def __eq__(self, other): + if not isinstance(other, QuantState): + return False + + return ( + torch.allclose(self.absmax, other.absmax, atol=1e-6) and + self.shape == other.shape and + torch.allclose(self.code, other.code, atol=1e-6) and + self.dtype == other.dtype and + self.blocksize == other.blocksize and + self.quant_type == other.quant_type and + (self.offset == other.offset if self.offset is not None and other.offset is not None else self.offset is other.offset) and + (self.state2 == other.state2 if self.state2 is not None and other.state2 is not None else self.state2 is other.state2) + ) + def quantize_blockwise( A: Tensor, From b3a9bd51bbb922fd25756d9eb07b8b41302a91ee Mon Sep 17 00:00:00 2001 From: Titus von Koeller <9048635+Titus-von-Koeller@users.noreply.github.com> Date: Wed, 21 Feb 2024 20:31:24 +0000 Subject: [PATCH 13/13] add back accidentally remove line --- bitsandbytes/nn/modules.py | 1 + 1 file changed, 1 insertion(+) diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 6e4b279f4..bd2bd5832 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -266,6 +266,7 @@ def _quantize(self, device): quant_type=self.quant_type, quant_storage=self.quant_storage, ) + self.data = w_4bit self.quant_state = quant_state if self.module is not None: self.module.quant_state = quant_state