Skip to content

Commit

Permalink
Ensure that parameters are leaf nodes when loading a model (#362)
Browse files Browse the repository at this point in the history
There was a subtle bug where we populate models with parameters that are
not leaf nodes because we called `to` on them for device placement.

This change fixes this issue and validates that all model parameters are
leaf nodes in the model tests.
  • Loading branch information
danieldk authored Feb 8, 2024
1 parent f9da3b5 commit af59d3f
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 9 deletions.
25 changes: 17 additions & 8 deletions curated_transformers/tests/models/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,7 @@ def assert_causal_lm_output_equals_hf(
)
orig_model.eval()

for _, param in orig_model.state_dict().items():
assert param.device == torch_device
check_params_buffers(orig_model, torch_device)

hf_model = transformers.AutoModelForCausalLM.from_pretrained(
model_name,
Expand Down Expand Up @@ -155,8 +154,7 @@ def assert_decoder_output_equals_hf(
)
orig_model.eval()

for _, param in orig_model.state_dict().items():
assert param.device == torch_device
check_params_buffers(orig_model, torch_device)

hf_model = transformers.AutoModel.from_pretrained(
model_name, revision=model_revision, trust_remote_code=trust_remote_code
Expand Down Expand Up @@ -219,8 +217,7 @@ def assert_encoder_output_equals_hf(
orig_model = model_class.from_hf_hub(name=model_name, device=torch_device)
orig_model.eval()

for _, param in orig_model.state_dict().items():
assert param.device == torch_device
check_params_buffers(orig_model, torch_device)

hf_model = transformers.AutoModel.from_pretrained(model_name)
hf_model.to(torch_device)
Expand Down Expand Up @@ -384,8 +381,7 @@ def assert_model_hf_serialization_roundtrip(
)
orig_model.eval()

for _, param in orig_model.state_dict().items():
assert param.device == torch_device
check_params_buffers(orig_model, torch_device)

auto_cls = (
transformers.AutoModelForCausalLM
Expand Down Expand Up @@ -424,3 +420,16 @@ def assert_model_hf_serialization_roundtrip(
assert (
hf_config[k] == v
), f"Key '{k}' value '{v}' is different in the Hugging Face model config ('{hf_config[k]}')"


def check_params_buffers(model: Module, device: torch.device):
"""
Check that parameters/buffers are placed on the correct device and that
parameters are leaf nodes.
"""
for buffer in model.buffers():
assert buffer.device == device

for param in model.parameters():
assert param.device == device
assert param.is_leaf
2 changes: 1 addition & 1 deletion curated_transformers/util/serde/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def default_tensor_to_parameter_converter(
old_param = module._parameters[parameter_name]
assert old_param is not None
_validate_replacement(old_param, tensor, module_prefix)
return Parameter(tensor, requires_grad=old_param.requires_grad).to(device=device) # type: ignore
return Parameter(tensor.to(device=device), requires_grad=old_param.requires_grad) # type: ignore


def _emplace_module_state_dict(
Expand Down

0 comments on commit af59d3f

Please sign in to comment.