Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add deepcopy and copy for Param4bit #1060

Merged

Conversation

SunMarc
Copy link
Contributor

@SunMarc SunMarc commented Feb 12, 2024

What does this PR do ?

This PR makes possible to deepcopy and copy the Params4bit class. With this feature, you can deepcopy/copy a 4-bit model
The tests related to 4-bit on transformers have also successfully passed.

Fixes huggingface/accelerate#2248

Deepcopy/copy Params4bit

from bitsandbytes.nn import Params4bit
import torch
import copy


t = torch.tensor([1.,2.,3.,4.])
param = Params4bit(data = t, requires_grad=False).cuda(0)

copy_param = copy.deepcopy(param)
assert param.quant_state is not copy_param.quant_state
assert param.data.data_ptr() != copy_param.data.data_ptr()

shallow_copy_param = copy.copy(param)
assert param.quant_state is shallow_copy_param.quant_state
assert param.data.data_ptr() == shallow_copy_param.data.data_ptr()

Deepcopy/copy a 4-bit model

model_name = "meta-llama/Llama-2-7b-hf"

tokenizer = AutoTokenizer.from_pretrained(model_name)

fp4_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=False,
    bnb_4bit_quant_type="fp4",
    bnb_4bit_compute_dtype=torch.float16
)

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=fp4_config,
    torch_dtype=torch.float16,
)

def generate(model):
    prompts = ["I would like to"]
    token_dict = tokenizer(prompts, return_tensors="pt").to(0)
    output_ids = model.generate(**token_dict, max_new_tokens=10)
    print(tokenizer.batch_decode(output_ids))
 
generate(model)

# deepcopy -> Memory increase 
model_copy = copy.deepcopy(model)
generate(model_copy)

# shallow copy -> No memory increase
model_shallow_copy = copy.copy(model)
generate(model_shallow_copy)

Cc @Titus-von-Koeller

Copy link

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@Titus-von-Koeller
Copy link
Collaborator

Thanks @SunMarc for taking the lead on this, greatly appreciated! I'll take a look and get back to you. Looks really good at a first glance 🤗

tests/test_functional.py Outdated Show resolved Hide resolved
Comment on lines 218 to 244
def __getstate__(self):
state = self.__dict__
state["data"] = self.data
state["requires_grad"] = self.requires_grad
return state

def __setstate__(self, state):
self.requires_grad = state["requires_grad"]
self.blocksize = state["blocksize"]
self.compress_statistics = state["compress_statistics"]
self.quant_type = state["quant_type"]
self.quant_state = state["quant_state"]
self.data = state["data"]

def __deepcopy__(self,memo):
new_instance = type(self).__new__(type(self))
state = self.__getstate__()
new_instance.__setstate__(state)
new_instance.quant_state = copy.deepcopy(state["quant_state"])
new_instance.data = copy.deepcopy(state["data"])
return new_instance

def __copy__(self):
new_instance = type(self).__new__(type(self))
state = self.__getstate__()
new_instance.__setstate__(state)
return new_instance
Copy link
Contributor

@akx akx Feb 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is having to do this dance common in Torch world? 🤔

I'm a little worried that someone adding a new field in __init__ will inevitably miss adding them here...

Copy link
Contributor Author

@SunMarc SunMarc Feb 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is having to do this dance common in Torch world? 🤔

I don't think but I wasn't able to find a better solution. I based my solution over this specific code from torch.

I'm a little worried that someone adding a new field in init will inevitably miss adding them here...

Yeah, that's true :/ . I tried modify __setstate__ so that we udpate state.__dict__ using self.__dict__ but some attributes were not copied properly.

tests/test_linear8bitlt.py Outdated Show resolved Hide resolved
tests/test_linear4bit.py Outdated Show resolved Hide resolved
tests/test_linear4bit.py Outdated Show resolved Hide resolved
@prathikr
Copy link

@SunMarc can this be merged soon?

@SunMarc
Copy link
Contributor Author

SunMarc commented Feb 20, 2024

@SunMarc can this be merged soon?

Yes, I'm waiting the review from @Titus-von-Koeller.

@Titus-von-Koeller
Copy link
Collaborator

I'll review this tmr, but the release will be at earliest next week. The PR looks great as is and there's really no reason not to merge it other than that I need to verify it first. As a side note, we currently have a lot in the pipeline and the release process is still in transition and review. Also, there's no PR based CI pipeline yet which would make this easier to validate and green light. Other than that the cross platform effort and FSDP come first priority-wise this week.

@prathikr
Copy link

@Titus-von-Koeller I appreciate the transparency on prioritization of this task.

However, I would like to point out that I raised the original issue over 2 months ago and several people on the ONNX Runtime team have encountered this problem. If there is truly no harm in merging it, please do so ASAP. Thank you.

@SunMarc
Copy link
Contributor Author

SunMarc commented Feb 21, 2024

As Titus said, we will most probably merge it tomorrow. Can you also double check that this PR will actually solve the issue that the ONNX Runtime team have and do not face other issues @prathikr. Thank you for your patience 🤗

@Titus-von-Koeller Titus-von-Koeller self-assigned this Feb 21, 2024
@Titus-von-Koeller
Copy link
Collaborator

Titus-von-Koeller commented Feb 21, 2024

Alright, I took a deep look this afternoon and also reran all the tests. Great work on this PR and thanks again for taking the initiative. Overall, everything looks perfect and is very polished; really appreciate this! The only thing that needs fixing is the last comment I made in the review that __getstate__ and __setstate__ should be matching. Then it's ready to merge.

@SunMarc
Copy link
Contributor Author

SunMarc commented Feb 21, 2024

Alright, I took a deep look this afternoon and also reran all the tests. Great work on this PR and thanks again for taking the initiative. Overall, everything looks perfect and is very polished; really appreciate this! The only thing that needs fixing is the last comment I made in the review that getstate and setstate should be matching. Then it's ready to merge.

Fixed ! Thanks again for your review !

@prathikr
Copy link

As Titus said, we will most probably merge it tomorrow. Can you also double check that this PR will actually solve the issue that the ONNX Runtime team have and do not face other issues @prathikr. Thank you for your patience 🤗

Yes @SunMarc this indeed resolves the issue, thank you for the efforts to merge this ASAP.

@Titus-von-Koeller Titus-von-Koeller merged commit cfd6ac7 into bitsandbytes-foundation:main Feb 21, 2024
9 of 10 checks passed
@Titus-von-Koeller
Copy link
Collaborator

Titus-von-Koeller commented Feb 21, 2024

Ok, added another test roundtripping the serialization. For that I added the capability to compare QuantState with each other. Also reran the tests, also Transformers BNB integration ones.

Happy to have this sorted now. Thanks @prathikr for raising it. Maybe you could check if things work for you now with BNB installed from source? Would be good to become aware of potential issues, given your particular use-case, before doing the release.

@akx
Copy link
Contributor

akx commented Feb 22, 2024

@Titus-von-Koeller FWIW, no point adding a commit to .git-blame-ignore-revs if you do a squash merge... 😅

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

deepcopy fails after accelerate==0.23.0
4 participants