Skip to content

Commit

Permalink
Remove excess load_state_dict() leading to meta tensor warnings
Browse files Browse the repository at this point in the history
  • Loading branch information
borzunov committed Jul 24, 2024
1 parent a65205d commit 5321af8
Showing 1 changed file with 0 additions and 5 deletions.
5 changes: 0 additions & 5 deletions src/petals/server/from_pretrained.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,6 @@ def load_pretrained_block(
max_disk_space=max_disk_space,
)

# dummy load, check that keys match
report = block.load_state_dict(state_dict, strict=False)
assert not report.missing_keys, f"Some block weights are missing: {report.missing_keys}"

for param_name, _ in block.named_parameters():
assert param_name in state_dict, f"{param_name} not in state dict"
param = state_dict[param_name]
Expand All @@ -76,7 +72,6 @@ def load_pretrained_block(
set_module_tensor_to_device(block, param_name, "cpu", value=param, dtype=param.dtype)

logger.info(f"Loaded {model_name} block {block_index}")
logger.debug(f"Details: {report}")
return block


Expand Down

0 comments on commit 5321af8

Please sign in to comment.