Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to save memory when loading weights? #51

Open
KaneGreen opened this issue Mar 23, 2024 · 8 comments
Open

How to save memory when loading weights? #51

KaneGreen opened this issue Mar 23, 2024 · 8 comments
Assignees
Labels
bug Something isn't working

Comments

@KaneGreen
Copy link

KaneGreen commented Mar 23, 2024

OS: Windows 11 22631.3296
Python: 3.11.8
PyTorch: 2.2.1 (installed in conda env)
CUDA: 12.1 (installed in conda env)
NV Driver: 551.76
Gemma Model: 7b-it

I was trying to run the inference. Before I started, I have used 6GB memory and had 26GB free.

I obseved that when the code runs to the load_weights function, the memory usage went up to 98% of my total 32GB RAM, lasted for about a minute and then dropped to normal. In that time, I haven't called the to(device) function in the next line.

Form the Task Manager, at the time of high usage, I see the python.exe took about 28GB Working set, while the active private working set was about 14GB. And at that time, the page file of Windows was involved to keep the system working.

Taskmgr_NyaKIArP30

However, the 7B-it model (16bit float) should not exceed 16GB size. Allocating 28GB of memory in this process is pointless.
Remember what I said above, the memory usage eventually dropped to normal without calling to(device)? This just showed that it doesn't require that much memory.

Sorry, I don't know how Python or PyTorch manage memory. But I'm wondering if it's possible to improve this line for smoothing memory usage spikes?

@pengchongjin
Copy link
Collaborator

My guess is when calling torch.load, it creates copies of weights as temporary variables, which doubles the memory, but evenutally get gc'ed.

self.load_state_dict(
torch.load(
model_path, mmap=True, weights_only=True,
)['model_state_dict'],
strict=False,
)

Maybe one workaround could be loading weights layer by layer in sequence and gc weights immediately after the weight of a particular layer gets loaded. I think in this way, it will have less peak memory usage.

@michaelmoynihan you have investigated this before, do you have any insights?

@Gopi-Uppari
Copy link
Collaborator

Hi @KaneGreen,

Could you please confirm if this issue is resolved for you with the above comment ? Please feel free to close the issue if it is resolved ?

Thank you.

@Gopi-Uppari Gopi-Uppari self-assigned this Sep 26, 2024
@KaneGreen
Copy link
Author

@Gopi-Uppari
This issue has not been resolved.

I just tried the 9B model (model weights file size about 17.2GB)
My computer has 32GB of memory installed. About 5.5GB was occupied by the operating system and other programs before the model was loaded. The available memory should be above 26GB.
If Windows Virtual Memory is turned on (similar to Linux Swap), I observed memory usage spikes, and almost all physical memory was used. There was noticeable writes on the disk containing the Windows page file (similar to Linux Swap File).
If Windows Virtual Memory is turned off, the model cannot be loaded.

@Gopi-Uppari
Copy link
Collaborator

Hi @KaneGreen,

The gemma-2-9b model size should be over 35 GB, not 17.2 GB, as shown in the attached screenshot for further details.
Image

If virtual memory was disabled, the system couldn't load the model due to insufficient space in the physical memory. Since the system lacked enough RAM, it needed to rely on virtual memory to handle the overflow. However, disabling virtual memory caused the model to fail to load. If you enable virtual memory, the model might load, but performance could suffer because hard drive speeds (used for virtual memory) are much slower than RAM. To address this issue, one potential solution is using quantization techniques to reduce the model's size, which can help it fit into the available memory without relying heavily on virtual memory.

Thank you.

@KaneGreen
Copy link
Author

KaneGreen commented Oct 29, 2024

@Gopi-Uppari Well, I use the model download from kaggle, not from Hugging Face.

.safetensors is the file format from Hugging Face.

As far as I know, 9B mean 9 billion, and Gemma uses FP16. Even a model has 9.999 billion variables. 9.999 * (10^9) * 16bit = 18.6 GBytes
So, I think the model size cannot be over 20GB.

@Gopi-Uppari
Copy link
Collaborator

Hi @KaneGreen,

Actually, I was referring to the Gemma-2-9B model with a size of 35GB, not the Gemma-2-9B-IT model, which has a size of 17.23GB when downloaded from Kaggle and 18.48GB from HuggingFace. The slight difference in sizes is due to variations in file saving formats.

In your case, even though there is sufficient space, the model fails to load because, apart from the 17.2GB required for the model weights, additional memory is needed for activations, gradients, and computations during model execution. This can easily exceed the available 26GB, especially for large models. Furthermore, inefficient memory management may also be contributing, as large models are often not optimized for efficient resource usage. There are might be the reasons.

Thank you.

@KaneGreen
Copy link
Author

@Gopi-Uppari
I have tried a loading method in #56, seems helpful.
But I don't know if this method has some side-effect.

@Gopi-Uppari
Copy link
Collaborator

Hi @KaneGreen,

Generally loading method should not cause side effects during inference.
The requires_grad attribute determines whether gradients are computed for a parameter during backpropagation. Setting it to True allows gradients to be calculated, which is essential for training.
During inference which means when you are generating text or using the model for evaluation, gradient computation is typically disabled using @torch.no_grad(), which ensures that no additional memory is used for storing gradients.

Thank you.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants