Skip to content

Commit

Permalink
Initialize buffers per init callbacks
Browse files Browse the repository at this point in the history
  • Loading branch information
kwen2501 committed Aug 5, 2024
1 parent 5958928 commit 1348e65
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 5 deletions.
37 changes: 33 additions & 4 deletions examples/llama/load_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import json
import torch

from typing import Optional
from typing import Callable, Dict, Optional


def load_weights(
Expand Down Expand Up @@ -37,14 +37,19 @@ def load_weights(
# files because the stage module is a partition of the full model.
needed_files = set()
for param in state_dict.keys():
file = weight_map[param]
needed_files.add(file)
# The file a param is saved in
file = weight_map.setdefault(param, None)
if file:
needed_files.add(file)

# Now we load the needed binary files
for file in needed_files:
checkpoint = torch.load(file, weights_only=True)
for param in state_dict.keys():
if weight_map[param] == file:
file_having_param = weight_map[param]
if file_having_param is None:
print(f"Cannot find checkpoint file for {param}, skipping")
elif file_having_param == file:
state_dict[param] = checkpoint[param]
updated_states.setdefault(param, None)

Expand All @@ -59,3 +64,27 @@ def load_weights(
# the current module are preserved.
stage_module.load_state_dict(state_dict, assign=True)


def init_buffers(
stage_module: torch.nn.Module,
device: torch.device,
init_callbacks: Dict[str, Callable],
):
"""
Initialize buffers of `stage_module` per the callback in `init_callbacks`.
`init_callbacks` is a dictionary from a buffer's FQN to its init function.
"""
for name, buf in stage_module.named_buffers():
if name in init_callbacks:
cb = init_callbacks[name]
buf_val = cb(device)
# Find the parent module
splits = name.split(".")
mod = stage_module
for atom in splits[: -1]:
mod = getattr(mod, atom)
mod.register_buffer(
splits[-1], buf_val, persistent=False,
)
print(f"Initialized buffer {name}")

4 changes: 3 additions & 1 deletion examples/llama/meta_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from torch._subclasses.fake_tensor import FakeTensorMode
from transformers import AutoModelForCausalLM, AutoTokenizer

from load_weights import load_weights
from load_weights import load_weights, init_buffers

# Grab the model in meta/fake mode
fake_mode = FakeTensorMode(allow_non_fake_inputs=True)
Expand Down Expand Up @@ -87,5 +87,7 @@
stage_module = pipe.get_stage_module(rank)
print(f"Loading weights into stage {rank}")
load_weights(stage_module)
if hasattr(llama, "buf_init_callbacks"):
init_buffers(stage_module, "cpu", llama.buf_init_callbacks)
stage_module.print_readable()

0 comments on commit 1348e65

Please sign in to comment.