Open
Description
Weight tying (embedding and lm head) is not preserved when using meta device initialisation:
meta_device=true, use_weight_tying: true:
model.state_dict()["transformer.wte.weight"]
tensor([[-0.0046, 0.0123, 0.0271, ..., 0.0234, 0.0303, -0.0038],
[ 0.0213, -0.0267, 0.0017, ..., 0.0034, -0.0085, 0.0428],
[ 0.0097, 0.0419, -0.0100, ..., -0.0108, 0.0084, 0.0299],
...,
[ 0.0030, 0.0022, 0.0153, ..., 0.0186, 0.0095, 0.0403],
[-0.0084, 0.0180, -0.0303, ..., -0.0155, 0.0102, 0.0063],
[-0.0093, -0.0123, -0.0340, ..., -0.0071, -0.0019, 0.0271]],
device='cuda:3')
model.state_dict()["transformer.lm_head.weight"]
tensor([[-0.0395, -0.0022, -0.0460, ..., 0.0032, 0.0038, 0.0101],
[-0.0349, 0.0033, 0.0124, ..., 0.0008, 0.0231, 0.0205],
[ 0.0297, -0.0023, 0.0014, ..., -0.0123, -0.0142, 0.0070],
...,
[ 0.0085, 0.0019, -0.0131, ..., -0.0101, 0.0244, -0.0056],
[-0.0298, 0.0023, -0.0147, ..., -0.0191, 0.0283, -0.0460],
[-0.0351, -0.0034, -0.0132, ..., -0.0081, 0.0177, -0.0136]],
device='cuda:3')
model.transformer.lm_head.weight is model.transformer.wte.weight
False
meta_device=false, use_weight_tying: true:
model.state_dict()["transformer.lm_head.weight"]
tensor([[ 0.0177, -0.0254, 0.0643, ..., -0.0025, 0.0320, -0.0013],
[-0.0360, -0.0042, 0.0136, ..., 0.0157, 0.0195, -0.0083],
[-0.0148, -0.0136, -0.0029, ..., -0.0040, -0.0242, 0.0247],
...,
[ 0.0212, 0.0066, 0.0017, ..., 0.0098, -0.0088, -0.0030],
[-0.0012, -0.0052, -0.0008, ..., -0.0174, 0.0121, 0.0071],
[-0.0104, 0.0250, -0.0229, ..., 0.0005, 0.0099, -0.0061]])
model.state_dict()["transformer.wte.weight"]
tensor([[ 0.0177, -0.0254, 0.0643, ..., -0.0025, 0.0320, -0.0013],
[-0.0360, -0.0042, 0.0136, ..., 0.0157, 0.0195, -0.0083],
[-0.0148, -0.0136, -0.0029, ..., -0.0040, -0.0242, 0.0247],
...,
[ 0.0212, 0.0066, 0.0017, ..., 0.0098, -0.0088, -0.0030],
[-0.0012, -0.0052, -0.0008, ..., -0.0174, 0.0121, 0.0071],
[-0.0104, 0.0250, -0.0229, ..., 0.0005, 0.0099, -0.0061]])
model.transformer.lm_head.weight is model.transformer.wte.weight
True
Config to reproduce:
https://github.com/Modalities/modalities/blob/2c5f78d023c0cfd15a6104245282b9c8720ddc1f/tutorials/warmstart/configs/pre_training_config.yaml
I set the breakpoint here: