-
Notifications
You must be signed in to change notification settings - Fork 32
/
Copy pathpretrain_nmt.py
1476 lines (1397 loc) · 130 KB
/
pretrain_nmt.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# -*- coding: utf-8 -*-
# Copyright 2021 National Institute of Information and Communication Technology (Raj Dabre)
#
# Permission is hereby granted, free of charge, to any person
# obtaining a copy of this software and associated
# documentation files (the "Software"), to deal in the
# Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute,
# sublicense, and/or sell copies of the Software, and to
# permit persons to whom the Software is furnished to do so,
# subject to the following conditions:
# The above copyright notice and this permission notice shall
# be included in all copies or substantial portions of the
# Software.
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY
# KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE
# WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
# PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS
# OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR
# OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
# OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
# SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
## Basic imports
import os
import argparse
import time
import sys
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" # see issue #152
##
## Huggingface imports
import transformers
from transformers import AutoTokenizer, MBartTokenizer, MBart50Tokenizer, BartTokenizer, AlbertTokenizer
from transformers import MBartForConditionalGeneration, BartForConditionalGeneration, MBartConfig, BartConfig, get_linear_schedule_with_warmup, get_cosine_schedule_with_warmup, get_cosine_with_hard_restarts_schedule_with_warmup
from transformers import AdamW
##
## Pytorch imports
import torch
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel
import torch.multiprocessing as mp
import torch.distributed as dist
from torch.optim import Adam
from torch.nn.functional import cosine_similarity
from torch.utils.tensorboard import SummaryWriter
try:
import wandb
except:
raise ImportError("Wandb not installed. Recommended: pip install wandb")
try:
import bitsandbytes as bnb
except:
bnb=None
print("Bits and bytes not installed. Dont use the flag --adam_8bit")
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP, MixedPrecision, FullStateDictConfig, ShardingStrategy
from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType, FullStateDictConfig, LocalStateDictConfig
from torch.distributed._shard.checkpoint import (
FileSystemReader,
FileSystemWriter,
save_state_dict,
load_state_dict,
)
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
from torch.distributed.fsdp.wrap import (
transformer_auto_wrap_policy,
size_based_auto_wrap_policy,
enable_wrap,
wrap,
)
from torch.distributed.fsdp import BackwardPrefetch
from functools import partial
try:
from torchdistx import deferred_init
except:
print("torchdistx not installed. Large models will load REALLLLYYYY SLOWLY!")
deferred_init = None
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import apply_activation_checkpointing, checkpoint_wrapper, CheckpointImpl
##
## Our imports
from common_utils import *
##
## Other imports
import random
import numpy as np
import math
import sacrebleu
import functools
import shutil
from contextlib import nullcontext
##
## Seed setting here
torch.manual_seed(621311)
##
## Get torch version
torch_version = torch.__version__
##
def model_create_load_run_save(gpu, args, files, train_files, ewc_files):
"""The main function which does the overall training. Should be split into multiple parts in the future. Currently monolithc intentionally."""
rank = args.nr * args.gpus + gpu ## The rank of the current process out of the total number of processes indicated by world_size.
print("Launching process:", rank)
dist.init_process_group(backend='nccl', init_method='env://', world_size=args.world_size, rank=rank)
if args.use_official_pretrained_tokenizer or args.use_official_pretrained: # If we use an official model then we are using its tokenizer by default.
if "mbart" in args.pretrained_model or "IndicBART" in args.pretrained_model:
if "50" in args.pretrained_model:
tok = MBart50Tokenizer.from_pretrained(args.tokenizer_name_or_path, use_fast=False)
elif "IndicBART" in args.pretrained_model:
tok = AlbertTokenizer.from_pretrained(args.tokenizer_name_or_path, do_lower_case=False, use_fast=False, keep_accents=True)
else:
tok = MBartTokenizer.from_pretrained(args.tokenizer_name_or_path, use_fast=False)
else:
tok = BartTokenizer.from_pretrained(args.tokenizer_name_or_path, use_fast=False)
else:
if "albert" in args.tokenizer_name_or_path:
tok = AlbertTokenizer.from_pretrained(args.tokenizer_name_or_path, do_lower_case=False, use_fast=False, keep_accents=True)
elif "mbart" in args.tokenizer_name_or_path:
tok = MBartTokenizer.from_pretrained(args.tokenizer_name_or_path, do_lower_case=False, use_fast=False, keep_accents=True)
## Fast tokenizers are not good because their behavior is weird. Accents should be kept or else the segmentation will be messed up on languages with accented characters. No lower case obviously because we want to train on the original case. Set to false if you are ok with the model not dealing with cases.
tok.save_pretrained(args.model_path+"_deploy") ## Save the tokenizer for future use.
# Copy the specially_added_tokens file into the deploy folder. This file exists when we arent using official pretrained models. We are not going to support this for separate tokenizers.
if os.path.exists(args.tokenizer_name_or_path+"/specially_added_tokens"):
shutil.copyfile(args.tokenizer_name_or_path+"/specially_added_tokens", args.model_path+"_deploy/specially_added_tokens")
print("Tokenizer is:", tok)
if args.shard_files and rank == 0: ## First shard the data using process 0 aka the prime process or master process. Other processes will wait.
shard_files_mono(files, tok, args)
shard_files_bi(train_files, tok, args, additional_tokenizer=None)
if args.ewc_importance != 0.0:
shard_files_bi(ewc_files, tok, args, additional_tokenizer=None)
# dist.barrier() ## Stop other processes from proceeding till sharding is done. ## Barriers are bad before loading a model they occupy memory for no reason.
if args.supported_languages is not None:
args.supported_languages = args.supported_languages.split(",")
with open(args.model_path+"_deploy/supported_languages.txt", "w") as f:
for supported_pair in args.supported_languages:
f.write(supported_pair.replace("-", " ")+"\n")
print(f"Running DDP/FSDP checkpoint example on rank {rank}.") ## Unlike the FT script this will always be distributed
if args.fp16: ## Although the code supports FP16/AMP training, it tends to be unstable in distributed setups so use this carefully.
print("We will do fp16 training")
if args.use_fsdp:
scaler = ShardedGradScaler(args.init_scale) ## This is the scaler for FSDP. It is different from the one in torch.cuda.amp.
else:
scaler = torch.cuda.amp.GradScaler(args.init_scale) ## Gradient scaler which will be used with torch's automatic mixed precision
# Get scaler info
scaler_info = scaler.state_dict()
# Print scaler info neatly
print("AMP scaler info:")
for key, value in scaler_info.items():
print(f"{key}: {value}")
# Store current scale value
scale_value = scaler.get_scale()
mixed_precision_policy = MixedPrecision(
# Param precision
param_dtype=torch.float16,
# Gradient communication precision.
reduce_dtype=torch.float16,
)
else:
print("We will do fp32 training")
mixed_precision_policy = None
if args.use_fsdp:
fullstate_save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
if torch_version.startswith("2."): ## From 2.0 onwards, FSDP allows sharded optimizers.
from torch.distributed.fsdp.fully_sharded_data_parallel import FullOptimStateDictConfig
from torch.distributed.checkpoint.optimizer import load_sharded_optimizer_state_dict
optimizer_fullstate_save_policy = FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True)
from transformers.models.mbart.modeling_mbart import MBartEncoderLayer, MBartDecoderLayer
if args.auto_wrap_policy == "transformer": ## A block will be kept on a single device. No minimum number of params.
print("We will use transformer auto wrap policy")
mbart_auto_wrap_policy = functools.partial(
transformer_auto_wrap_policy,
transformer_layer_cls={
MBartEncoderLayer, MBartDecoderLayer,
},
)
else:
print("We will use size based auto wrap policy with min params:", args.fsdp_min_params)
mbart_auto_wrap_policy = functools.partial(size_based_auto_wrap_policy, min_num_params=int(args.fsdp_min_params))
if args.activation_checkpointing:
print("We will use activation checkpointing for FSDP.")
non_reentrant_wrapper = partial(
checkpoint_wrapper,
offload_to_cpu=False,
checkpoint_impl=CheckpointImpl.NO_REENTRANT,
)
check_fn = lambda submodule: (isinstance(submodule, MBartEncoderLayer) or isinstance(submodule, MBartDecoderLayer))
if args.sharding_strategy == "FULL_SHARD":
print("We will use full sharding")
sharding_strategy = ShardingStrategy.FULL_SHARD
elif args.sharding_strategy == "SHARD_GRAD_OP":
print("We will use gradient and optimizer sharding")
sharding_strategy = ShardingStrategy.SHARD_GRAD_OP
elif args.sharding_strategy == "HYBRID_SHARD":
print("We will use hybrid sharding. Model is sharded on a node and then each node forms a replica.")
sharding_strategy = ShardingStrategy.HYBRID_SHARD
elif args.sharding_strategy == "_HYBRID_SHARD_ZERO2":
print("Similar to hybrid sharding except that only optimizer and gradient sharding is done over a node.")
sharding_strategy = ShardingStrategy._HYBRID_SHARD_ZERO2
else:
raise ValueError("Invalid sharding strategy")
backward_prefetch_policy = BackwardPrefetch.BACKWARD_PRE if args.backward_prefetch else None
torch.backends.cuda.matmul.allow_tf32 = args.allow_tf32_matmul
if args.nodes_per_hsdp_group > 1:
print("We will use HSDP replicas of size:", args.nodes_per_hsdp_group*args.gpus, "GPUs and there are total", args.nodes//args.nodes_per_hsdp_group, "HSDP replicas")
assert args.nodes % args.nodes_per_hsdp_group == 0, "The number of nodes should be divisible by the number of nodes per HSDP group"
hsdp_replica_id = rank//(args.nodes_per_hsdp_group*args.gpus)
print("HSDP replica id is:", hsdp_replica_id)
intranode_process_group, _ = dist.new_subgroups(group_size=args.nodes_per_hsdp_group*args.gpus)
for local_rank in range(args.nodes_per_hsdp_group*args.gpus):
internode_ranks = [hsdp_replica_num*args.nodes_per_hsdp_group*args.gpus+rank%(args.nodes_per_hsdp_group*args.gpus) for hsdp_replica_num in range(args.nodes//args.nodes_per_hsdp_group)]
grp = dist.new_group(ranks=internode_ranks, backend='nccl')
if local_rank == (rank%(args.nodes_per_hsdp_group*args.gpus)):
internode_process_group = grp
print("Process groups are:", dist.get_process_group_ranks(intranode_process_group), dist.get_process_group_ranks(internode_process_group))
process_group = (intranode_process_group, internode_process_group)
else:
print("We will use a single node for each HSDP group")
process_group = None
cpu_offload = dist.fsdp.CPUOffload(offload_params=args.fsdp_cpu_offload)
if args.fsdp_cpu_offload:
print("We will use CPU offloading for FSDP")
# dist.barrier()
if args.encoder_tying_config is not None:
print("We will use recurrently stacked layers for the encoder with configuration:", args.encoder_tying_config)
if args.decoder_tying_config is not None:
print("We will use recurrently stacked layers for the decoder with configuration:", args.decoder_tying_config)
if args.unidirectional_encoder:
print("Using unidirectional encoder.")
torch.cuda.set_device(gpu) ## Set the device to the current GPU. This is different from the rank so keep this in mind.
if rank == 0:
writer = SummaryWriter(args.model_path+".tflogs")
if args.wb:
run = wandb.init(
project=args.wb_project,
name=args.wb_run,
config=vars(args),
save_code=True,
)
print("Initialization scheme is:", args.initialization_scheme)
if args.initialization_scheme == "static":
print("Static initialization scheme is used. We will use the init_std value of:", args.init_std)
dist.barrier()
if args.use_official_pretrained:
if "mbart" in args.pretrained_model or "IndicBART" in args.pretrained_model:
config = MBartConfig.from_pretrained(args.pretrained_model)
config.init_std = args.init_std # We should set the init_std to be different when using adaptors or newer params.
config.initialization_scheme = args.initialization_scheme # We should set the initialization_scheme to be different when using adaptors or newer params.
config.dropout = args.dropout ## We should set dropouts manually
config.attention_dropout = args.attention_dropout ## We should set dropouts manually
config.activation_dropout = args.activation_dropout ## We should set dropouts manually
config.encoder_layerdrop = args.layerdrop ## We should set dropouts manually
config.decoder_layerdrop = args.layerdrop ## We should set dropouts manually
config.prompt_tuning = args.prompt_tuning ## We should set prompt_tuning_info manually
config.prompt_projection_hidden_size=args.prompt_projection_hidden_size
config.prompt_init_std=args.prompt_init_std ## We should set prompt_init_std manually
config.layernorm_prompt_projection=args.layernorm_prompt_projection ## We should set layernorm_prompt_projection manually
config.no_projection_prompt=args.no_projection_prompt ## We should set no_projection_prompt manually
config.use_tanh_activation_prompt=args.use_tanh_activation_prompt ## We should set use_tanh_activation_prompt manually
config.residual_connection_prompt=args.residual_connection_prompt ## We should set residual_connection_prompt manually
config.num_prompts = args.num_prompts ## We should set num_prompts manually
config.prompt_dropout = args.prompt_dropout ## We should set prompt_dropout manually
config.recurrent_projections = args.recurrent_projections ## We should set recurrent_projections manually
config.adaptor_tuning = args.adaptor_tuning ## We should set adaptor_tuning_info manually
config.deep_adaptor_tuning = args.deep_adaptor_tuning ## We should set deep_adaptor_tuning_info manually
config.deep_adaptor_tuning_ffn_only = args.deep_adaptor_tuning_ffn_only ## We should set deep_adaptor_tuning_info manually
config.adaptor_dropout = args.adaptor_dropout ## We should set adaptor_dropout manually
config.adaptor_activation_function = args.adaptor_activation_function ## We should set adaptor_activation_function manually
config.parallel_adaptors = args.parallel_adaptors ## We should set parallel_adaptors_info manually
config.layernorm_adaptor_input = args.layernorm_adaptor_input ## We should set layernorm_adaptor_input_info manually
config.adaptor_scaling_factor = args.adaptor_scaling_factor ## We should set adaptor_scaling_factor_info manually
config.residual_connection_adaptor = args.residual_connection_adaptor ## We should set residual_connection_adaptor_info manually
config.encoder_adaptor_tying_config = args.encoder_adaptor_tying_config ## We should set encoder_tying_config manually
config.decoder_adaptor_tying_config = args.decoder_adaptor_tying_config ## We should set decoder_tying_config manually
config.adaptor_hidden_size = args.adaptor_hidden_size ## We should set adaptor_hidden_size manually
config.moe_adaptors=args.moe_adaptors ## We should set moe_adaptors_info manually
config.num_moe_adaptor_experts=args.num_moe_adaptor_experts ## We should set num_moe_adaptor_experts_info manually
config.hypercomplex = args.hypercomplex ## We should set hypercomplex manually
config.hypercomplex_n = args.hypercomplex_n ## We should set hypercomplex_n manually
config.ia3_adaptors = args.ia3_adaptors ## We should set ia3_adaptors info manually
config.lora_adaptors = args.lora_adaptors ## We should set lora_adaptors info manually
config.lora_adaptor_rank = args.lora_adaptor_rank ## We should set lora_adaptor_rank info manually
config.softmax_bias_tuning = args.softmax_bias_tuning ## We should set softmax_bias_tuning_info manually
config.gradient_checkpointing = args.gradient_checkpointing ## We should set gradient_checkpointing_info manually
config.sparsify_attention = args.sparsify_attention
config.sparsify_ffn = args.sparsify_ffn
config.num_sparsify_blocks = args.num_sparsify_blocks
config.sparsification_temperature = args.sparsification_temperature
model = deferred_init.deferred_init(MBartForConditionalGeneration.from_pretrained, args.pretrained_model, config=config) if (args.use_fsdp and deferred_init is not None) else MBartForConditionalGeneration.from_pretrained(args.pretrained_model, config=config) ## We may use FBs official model and fine-tune it for our purposes.
config.architectures = ["MBartForConditionalGeneration"]
config.save_pretrained(args.model_path+"_deploy") ## Save the config as a json file to ensure easy loading during future fine tuning of the model.
elif "bart" in args.pretrained_model:
config = BartConfig.from_pretrained(args.pretrained_model)
config.init_std = args.init_std # We should set the init_std to be different when using adaptors or newer params.
config.initialization_scheme = args.initialization_scheme # We should set the initialization_scheme to be different when using adaptors or newer params.
config.dropout = args.dropout ## We should set dropouts manually
config.attention_dropout = args.attention_dropout ## We should set dropouts manually
config.activation_dropout = args.activation_dropout ## We should set dropouts manually
config.encoder_layerdrop = args.layerdrop ## We should set dropouts manually
config.decoder_layerdrop = args.layerdrop ## We should set dropouts manually
config.gradient_checkpointing = args.gradient_checkpointing ## We should set gradient_checkpointing_info manually
model = deferred_init.deferred_init(BartForConditionalGeneration.from_pretrained, args.pretrained_model, config=config, force_bos_token_to_be_generated=True) if (args.use_fsdp and deferred_init is not None) else BartForConditionalGeneration.from_pretrained(args.pretrained_model, config=config, force_bos_token_to_be_generated=True) ## We may use FBs official model and fine-tune it for our purposes.
config.architectures = ["BartForConditionalGeneration"]
config.save_pretrained(args.model_path+"_deploy") ## Save the config as a json file to ensure easy loading during future fine tuning of the model.
else: ## We are going to manually specify our own model config.
config = MBartConfig(vocab_size=len(tok), init_std=args.init_std, initialization_scheme=args.initialization_scheme, encoder_layers=args.encoder_layers, decoder_layers=args.decoder_layers, dropout=args.dropout, attention_dropout=args.attention_dropout, activation_dropout=args.activation_dropout, encoder_attention_heads=args.encoder_attention_heads, decoder_attention_heads=args.decoder_attention_heads, encoder_ffn_dim=args.encoder_ffn_dim, decoder_ffn_dim=args.decoder_ffn_dim, d_model=args.d_model, embed_low_rank_dim=args.embed_low_rank_dim, no_embed_norm=args.no_embed_norm, scale_embedding=args.scale_embedding, pad_token_id=tok.pad_token_id, eos_token_id=tok(["</s>"], add_special_tokens=False).input_ids[0][0], bos_token_id=tok(["<s>"], add_special_tokens=False).input_ids[0][0], encoder_tying_config=args.encoder_tying_config, decoder_tying_config=args.decoder_tying_config, gradient_checkpointing=args.gradient_checkpointing, multilayer_softmaxing=args.multilayer_softmaxing, wait_k=args.wait_k, unidirectional_encoder=args.unidirectional_encoder, softmax_temperature=args.softmax_temperature, temperature_calibration=args.temperature_calibration, encoder_layerdrop=args.layerdrop, decoder_layerdrop=args.layerdrop, no_scale_attention_embedding=args.no_scale_attention_embedding, positional_encodings=args.positional_encodings, alibi_encoding=args.alibi_encoding, asymmetric_alibi_encoding=args.asymmetric_alibi_encoding, rope_encoding=args.rope_encoding, num_domains_for_domain_classifier=args.num_domains_for_domain_classifier, gradient_reversal_for_domain_classifier=args.gradient_reversal_for_domain_classifier, activation_function=args.activation_function, no_positional_encoding_encoder=args.no_positional_encoding_encoder, no_positional_encoding_decoder=args.no_positional_encoding_decoder, postnorm_encoder=args.postnorm_encoder, postnorm_decoder=args.postnorm_decoder, use_moe=args.use_moe, num_experts=args.num_experts, expert_ffn_size=args.expert_ffn_size, prompt_tuning=args.prompt_tuning, prompt_dropout=args.prompt_dropout, prompt_projection_hidden_size=args.prompt_projection_hidden_size, prompt_init_std=args.prompt_init_std, layernorm_prompt_projection=args.layernorm_prompt_projection, no_projection_prompt=args.no_projection_prompt, use_tanh_activation_prompt=args.use_tanh_activation_prompt, residual_connection_prompt=args.residual_connection_prompt, num_prompts=args.num_prompts, recurrent_projections=args.recurrent_projections, adaptor_tuning=args.adaptor_tuning, deep_adaptor_tuning=args.deep_adaptor_tuning, deep_adaptor_tuning_ffn_only=args.deep_adaptor_tuning_ffn_only, adaptor_dropout=args.adaptor_dropout, adaptor_activation_function=args.adaptor_activation_function, parallel_adaptors = args.parallel_adaptors, layernorm_adaptor_input = args.layernorm_adaptor_input, adaptor_scaling_factor = args.adaptor_scaling_factor, residual_connection_adaptor = args.residual_connection_adaptor, encoder_adaptor_tying_config=args.encoder_adaptor_tying_config, decoder_adaptor_tying_config=args.decoder_adaptor_tying_config, adaptor_hidden_size=args.adaptor_hidden_size, moe_adaptors=args.moe_adaptors, num_moe_adaptor_experts=args.num_moe_adaptor_experts, hypercomplex=args.hypercomplex, hypercomplex_n=args.hypercomplex_n, ia3_adaptors=args.ia3_adaptors, lora_adaptors=args.lora_adaptors, lora_adaptor_rank=args.lora_adaptor_rank, softmax_bias_tuning=args.softmax_bias_tuning, sparsify_attention=args.sparsify_attention, sparsify_ffn=args.sparsify_ffn, num_sparsify_blocks=args.num_sparsify_blocks, sparsification_temperature=args.sparsification_temperature, tokenizer_class="AlbertTokenizer" if "albert" in args.tokenizer_name_or_path else "MBartTokenizer") ## Configuration. TODO: Save this configuration somehow.
config.architectures = ["MBartForConditionalGeneration"]
config.save_pretrained(args.model_path+"_deploy") ## Save the config as a json file to ensure easy loading during future fine tuning of the model.
model = deferred_init.deferred_init(MBartForConditionalGeneration, config) if (args.use_fsdp and deferred_init is not None) else MBartForConditionalGeneration(config)
model.train()
if args.distillation: ## When distilling we need a parent model. The creation of the model is in the same way as the child. This model is immediately loaded with some pretrained params and then loaded into the GPU.
print("We will do distillation from a parent model.")
if args.use_official_parent_pretrained:
if "mbart" in args.parent_pretrained_model or "IndicBART" in args.pretrained_model:
parent_config = MBartConfig.from_pretrained(args.parent_pretrained_model)
parent_config.dropout = args.parent_dropout ## We should set dropouts manually
parent_config.attention_dropout = args.parent_attention_dropout ## We should set dropouts manually
parent_config.activation_dropout = args.parent_activation_dropout ## We should set dropouts manually
parent_config.encoder_layerdrop = args.layerdrop ## We should set dropouts manually
parent_config.decoder_layerdrop = args.layerdrop ## We should set dropouts manually
parent_model = deferred_init.deferred_init(MBartForConditionalGeneration.from_pretrained, args.parent_pretrained_model, config=parent_config) if (args.use_fsdp and deferred_init is not None) else MBartForConditionalGeneration.from_pretrained(args.parent_pretrained_model, config=parent_config) ## We may use FBs official model and fine-tune it for our purposes.
elif "bart" in args.parent_pretrained_model:
parent_config = BartConfig.from_pretrained(args.parent_pretrained_model)
parent_config.dropout = args.parent_dropout ## We should set dropouts manually
parent_config.attention_dropout = args.parent_attention_dropout ## We should set dropouts manually
parent_config.activation_dropout = args.parent_activation_dropout ## We should set dropouts manually
parent_config.encoder_layerdrop = args.layerdrop ## We should set dropouts manually
parent_config.decoder_layerdrop = args.layerdrop ## We should set dropouts manually
parent_model = deferred_init.deferred_init(BartForConditionalGeneration.from_pretrained, args.parent_pretrained_model, config=parent_config, force_bos_token_to_be_generated=True) if (args.use_fsdp and deferred_init is not None) else BartForConditionalGeneration.from_pretrained(args.pretrained_model, config=config, force_bos_token_to_be_generated=True) ## We may use FBs official model and fine-tune it for our purposes.
else: ## We are going to manually specify our own parent model config.
parent_config = MBartConfig(vocab_size=len(tok), encoder_layers=args.parent_encoder_layers, decoder_layers=args.parent_decoder_layers, dropout=args.parent_dropout, attention_dropout=args.parent_attention_dropout, activation_dropout=args.parent_activation_dropout, encoder_attention_heads=args.parent_encoder_attention_heads, decoder_attention_heads=args.parent_decoder_attention_heads, encoder_ffn_dim=args.parent_encoder_ffn_dim, decoder_ffn_dim=args.parent_decoder_ffn_dim, d_model=args.parent_d_model, no_embed_norm=args.no_embed_norm, scale_embedding=args.scale_embedding, pad_token_id=tok.pad_token_id, eos_token_id=tok(["</s>"], add_special_tokens=False).input_ids[0][0], bos_token_id=tok(["<s>"], add_special_tokens=False).input_ids[0][0], encoder_tying_config=args.encoder_tying_config, decoder_tying_config=args.decoder_tying_config, multilayer_softmaxing=args.multilayer_softmaxing, wait_k=args.wait_k, unidirectional_encoder=args.unidirectional_encoder, softmax_temperature=args.softmax_temperature, temperature_calibration=args.temperature_calibration, encoder_layerdrop=args.layerdrop, decoder_layerdrop=args.layerdrop, no_scale_attention_embedding=args.no_scale_attention_embedding, positional_encodings=args.positional_encodings, alibi_encoding=args.alibi_encoding, asymmetric_alibi_encoding=args.asymmetric_alibi_encoding, rope_encoding=args.rope_encoding, activation_function=args.activation_function, no_positional_encoding_encoder=args.no_positional_encoding_encoder, no_positional_encoding_decoder=args.no_positional_encoding_decoder, postnorm_encoder=args.postnorm_encoder, postnorm_decoder=args.postnorm_decoder, use_moe=args.use_moe, num_experts=args.num_experts, expert_ffn_size=args.expert_ffn_size)
parent_model = deferred_init.deferred_init(MBartForConditionalGeneration, parent_config) if (args.use_fsdp and deferred_init is not None) else MBartForConditionalGeneration(parent_config)
parent_model.train() ## We do this to enable dropout but we wont have an optimizer for this so we wont train this model. For now. Future implementations should ask if we want to do co-distill or not. By co-distillation I mean, the parent will learn together with the child.
if not args.use_fsdp:
parent_model.cuda(gpu) ## Move the model to the GPU.
print("Memory consumed after moving parent model to GPU", round(torch.cuda.memory_allocated(gpu)/(1024**3), 2), "GB")
else:
print("When FSDP is used, the parent model is not moved to the GPU. This is because FSDP does not support moving the model to the GPU. Instead, it moves the model to the CPU and then to the GPU. This is done to save memory. This is done in the FSDP wrapper itself.")
if args.use_fsdp:
parent_model = FSDP(parent_model, mixed_precision=mixed_precision_policy, device_id=torch.cuda.current_device(), auto_wrap_policy=mbart_auto_wrap_policy, sharding_strategy=sharding_strategy, backward_prefetch=backward_prefetch_policy, process_group=process_group, cpu_offload=cpu_offload) #, forward_prefetch=args.forward_prefetch
else:
parent_model = DistributedDataParallel(parent_model, device_ids=[gpu], output_device=gpu)
print("Loading a parent model from which distillation will be done.")
dist.barrier()
# configure map_location properly
if not args.use_official_parent_pretrained:
if args.use_fsdp:
if torch_version.startswith("2."): ## From 2.0 onwards, FSDP allows sharded optimizers.
reader = FileSystemReader(args.parent_pretrained_model + "_sharded")
with FSDP.state_dict_type(parent_model, StateDictType.LOCAL_STATE_DICT):
state_dict = parent_model.state_dict()
load_state_dict(state_dict, reader)
parent_model.load_state_dict(state_dict)
del state_dict
else:
reader = FileSystemReader(args.parent_pretrained_model + "_sharded")
with FSDP.state_dict_type(parent_model, StateDictType.LOCAL_STATE_DICT):
state_dict = parent_model.state_dict()
load_state_dict(state_dict, reader)
parent_model.load_state_dict(state_dict)
del state_dict
else:
map_location = {'cuda:%d' % 0: 'cuda:%d' % gpu}
parent_checkpoint_dict = torch.load(args.parent_pretrained_model, map_location=map_location)
if type(parent_checkpoint_dict) == dict:
parent_model.load_state_dict(parent_checkpoint_dict['model']) # We never do any remapping of the parent. We always reuse it as it is.
else:
parent_model.module.load_state_dict(parent_checkpoint_dict) # We never do any remapping of the parent. We always reuse it as it is.
del parent_checkpoint_dict
parent_model.train()
torch.cuda.empty_cache()
freeze_params(model, args.freeze_exception_list, rank)
### NOTE: Please freeze params before wrapping the model in DDP. Mandem almost had a stoke trying to figure this out.
if not args.use_fsdp:
model.cuda(gpu) ## Move the model to the GPU.
print("Memory consumed after moving model to GPU", round(torch.cuda.memory_allocated(gpu)/(1024**3), 2), "GB")
else:
print("When FSDP is used, the model is not moved to the GPU. This is because FSDP does not support moving the model to the GPU. Instead, it moves the model to the CPU and then to the GPU. This is done to save memory. This is done in the FSDP wrapper itself.")
print("Optimizing", [n for n, p in model.named_parameters() if p.requires_grad])
if args.gradient_checkpointing:
print("Using gradient checkpointing")
num_params_to_optimize = sum(p.numel() for p in model.parameters() if p.requires_grad)
num_model_params = sum(p.numel() for p in model.parameters())
print("Number of model parameters:", num_model_params)
print("Total number of params to be optimized are: ", num_params_to_optimize)
print("Percentage of parameters to be optimized: ", 100*num_params_to_optimize/num_model_params)
if args.use_fsdp:
model = FSDP(model, mixed_precision=mixed_precision_policy, device_id=torch.cuda.current_device(), auto_wrap_policy=mbart_auto_wrap_policy, sharding_strategy=sharding_strategy, backward_prefetch=backward_prefetch_policy, process_group=process_group, cpu_offload=cpu_offload) ## This wrapper around the model will enable sharded distributed training. , forward_prefetch=args.forward_prefetch
if args.sharding_strategy == "HYBRID_SHARD":
print("Process groups are", torch.distributed.get_process_group_ranks(model.process_group), torch.distributed.get_process_group_ranks(model._inter_node_pg))
else:
model = DistributedDataParallel(model, device_ids=[gpu], output_device=gpu) ## This wrapper around the model will enable distributed training.
print("Memory consumed after wrapping with DDP/FSDP", round(torch.cuda.memory_allocated(gpu)/(1024**3), 2), "GB")
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
{
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay) and p.requires_grad],
"weight_decay": args.weight_decay,
},
{
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay) and p.requires_grad],
"weight_decay": 0.0,
},
] ## We suppose that weight decay will be used except for biases and layer norm weights.
if args.prompt_tuning:
print("Although the percentage of parameters to be optimized is high, during training the number of actual params during decoding are way way lower.")
if args.adam_8bit:
print("Using an 8-bit AdamW optimizer.")
optimizer = bnb.optim.AdamW8bit(optimizer_grouped_parameters, lr=args.lr, eps=args.adam_eps, betas=(0.9, 0.995)) # Our glorious 8 bit optimizer. All hail our lord and savior Tim Dettmers.
else:
print("Using an 32-bit AdamW optimizer.")
if args.rms_adam:
print("Using RMSAdam optimizer.")
optimizer = AdamWScale(optimizer_grouped_parameters, lr=args.lr, eps=args.adam_eps, betas=(0.9, 0.995))
else:
optimizer = AdamW(optimizer_grouped_parameters, lr=args.lr, eps=args.adam_eps) ## Our glorious optimizer.
model.train()
if args.lr_scheduler == "linear":
scheduler = get_linear_schedule_with_warmup(optimizer, args.warmup_steps, args.num_batches) ## A warmup and decay scheduler. We use the linear scheduler for now. TODO: Enable other schedulers with a flag.
elif args.lr_scheduler == "cosine":
scheduler = get_cosine_schedule_with_warmup(optimizer, args.warmup_steps, args.num_batches, num_cycles=args.cosine_scheduler_num_cycles) ## A warmup and decay scheduler. We use the linear scheduler for now. TODO: Enable other schedulers with a flag.
elif args.lr_scheduler == "cosine_with_restarts":
scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(optimizer, args.warmup_steps, args.num_batches, num_cycles=args.cosine_scheduler_num_cycles)
else:
raise ValueError("Invalid LR scheduler")
while scheduler.get_lr()[0] < 1e-7: ## We want to keep a minimum learning rate else for the initial batch or initial few batches barely anything will be learned which is a waste of computation. This minimum value is kept to 1e-7 by default in accordance with previous literature, other implementations and the Paris peace accords.
scheduler.step()
if rank == 0:
print("Initial LR is:", scheduler.get_lr()[0], ", max LR is:", args.lr, ", warmup steps are:", args.warmup_steps, ", total number of batches/steps are:", args.num_batches)
if args.pretrained_model != "" and (not args.use_official_pretrained or args.locally_fine_tuned_model_path is not None): ## Here we load a previous checkpoint in case training crashed. Note the args.locally_fine_tuned_model_path. This is in case we were tuning an official mbart or indicbart or bart model but want to further tine tune it or it crashed and we want to resume training it.
print("Loading from checkpoint. Strict loading by default but if there are missing or non matching keys or if we use prompt or adaptor tuning, they will be ignored when layer remapping or component selection is done. In case of prompt and adaptor tuning, new params are added to the model and hence strict matching of keys is not possible.")
dist.barrier()
sys.stdout.flush()
if args.locally_fine_tuned_model_path is not None: ## Now that the pretrained_model argument was used to instantiate the model, it can be replaced with the local model path. Remember to specify pure model or the model with the optimizer and scheduler states depending on your requirement by relying on the flag --no_reload_optimizer_ctr_and_scheduler.
args.pretrained_model = args.locally_fine_tuned_model_path
if args.use_fsdp: # With FSDP models I would rather not risk pruning or layer remapping. So I am not going to do it. I am going to load the model as it is. Consider doing this externally before loading the model.
if torch_version.startswith("2."): ## From 2.0 onwards, FSDP allows sharded optimizers.
reader = FileSystemReader(args.pretrained_model + "_sharded")
with FSDP.state_dict_type(model, StateDictType.LOCAL_STATE_DICT):
state_dict = model.state_dict()
load_state_dict(state_dict, reader)
model.load_state_dict(state_dict)
del state_dict
else:
reader = FileSystemReader(args.pretrained_model + "_sharded")
with FSDP.state_dict_type(model, StateDictType.LOCAL_STATE_DICT):
state_dict = model.state_dict()
load_state_dict(state_dict, reader)
model.load_state_dict(state_dict) # Check if strict loading is required here. We ideally dont want it to be so if we add prompts or adaptors.
if args.prompt_tuning and args.initialize_prompts_with_random_embeddings: # This might fail for FSDP so please check. TODO.
model.module.initialize_prompt_params_with_random_embeddings()
if not args.no_reload_optimizer_ctr_and_scheduler:
if torch_version.startswith("2."): ## From 2.0 onwards, FSDP allows sharded optimizers so we will load a sharded optimizer.
reader = FileSystemReader(args.pretrained_model+ "_optim_sharded")
with FSDP.state_dict_type(model, StateDictType.LOCAL_STATE_DICT):
optim_dict = load_sharded_optimizer_state_dict(
model_state_dict=model.state_dict(),
optimizer_key="optim",
storage_reader=reader,
)
flattened_osd = FSDP.optim_state_dict_to_load(model, optimizer, optim_dict["optim"])
optimizer.load_state_dict(flattened_osd)
del flattened_osd
del optim_dict
else:
full_optimizer = None
if rank == 0:
full_optimizer = torch.load(args.pretrained_model+ "_optim") ## We now load only the optimizer and scheduler.
sharded_optimizer = FSDP.scatter_full_optim_state_dict(full_optimizer, model)
optimizer.load_state_dict(sharded_optimizer)
scheduler_and_ctr = torch.load(args.pretrained_model + "_scheduler_and_ctr")
scheduler.load_state_dict(scheduler_and_ctr['scheduler'])
ctr = scheduler_and_ctr['ctr']
del scheduler_and_ctr
del sharded_optimizer
del full_optimizer
else:
ctr = 0
del state_dict
else:
map_location = {'cuda:%d' % 0: 'cuda:%d' % gpu}
checkpoint_dict = torch.load(args.pretrained_model, map_location=map_location)
if type(checkpoint_dict) == dict:
model.load_state_dict(remap_embeddings_eliminate_components_and_eliminate_mismatches(model.state_dict(), remap_layers(checkpoint_dict['model'], 4, args, rank), args), strict=True if (args.remap_encoder == "" and args.remap_decoder == "" and not args.eliminate_encoder_before_initialization and not args.eliminate_decoder_before_initialization and not args.eliminate_embeddings_before_initialization and not args.prompt_tuning and not args.adaptor_tuning and not args.deep_adaptor_tuning and not args.deep_adaptor_tuning_ffn_only and not args.ia3_adaptors and not args.lora_adaptors and not args.softmax_bias_tuning and not args.sparsify_attention) else False)
if args.prompt_tuning and args.initialize_prompts_with_random_embeddings:
model.module.initialize_prompt_params_with_random_embeddings()
if not args.no_reload_optimizer_ctr_and_scheduler and args.remap_encoder == '' and args.remap_decoder == '' and not args.eliminate_encoder_before_initialization and not args.eliminate_decoder_before_initialization and not args.eliminate_embeddings_before_initialization: ## Do not load optimizers, ctr and schedulers when remapping or resuming training.
if 'optimizer' in checkpoint_dict:
print("Reloading optimizer")
optimizer.load_state_dict(checkpoint_dict['optimizer']) ## Dubious
if 'scheduler' in checkpoint_dict:
print("Reloading scheduler")
scheduler.load_state_dict(checkpoint_dict['scheduler']) ## Dubious
if 'ctr' in checkpoint_dict:
print("Reloading ctr. This means we resume training.")
ctr = checkpoint_dict['ctr']
else:
ctr = 0
else:
model.module.load_state_dict(remap_embeddings_eliminate_components_and_eliminate_mismatches(model.state_dict(), remap_layers(checkpoint_dict, 3, args, rank), args), strict=True if (args.remap_encoder == "" and args.remap_decoder == "" and not args.eliminate_encoder_before_initialization and not args.eliminate_decoder_before_initialization and not args.eliminate_embeddings_before_initialization and not args.prompt_tuning and not args.adaptor_tuning and not args.deep_adaptor_tuning and not args.deep_adaptor_tuning_ffn_only and not args.ia3_adaptors and not args.lora_adaptors and not args.softmax_bias_tuning and not args.sparsify_attention) else False)
if args.prompt_tuning and args.initialize_prompts_with_random_embeddings:
model.module.initialize_prompt_params_with_random_embeddings()
ctr = 0
del checkpoint_dict
else:
if args.use_official_pretrained:
print("Training from official pretrained model")
if args.prompt_tuning and args.initialize_prompts_with_random_embeddings:
model.module.initialize_prompt_params_with_random_embeddings()
else:
print("Training from scratch")
CHECKPOINT_PATH = args.model_path
if args.use_fsdp: # For FSDP we will save the model params, optimizer, scheduler and ctr in separate files. This is because FSDP saving everything in a single file is too heavy.
if torch_version.startswith("2."): ## From 2.0 onwards, FSDP allows sharded optimizers so we will load a sharded optimizer.
model_shard_writer = FileSystemWriter(CHECKPOINT_PATH + "_sharded")
optim_shard_writer = FileSystemWriter(CHECKPOINT_PATH + "_optim_sharded")
with FSDP.state_dict_type(model, StateDictType.LOCAL_STATE_DICT):
state_dict = model.state_dict()
save_state_dict(state_dict, model_shard_writer)
optim_dict = {"optim": FSDP.optim_state_dict(model, optimizer)}
save_state_dict(optim_dict, optim_shard_writer)
del state_dict
del optim_dict
## Also save the full state dict for the model and optimizer.
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, fullstate_save_policy, optimizer_fullstate_save_policy):
state_dict = model.state_dict()
optim_dict = {"optim": FSDP.full_optim_state_dict(model, optimizer)}
if rank == 0:
torch.save(state_dict, CHECKPOINT_PATH)
torch.save(optim_dict, CHECKPOINT_PATH + "_optim")
scheduler_and_ctr = {'scheduler': scheduler.state_dict(), 'ctr': 0}
torch.save(scheduler_and_ctr, CHECKPOINT_PATH + "_scheduler_and_ctr")
os.system("cp "+CHECKPOINT_PATH+" "+CHECKPOINT_PATH+"_deploy/pytorch_model.bin")
del scheduler_and_ctr
del state_dict
del optim_dict
else:
shard_writer = FileSystemWriter(CHECKPOINT_PATH + "_sharded")
with FSDP.state_dict_type(model, StateDictType.LOCAL_STATE_DICT):
state_dict = model.state_dict()
save_state_dict(state_dict, shard_writer)
del state_dict
state_dict = None
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, fullstate_save_policy): ## This full state dict is what is messing things up. The model should be saved as local state dicts and then assembled as a full state dict in the end if needed. A presharding and unsharding script may be useful. We have used an offload to CPU policy and this means we hopefully wont run out of memory.
state_dict = model.state_dict()
optim_dict = FSDP.full_optim_state_dict(model, optimizer)
if rank == 0:
torch.save(state_dict, CHECKPOINT_PATH)
torch.save(optim_dict, CHECKPOINT_PATH + "_optim")
scheduler_and_ctr = {'scheduler': scheduler.state_dict(), 'ctr': 0}
torch.save(scheduler_and_ctr, CHECKPOINT_PATH + "_scheduler_and_ctr")
os.system("cp "+CHECKPOINT_PATH+" "+CHECKPOINT_PATH+"_deploy/pytorch_model.bin")
del scheduler_and_ctr
del state_dict
del optim_dict
else:
if rank == 0:
checkpoint_dict = {'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'scheduler': scheduler.state_dict(), 'ctr': 0}
torch.save(checkpoint_dict, CHECKPOINT_PATH) ## Save a model by default every eval_every steps. This model will be saved with the same file name each time.
torch.save(model.module.state_dict(), CHECKPOINT_PATH+".pure_model")
os.system("cp "+CHECKPOINT_PATH+".pure_model "+CHECKPOINT_PATH+"_deploy/pytorch_model.bin")
del checkpoint_dict
dist.barrier()
if args.use_fsdp: ## This is consuming CPU ram. Need an optimization here. We need to make a decision whether we are going to go for a full state dict or a local state dict. If we are going to go for a full state dict, we need to make sure that we are not going to run out of memory.
if torch_version.startswith("2."): ## From 2.0 onwards, FSDP allows sharded optimizers so we will load a sharded optimizer.
reader = FileSystemReader(CHECKPOINT_PATH + "_sharded")
with FSDP.state_dict_type(model, StateDictType.LOCAL_STATE_DICT):
state_dict = model.state_dict()
load_state_dict(state_dict, reader)
model.load_state_dict(state_dict)
del state_dict
reader = FileSystemReader(CHECKPOINT_PATH + "_optim_sharded")
with FSDP.state_dict_type(model, StateDictType.LOCAL_STATE_DICT):
optim_dict = load_sharded_optimizer_state_dict(
model_state_dict=model.state_dict(),
optimizer_key="optim",
storage_reader=reader,
)
flattened_osd = FSDP.optim_state_dict_to_load(model, optimizer, optim_dict["optim"])
optimizer.load_state_dict(flattened_osd)
del flattened_osd
del optim_dict
scheduler_and_ctr = torch.load(CHECKPOINT_PATH + "_scheduler_and_ctr")
scheduler.load_state_dict(scheduler_and_ctr['scheduler'])
ctr = scheduler_and_ctr['ctr']
del scheduler_and_ctr
else:
reader = FileSystemReader(CHECKPOINT_PATH + "_sharded")
with FSDP.state_dict_type(model, StateDictType.LOCAL_STATE_DICT):
state_dict = model.state_dict()
load_state_dict(state_dict, reader)
model.load_state_dict(state_dict)
full_optimizer = None
if rank == 0:
full_optimizer = torch.load(CHECKPOINT_PATH+ "_optim") ## We now load only the optimizer and scheduler.
sharded_optimizer = FSDP.scatter_full_optim_state_dict(full_optimizer, model)
optimizer.load_state_dict(sharded_optimizer)
scheduler_and_ctr = torch.load(CHECKPOINT_PATH + "_scheduler_and_ctr")
scheduler.load_state_dict(scheduler_and_ctr['scheduler'])
ctr = scheduler_and_ctr['ctr']
del scheduler_and_ctr
del sharded_optimizer
del full_optimizer
del state_dict
else:
map_location = {'cuda:%d' % 0: 'cuda:%d' % gpu}
checkpoint_dict = torch.load(CHECKPOINT_PATH, map_location=map_location)
model.load_state_dict(checkpoint_dict['model'])
optimizer.load_state_dict(checkpoint_dict['optimizer'])
scheduler.load_state_dict(checkpoint_dict['scheduler'])
ctr = checkpoint_dict['ctr']
del checkpoint_dict
torch.cuda.empty_cache()
dist.barrier()
model.train()
print("Using label smoothing of", args.label_smoothing)
print("Using gradient clipping norm of", args.max_gradient_clip_value)
print("Using softmax temperature of", args.softmax_temperature)
if args.max_ent_weight != -1:
print("Doing entropy maximization during loss computation.")
if args.multistep_optimizer_steps > 1:
print("Using a multistep optimizer where gradients will be accumulated over", args.multistep_optimizer_steps, "batches.")
if args.ewc_importance != 0: ## Set up elastic weight consolidation
print("Using Elastic Weight Consolidation with importance", args.ewc_importance)
print("Number of training batches to compute Fisher coefficients:", args.ewc_samples)
num_batches_tmp = args.num_batches
args.num_batches = args.ewc_samples
print("Learning Fisher coefficients.")
ewc_loss = EWC(model, generate_batches_monolingual_masked(tok, args, ewc_files, rank), gpu, args.label_smoothing, ignore_index=tok.pad_token_id)
args.num_batches = num_batches_tmp
print("Fisher coefficients learned.")
num_batches_this_optimizer_step = 0
losses = 0
batch_stats = torch.zeros(7, dtype=torch.long, device=gpu) # We want to keep track of batch statistics.
avg_memory_stats = torch.zeros(2, dtype=torch.float, device=gpu)
start = time.time()
for (input_ids, input_masks, decoder_input_ids, labels), is_bilingual in generate_batches_monolingual_masked_or_bilingual(tok, args, rank, files, train_files): #Batches are generated from here. The argument (0.30, 0.40) is a range which indicates the percentage of the source sentence to be masked in case we want masking during training just like we did during BART pretraining. The argument 3.5 is the lambda to the poisson length sampler which indicates the average length of a word sequence that will be masked. Since this is pretraining we do not do any evaluations even if we train on parallel corpora.
if num_batches_this_optimizer_step == 0: ## This is the first batch of this optimizer step.
optimizer.zero_grad(set_to_none=True) ## Empty the gradients before any computation.
if ctr % args.save_every == 0 and num_batches_this_optimizer_step == 0: ## We have to evaluate our model every save_every steps. Since there is no evaluation data during pretraining this means our model is saved every save_every steps.
CHECKPOINT_PATH = args.model_path
if args.use_fsdp: # For FSDP we will save the model params, optimizer, scheduler and ctr in separate files. This is because FSDP saving everything in a single file is too heavy.
if torch_version.startswith("2."): ## From 2.0 onwards, FSDP allows sharded optimizers so we will load a sharded optimizer.
model_shard_writer = FileSystemWriter(CHECKPOINT_PATH + "_sharded")
optim_shard_writer = FileSystemWriter(CHECKPOINT_PATH + "_optim_sharded")
with FSDP.state_dict_type(model, StateDictType.LOCAL_STATE_DICT):
state_dict = model.state_dict()
save_state_dict(state_dict, model_shard_writer)
optim_dict = {"optim": FSDP.optim_state_dict(model, optimizer)}
save_state_dict(optim_dict, optim_shard_writer)
if ctr % args.save_intermediate_checkpoints_every == 0 and args.save_intermediate_checkpoints:
model_shard_writer = FileSystemWriter(CHECKPOINT_PATH +"."+str(ctr)+ "_sharded")
optim_shard_writer = FileSystemWriter(CHECKPOINT_PATH +"."+str(ctr)+ "_optim_sharded")
save_state_dict(state_dict, model_shard_writer)
save_state_dict(optim_dict, optim_shard_writer)
del state_dict
del optim_dict
## Also save the full state dict for the model and optimizer.
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, fullstate_save_policy, optimizer_fullstate_save_policy):
state_dict = model.state_dict()
optim_dict = {"optim": FSDP.full_optim_state_dict(model, optimizer)}
if rank == 0:
torch.save(state_dict, CHECKPOINT_PATH)
torch.save(optim_dict, CHECKPOINT_PATH + "_optim")
scheduler_and_ctr = {'scheduler': scheduler.state_dict(), 'ctr': ctr}
torch.save(scheduler_and_ctr, CHECKPOINT_PATH + "_scheduler_and_ctr")
os.system("cp "+CHECKPOINT_PATH+" "+CHECKPOINT_PATH+"_deploy/pytorch_model.bin")
if ctr % args.save_intermediate_checkpoints_every == 0 and args.save_intermediate_checkpoints:
torch.save(state_dict, CHECKPOINT_PATH +"."+str(ctr))
torch.save(optim_dict, CHECKPOINT_PATH +"."+str(ctr)+ "_optim")
torch.save(scheduler_and_ctr, CHECKPOINT_PATH +"."+str(ctr)+ "_scheduler_and_ctr")
del scheduler_and_ctr
del state_dict
del optim_dict
else:
with FSDP.state_dict_type(model, StateDictType.LOCAL_STATE_DICT):
state_dict = model.state_dict()
if ctr % args.save_intermediate_checkpoints_every == 0 and args.save_intermediate_checkpoints:
shard_writer = FileSystemWriter(CHECKPOINT_PATH +"."+str(ctr)+ "_sharded")
save_state_dict(state_dict, shard_writer)
shard_writer = FileSystemWriter(CHECKPOINT_PATH + "_sharded")
save_state_dict(state_dict, shard_writer)
del state_dict
state_dict = None
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, fullstate_save_policy): ## This full state dict is what is messing things up. The model should be saved as local state dicts and then assembled as a full state dict in the end if needed. A presharding and unsharding script may be useful. We have used an offload to CPU policy and this means we hopefully wont run out of memory.
state_dict = model.state_dict()
optim_dict = FSDP.full_optim_state_dict(model, optimizer)
if rank == 0:
print("Saving the model")
torch.save(state_dict, CHECKPOINT_PATH)
scheduler_and_ctr = {'scheduler': scheduler.state_dict(), 'ctr': ctr}
torch.save(optim_dict, CHECKPOINT_PATH + "_optim")
torch.save(scheduler_and_ctr, CHECKPOINT_PATH + "_scheduler_and_ctr")
os.system("cp "+CHECKPOINT_PATH+" "+CHECKPOINT_PATH+"_deploy/pytorch_model.bin")
if ctr % args.save_intermediate_checkpoints_every == 0 and args.save_intermediate_checkpoints:
print("Saving an intermediate checkpoint")
torch.save(state_dict, CHECKPOINT_PATH+"."+str(ctr))
torch.save(optim_dict, CHECKPOINT_PATH + "."+str(ctr)+"_optim")
torch.save(scheduler_and_ctr, CHECKPOINT_PATH + "."+str(ctr)+"_scheduler_and_ctr")
del scheduler_and_ctr
del state_dict
del optim_dict
else:
if rank == 0:
print("Saving the model")
checkpoint_dict = {'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'scheduler': scheduler.state_dict(), 'ctr': ctr}
if ctr % args.save_intermediate_checkpoints_every == 0 and args.save_intermediate_checkpoints:
print("Saving an intermediate checkpoint")
torch.save(checkpoint_dict, CHECKPOINT_PATH+"."+str(ctr))
sys.stdout.flush()
torch.save(checkpoint_dict, CHECKPOINT_PATH) ## Save a model by default every eval_every steps. This model will be saved with the same file name each time.
torch.save(model.module.state_dict(), CHECKPOINT_PATH+".pure_model")
os.system("cp "+CHECKPOINT_PATH+".pure_model "+CHECKPOINT_PATH+"_deploy/pytorch_model.bin")
del checkpoint_dict
# Use a barrier() to make sure that process 1 loads the model after process
# 0 saves it.
dist.barrier()
# start = time.time() ## All eval and ckpt saving is done here so start counting from here.
if args.num_domains_for_domain_classifier > 1: ## The label will contain the label as well as the domain indicator
domain_classifier_labels=labels[1] ## This is not a tensor yet
# print(domain_classifier_labels)
domain_classifier_labels = torch.tensor(domain_classifier_labels, dtype=torch.int64).to(gpu) ## Move to gpu
labels=labels[0]
label_mask = labels.eq(tok.pad_token_id).unsqueeze(-1).to(gpu)
input_ids=input_ids.to(gpu) ## Move to gpu
input_masks=input_masks.to(gpu) ## Move to gpu
decoder_input_ids=decoder_input_ids.to(gpu) ## Move to gpu
labels=labels.to(gpu) ## Move to gpu
if args.mixed_wait_k:
model.module.config.wait_k = random.randint(1, args.wait_k)
if args.prompt_tuning:
input_shape = input_masks.size()
encoder_pad = torch.ones(input_shape[0], args.num_prompts).clone().detach()
input_masks = torch.cat([encoder_pad, input_masks], dim=1)
with torch.cuda.amp.autocast() if (args.fp16 and not args.use_fsdp) else nullcontext(): ## The difference between AMP and FP32 is the use of the autocast. I am not sure if I should use autocast with FSDP or not. Some people use it. Some dont. In one of the issues on pytorch someone said it shouldnt matter. https://github.com/pytorch/pytorch/issues/76607#issuecomment-1370053227
if is_bilingual and args.unify_encoder:
source_hidden_state_encoder = model.module.get_encoder()(input_ids=input_ids, attention_mask=input_masks).last_hidden_state ## Run the encoder for source sentence.
decoder_input_masks = (decoder_input_ids != tok.pad_token_id).int().to(gpu)
target_hidden_state_encoder = model.module.get_encoder()(input_ids=decoder_input_ids, attention_mask=decoder_input_masks).last_hidden_state ## Run the encoder for source sentence.
del decoder_input_masks ## Delete to avoid retention.
pad_mask = input_ids.eq(tok.pad_token_id).unsqueeze(2)
source_hidden_state_encoder.masked_fill_(pad_mask, 0.0)
source_hidden_state_encoder = source_hidden_state_encoder.mean(dim=1)
pad_mask = decoder_input_ids.eq(tok.pad_token_id).unsqueeze(2)
target_hidden_state_encoder.masked_fill_(pad_mask, 0.0)
target_hidden_state_encoder = target_hidden_state_encoder.mean(dim=1)
loss = -cosine_similarity(source_hidden_state_encoder, target_hidden_state_encoder)
if rank == 0:
writer.add_scalar("encoder unification loss", loss.detach().cpu().numpy(), ctr)
if args.wb:
wandb.log({"encoder unification loss": loss.detach().cpu().numpy()}, step=ctr)
else:
mod_compute = model(input_ids=input_ids, attention_mask=input_masks, decoder_input_ids=decoder_input_ids, output_hidden_states=args.distillation, output_attentions=args.distillation, label_mask=label_mask if args.num_domains_for_domain_classifier > 1 else None) ## Run the model and get logits.
logits = mod_compute.logits
lprobs = torch.nn.functional.log_softmax(logits, dim=-1) ## Softmax tempering of logits if needed.
loss = label_smoothed_nll_loss(
lprobs, labels, args.label_smoothing, ignore_index=tok.pad_token_id
) ## Label smoothed cross entropy loss.
loss = loss*args.softmax_temperature ## Up scale loss in case of non unitary temperatures. Note that in case of self calibrating temperature, the softmax temperature must be set to 1.
if rank == 0:
writer.add_scalar("pure cross entropy loss", loss.detach().cpu().numpy(), ctr)
if args.wb:
wandb.log({"pure cross entropy loss": loss.detach().cpu().numpy()}, step=ctr)
if args.ewc_importance != 0: ## Update the model with the EWC loss.
ewc_loss_current = args.ewc_importance * ewc_loss.penalty(model)
if rank == 0:
writer.add_scalar("EWC loss", ewc_loss_current.detach().cpu().numpy(), ctr)
if args.wb:
wandb.log({"EWC loss": ewc_loss_current.detach().cpu().numpy()}, step=ctr)
loss = loss + ewc_loss_current
if args.temperature_calibration:
loss = loss*mod_compute.softmax_temperature
if rank == 0:
writer.add_scalar("calibrated temperature", mod_compute.softmax_temperature.detach().cpu().numpy(), ctr)
writer.add_scalar("calibrated temperature loss", loss.detach().cpu().numpy(), ctr)
if args.wb:
wandb.log({"calibrated temperature": mod_compute.softmax_temperature.detach().cpu().numpy()}, step=ctr)
wandb.log({"calibrated temperature loss": loss.detach().cpu().numpy()}, step=ctr)
if args.num_domains_for_domain_classifier > 1: ## We augment the main loss with the domain classifier loss
domain_classifier_logits = mod_compute.domain_classifier_logits
domain_classifier_lprobs = torch.nn.functional.log_softmax(domain_classifier_logits, dim=-1) ## Softmax tempering of logits if needed.
domain_classifier_loss = label_smoothed_nll_loss(
domain_classifier_lprobs.view(-1,args.num_domains_for_domain_classifier), domain_classifier_labels.view(-1,1), args.label_smoothing
) ## Label smoothed cross entropy loss. We are not going to do any temperature related stuff to this.
loss = domain_classifier_loss*args.domain_classifier_loss_weight + loss * (1.0-args.domain_classifier_loss_weight)
if rank == 0:
writer.add_scalar("domain classifier loss", domain_classifier_loss.detach().cpu().numpy(), ctr)
writer.add_scalar("loss with domain classifier loss", loss.detach().cpu().numpy(), ctr)
if args.wb:
wandb.log({"domain classifier loss": domain_classifier_loss.detach().cpu().numpy()}, step=ctr)
wandb.log({"loss with domain classifier loss": loss.detach().cpu().numpy()}, step=ctr)
## We will do multilayer softmaxing without any consideration for distillation or domain classification.
if mod_compute.additional_lm_logits is not None:
for additional_logits in mod_compute.additional_lm_logits:
lprobs = torch.nn.functional.log_softmax(additional_logits, dim=-1) ## Softmax tempering of logits if needed.
loss_extra = label_smoothed_nll_loss(
lprobs, labels, args.label_smoothing, ignore_index=tok.pad_token_id
) ## Label smoothed cross entropy loss.
loss_extra = loss_extra*args.softmax_temperature ## Up scale loss in case of non unitary temperatures. Note that in case of self calibrating temperature, the softmax temperature must be set to 1. TODO: Perhaps log this too.
if args.temperature_calibration:
loss_extra = loss_extra*mod_compute.softmax_temperature
loss += loss_extra ## Up scale loss in case of non unitary temperatures. TODO: Perhaps log this too.
if args.max_ent_weight != -1: ## This deals with softmax entropy maximization. The logic is that we compute the softmax entropy of the predictions via -(P(Y/X)*log(P(Y/X))). We then add it to the cross entropy loss with a negative sign as we wish to maximize entropy. This should penalize overconfident predictions.
assert (args.max_ent_weight >= 0 and args.max_ent_weight <= 1)
logits = logits*args.softmax_temperature ## We have to undo the tempered logits else our entropy estimate will be wrong.
if args.temperature_calibration:
logits = logits*mod_compute.softmax_temperature
lprobs = torch.nn.functional.log_softmax(logits, dim=-1) ## No tempering here
entropy = -(torch.exp(lprobs)*lprobs).mean()
if rank == 0:
writer.add_scalar("softmax entropy", entropy.detach().cpu().numpy(), ctr)
if args.wb:
wandb.log({"softmax entropy": entropy.detach().cpu().numpy()}, step=ctr)
if mod_compute.additional_lm_logits is not None:
for additional_logits in mod_compute.additional_lm_logits: ## Compute entropy for each layer as well
additional_logits = additional_logits*args.softmax_temperature ## We have to undo the tempered logits else our entropy estimate will be wrong.
if args.temperature_calibration:
additional_logits = additional_logits*mod_compute.softmax_temperature
lprobs = torch.nn.functional.log_softmax(additional_logits, dim=-1) ## No tempering here
entropy_extra = -(torch.exp(lprobs)*lprobs).mean()
entropy += entropy_extra
loss = loss*(1-args.max_ent_weight) - entropy*args.max_ent_weight ## Maximize the entropy so a minus is needed. Weigh and add losses as required.
if rank == 0:
writer.add_scalar("loss with entropy loss", loss.detach().cpu().numpy(), ctr)
if args.wb:
wandb.log({"loss with entropy loss": loss.detach().cpu().numpy()}, step=ctr)
if args.distillation: ## Time to distill.
with torch.no_grad(): ## No gradient to avoid memory allocation.
parent_mod_compute = parent_model(input_ids=input_ids, attention_mask=input_masks ,decoder_input_ids=decoder_input_ids, output_hidden_states=args.distillation, output_attentions=args.distillation)
distillation_loss = compute_distillation_losses(mod_compute, parent_mod_compute, labels, tok.pad_token_id, args) ## Get the parent model's computations.
loss = args.distillation_loss_weight*distillation_loss + (1.0 - args.distillation_loss_weight)*loss ## Update the main loss with weighing and adding.
if rank == 0:
writer.add_scalar("distillation loss", distillation_loss.detach().cpu().numpy(), ctr)
writer.add_scalar("final loss", loss.detach().cpu().numpy(), ctr)
if args.wb:
wandb.log({"distillation loss": distillation_loss.detach().cpu().numpy()}, step=ctr)
wandb.log({"final loss": loss.detach().cpu().numpy()}, step=ctr)
if args.use_moe or args.moe_adaptors: ## add MOE losses too.
moe_loss = torch.sum(torch.stack(mod_compute.encoder_moe_losses)) + torch.sum(torch.stack(mod_compute.decoder_moe_losses))
if rank == 0:
writer.add_scalar("moe loss", moe_loss.detach().cpu().numpy(), ctr)
if args.wb:
wandb.log({"moe loss": moe_loss.detach().cpu().numpy()}, step=ctr)
loss += moe_loss
if args.sparsify_attention or args.sparsify_ffn: ## add sparsification losses too.
sparsification_loss = torch.sum(torch.stack(mod_compute.encoder_sparsification_l0_losses)) + torch.sum(torch.stack(mod_compute.decoder_sparsification_l0_losses))
if rank == 0:
writer.add_scalar("sparsification loss", sparsification_loss.detach().cpu().numpy(), ctr)
if args.wb:
wandb.log({"sparsification loss": sparsification_loss.detach().cpu().numpy()}, step=ctr)
loss += sparsification_loss * args.sparsification_lambda
if args.contrastive_decoder_training: ## Shuffle the decoder input and label batches and compute loss. This should be negated and added to the overall loss.
batch_size = decoder_input_ids.size()[0]
shuffle_indices = torch.randperm(batch_size)
decoder_input_ids = decoder_input_ids[shuffle_indices]
labels = labels[shuffle_indices]
mod_compute = model(input_ids=input_ids, attention_mask=input_masks, decoder_input_ids=decoder_input_ids) ## Run the model and get logits.
logits = mod_compute.logits
lprobs = torch.nn.functional.log_softmax(logits, dim=-1)
contrastive_loss = label_smoothed_nll_loss(
lprobs, labels, args.label_smoothing, ignore_index=tok.pad_token_id
) ## Label smoothed cross entropy loss.
loss -= contrastive_loss
fwd_memory = torch.cuda.memory_allocated(gpu)/(1024**3)
padding_tokens_dec = (decoder_input_ids == tok.pad_token_id).sum().item()
padding_tokens_enc = (input_ids == tok.pad_token_id).sum().item()
total_tokens_dec = decoder_input_ids.numel()
total_tokens_enc = input_ids.numel()
non_padding_tokens_dec = total_tokens_dec - padding_tokens_dec
non_padding_tokens_enc = total_tokens_enc - padding_tokens_enc
num_sequences = input_ids.size()[0]
batch_stats += torch.tensor([non_padding_tokens_enc, padding_tokens_enc, total_tokens_enc, non_padding_tokens_dec, padding_tokens_dec, total_tokens_dec, num_sequences], dtype=torch.long, device=gpu)
## Optimization part of the model from this point forward.
if args.fp16: ## The gradient scaler needs to be invoked with FP16/AMP computation. ## With FP16/AMP computation we need to unscale gradients before clipping them. We then optimize and update the scaler.
loss = loss/args.multistep_optimizer_steps
scaler.scale(loss).backward()
bwd_memory = torch.cuda.memory_allocated(gpu)/(1024**3)
avg_memory_stats += torch.tensor([fwd_memory, bwd_memory], dtype=torch.float, device=gpu)
num_batches_this_optimizer_step += 1
losses += loss.detach().cpu().numpy()
if num_batches_this_optimizer_step < args.multistep_optimizer_steps:
continue
if args.max_gradient_clip_value != 0.0:
scaler.unscale_(optimizer)
if args.use_fsdp:
model.clip_grad_norm_(args.max_gradient_clip_value)
else:
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_gradient_clip_value)
scaler.step(optimizer)
scaler.update()
current_scale_value = scaler.get_scale()
# If the scale value changed then print it.
if current_scale_value != scale_value:
print("Gradient scale value changed from {} to {}".format(scale_value, current_scale_value))
scale_value = current_scale_value
else: ## With FP32, we just do regular backpropagation, gradient clipping and then step the optimizer.
loss = loss/args.multistep_optimizer_steps
loss.backward()
bwd_memory = torch.cuda.memory_allocated(gpu)/(1024**3)
avg_memory_stats += torch.tensor([fwd_memory, bwd_memory], dtype=torch.float, device=gpu)
num_batches_this_optimizer_step += 1
losses += loss.detach().cpu().numpy()
if num_batches_this_optimizer_step < args.multistep_optimizer_steps:
continue
if args.max_gradient_clip_value != 0.0:
if args.use_fsdp:
model.clip_grad_norm_(args.max_gradient_clip_value)
else:
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_gradient_clip_value)
optimizer.step()
scheduler.step() ## Advance the scheduler to get to the next value of LR.
lv = losses ## Detach the loss in order to report it.
losses = 0
num_batches_this_optimizer_step = 0
if ctr % 100 == 0: ## Print the current loss every 100 batches but only for the master/prime process.
# All reduce the batch stats.
end = time.time()
torch.distributed.all_reduce(batch_stats)
torch.distributed.all_reduce(avg_memory_stats)
avg_memory_stats = avg_memory_stats/args.world_size/100/args.multistep_optimizer_steps
fwd_memory, bwd_memory = avg_memory_stats.tolist()
# Round the memory stats to 2 decimal places.
fwd_memory = round(fwd_memory, 2)
bwd_memory = round(bwd_memory, 2)
non_padding_tokens_enc, padding_tokens_enc, total_tokens_enc, non_padding_tokens_dec, padding_tokens_dec, total_tokens_dec, num_sequences = batch_stats.tolist()
non_padding_percent_enc = round(non_padding_tokens_enc/total_tokens_enc*100, 2)
non_padding_percent_dec = round(non_padding_tokens_dec/total_tokens_dec*100, 2)
if rank % args.world_size == 0:
print("Step:", ctr, "| Loss:", round(lv.item(),2), " | Time:", round(end-start, 2), "s/100 batches. | Fwd-Bwd (avg):", fwd_memory, ",", bwd_memory, "GB. | Enc Non-pad, Pad, Total tokens, Non pad percentage:", non_padding_tokens_enc, ",", padding_tokens_enc, ",", total_tokens_enc, ",", non_padding_percent_enc, "% | Dec Non-pad, Pad, Total tokens, Non pad percentage:", non_padding_tokens_dec, ",", padding_tokens_dec, ",", total_tokens_dec, ",", non_padding_percent_dec, "% | Num sequences:", num_sequences)
sys.stdout.flush()
batch_stats = torch.zeros(7, dtype=torch.long, device=gpu)
avg_memory_stats = torch.zeros(2, dtype=torch.float, device=gpu)
start = time.time()
del input_ids ## Delete to avoid retention.
del input_masks ## Delete to avoid retention.
del decoder_input_ids ## Delete to avoid retention.
del labels ## Delete to avoid retention.
if args.num_domains_for_domain_classifier > 1:
del domain_classifier_labels
del label_mask
if ctr % 1000 == 0 and rank == 0 and args.save_weights_and_gradeint_info: ## Save the model weight and gradient info every time this condition is triggered.
for param_name, param_value in model.named_parameters():