Skip to content

Commit

Permalink
Ensure that parameters are leaf nodes when loading a model (#364)
Browse files Browse the repository at this point in the history
There was a subtle bug where we populate models with parameters that are
not leaf nodes because we called `to` on them for device placement.

This change fixes this issue and validates that all model parameters are
leaf nodes in the model tests.
  • Loading branch information
danieldk authored Feb 11, 2024
1 parent 4055d7e commit 581316d
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 7 deletions.
22 changes: 16 additions & 6 deletions curated_transformers/tests/models/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,7 @@ def assert_causal_lm_output_equals_hf(
)
orig_model.eval()

for _, param in orig_model.state_dict().items():
assert param.device == torch_device
check_params_buffers(orig_model, torch_device)

hf_model = transformers.AutoModelForCausalLM.from_pretrained(
model_name,
Expand Down Expand Up @@ -153,8 +152,7 @@ def assert_decoder_output_equals_hf(
)
orig_model.eval()

for _, param in orig_model.state_dict().items():
assert param.device == torch_device
check_params_buffers(orig_model, torch_device)

hf_model = transformers.AutoModel.from_pretrained(
model_name, revision=model_revision, trust_remote_code=trust_remote_code
Expand Down Expand Up @@ -217,8 +215,7 @@ def assert_encoder_output_equals_hf(
orig_model = model_class.from_hf_hub(name=model_name, device=torch_device)
orig_model.eval()

for _, param in orig_model.state_dict().items():
assert param.device == torch_device
check_params_buffers(orig_model, torch_device)

hf_model = transformers.AutoModel.from_pretrained(model_name)
hf_model.to(torch_device)
Expand Down Expand Up @@ -362,3 +359,16 @@ def assert_model_config(model: TransformerModule, model_output: Tensor):

hidden_width = model_output.size(-1)
assert config.layer.feedforward.hidden_width == hidden_width


def check_params_buffers(model: Module, device: torch.device):
"""
Check that parameters/buffers are placed on the correct device and that
parameters are leaf nodes.
"""
for buffer in model.buffers():
assert buffer.device == device

for param in model.parameters():
assert param.device == device
assert param.is_leaf
2 changes: 1 addition & 1 deletion curated_transformers/util/serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def default_tensor_to_parameter_converter(
old_param = module._parameters[parameter_name]
assert old_param is not None
_validate_replacement(old_param, tensor, module_prefix)
return Parameter(tensor, requires_grad=old_param.requires_grad).to(device=device) # type: ignore
return Parameter(tensor.to(device=device), requires_grad=old_param.requires_grad) # type: ignore


def _emplace_module_state_dict(
Expand Down

0 comments on commit 581316d

Please sign in to comment.