Skip to content

Weight tying breaks with meta device initialisation #357

Open
@le1nux

Description

@le1nux

Weight tying (embedding and lm head) is not preserved when using meta device initialisation:

meta_device=true, use_weight_tying: true:

model.state_dict()["transformer.wte.weight"]
tensor([[-0.0046,  0.0123,  0.0271,  ...,  0.0234,  0.0303, -0.0038],
        [ 0.0213, -0.0267,  0.0017,  ...,  0.0034, -0.0085,  0.0428],
        [ 0.0097,  0.0419, -0.0100,  ..., -0.0108,  0.0084,  0.0299],
        ...,
        [ 0.0030,  0.0022,  0.0153,  ...,  0.0186,  0.0095,  0.0403],
        [-0.0084,  0.0180, -0.0303,  ..., -0.0155,  0.0102,  0.0063],
        [-0.0093, -0.0123, -0.0340,  ..., -0.0071, -0.0019,  0.0271]],
       device='cuda:3')

model.state_dict()["transformer.lm_head.weight"]
tensor([[-0.0395, -0.0022, -0.0460,  ...,  0.0032,  0.0038,  0.0101],
        [-0.0349,  0.0033,  0.0124,  ...,  0.0008,  0.0231,  0.0205],
        [ 0.0297, -0.0023,  0.0014,  ..., -0.0123, -0.0142,  0.0070],
        ...,
        [ 0.0085,  0.0019, -0.0131,  ..., -0.0101,  0.0244, -0.0056],
        [-0.0298,  0.0023, -0.0147,  ..., -0.0191,  0.0283, -0.0460],
        [-0.0351, -0.0034, -0.0132,  ..., -0.0081,  0.0177, -0.0136]],
       device='cuda:3')
model.transformer.lm_head.weight is model.transformer.wte.weight
False

meta_device=false, use_weight_tying: true:

model.state_dict()["transformer.lm_head.weight"]
tensor([[ 0.0177, -0.0254,  0.0643,  ..., -0.0025,  0.0320, -0.0013],
        [-0.0360, -0.0042,  0.0136,  ...,  0.0157,  0.0195, -0.0083],
        [-0.0148, -0.0136, -0.0029,  ..., -0.0040, -0.0242,  0.0247],
        ...,
        [ 0.0212,  0.0066,  0.0017,  ...,  0.0098, -0.0088, -0.0030],
        [-0.0012, -0.0052, -0.0008,  ..., -0.0174,  0.0121,  0.0071],
        [-0.0104,  0.0250, -0.0229,  ...,  0.0005,  0.0099, -0.0061]])
        
model.state_dict()["transformer.wte.weight"]
tensor([[ 0.0177, -0.0254,  0.0643,  ..., -0.0025,  0.0320, -0.0013],
        [-0.0360, -0.0042,  0.0136,  ...,  0.0157,  0.0195, -0.0083],
        [-0.0148, -0.0136, -0.0029,  ..., -0.0040, -0.0242,  0.0247],
        ...,
        [ 0.0212,  0.0066,  0.0017,  ...,  0.0098, -0.0088, -0.0030],
        [-0.0012, -0.0052, -0.0008,  ..., -0.0174,  0.0121,  0.0071],
        [-0.0104,  0.0250, -0.0229,  ...,  0.0005,  0.0099, -0.0061]])
model.transformer.lm_head.weight is model.transformer.wte.weight
True

Config to reproduce:
https://github.com/Modalities/modalities/blob/2c5f78d023c0cfd15a6104245282b9c8720ddc1f/tutorials/warmstart/configs/pre_training_config.yaml

I set the breakpoint here:

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions