Skip to content

Commit

Permalink
working ewc and mas
Browse files Browse the repository at this point in the history
  • Loading branch information
kaushal-py committed May 22, 2024
1 parent 761165a commit 5bf5e6d
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,9 @@ def compute_ewc_params(model, dataloader,device):
model.zero_grad()
nb += 1
# loss.detach().cpu()
# if e == 5:
# del cuda_batch, model, loss
# break
if e == 5:
del cuda_batch, model, loss
break

for name in fisher:
fisher[name] /= nb
Expand Down Expand Up @@ -124,9 +124,9 @@ def compute_mas_params(model, dataloader,device):
model.zero_grad()
nb += 1
# loss.detach().cpu()
# if e == 5:
# del cuda_batch, model, loss
# break
if e == 5:
del cuda_batch, model, loss
break

for name in importance:
importance[name] /= nb
Expand Down Expand Up @@ -159,7 +159,7 @@ def main(cfg):
# load the previous checkpoint with the old dataloader
prev_cfg = OmegaConf.load(cl_config.ewc_params.old_config)
trainer = pl.Trainer(**cfg.trainer)
# prev_cfg.model.train_ds.batch_size = 32
prev_cfg.model.train_ds.batch_size = 16

## model contains dataset, this means that this line has loaded all the data of the previous episode
asr_model_old = EncDecHybridRNNTCTCBPEModel(cfg=prev_cfg.model, trainer=trainer)
Expand All @@ -171,6 +171,27 @@ def main(cfg):

## computing the fisher and storing the old params
params, fisher = compute_ewc_params(asr_model_old,asr_model_old._train_dl,device)

## load the old params (if applicable)
old_param_path = f"{os.path.abspath(os.path.join(log_dir,'..'))}/{os.path.split(cl_config.ewc_params.old_config)[-1].split('.')[0]}/ewc.pkl"
breakpoint()
if not ('ep0' in old_param_path and 'full_finetune' in old_param_path):
assert os.path.exists(old_param_path),'Old param path is required'

with open(f'{log_dir}/ewc.pkl','rb') as reader:
saved = pickle.load(reader)
old_fisher = saved['fisher']

# do the necessary scaling
for name in fisher:
if name in old_fisher:
old_importance = old_fisher[name]
fisher[name] *= 1 - cl_config.ewc_params.alpha
fisher[name] += (
cl_config.ewc_params.alpha * old_importance
)

# assert os.path.exists(old_param_path) if 'ep0' not in old_param_path,
with open(f'{log_dir}/ewc.pkl','wb') as writer:
pickle.dump({'params':params,'fisher':fisher},writer)

Expand All @@ -184,7 +205,7 @@ def main(cfg):
with open(f'{log_dir}/ewc.pkl','rb') as reader:
saved = pickle.load(reader)
params, fisher = saved['params'],saved['fisher']

cl_params = {
'alpha': cl_config.ewc_params.alpha,
'lamda': cl_config.ewc_params.lda,
Expand All @@ -193,8 +214,12 @@ def main(cfg):
}

trainer = pl.Trainer(**cfg.trainer)
asr_model = EncDecHybridRNNTCTCBPEModelEWC(cfg=cfg.model, trainer=trainer, cl_params=cl_params)
asr_model = EncDecHybridRNNTCTCBPEModelEWC(cfg=cfg.model, trainer=trainer)
asr_model.maybe_init_from_pretrained_checkpoint(cfg)

# setting cl params
asr_model.set_cl_params(cl_params)

trainer.fit(asr_model)

elif cl_config.name == 'MAS':
Expand All @@ -204,7 +229,7 @@ def main(cfg):
# load the previous checkpoint with the old dataloader
prev_cfg = OmegaConf.load(cl_config.mas_params.old_config)
trainer = pl.Trainer(**cfg.trainer)
# prev_cfg.model.train_ds.batch_size = 32
prev_cfg.model.train_ds.batch_size = 16

## model contains dataset, this means that this line has loaded all the data of the previous episode
asr_model_old = EncDecHybridRNNTCTCBPEModel(cfg=prev_cfg.model, trainer=trainer)
Expand All @@ -214,32 +239,56 @@ def main(cfg):
# asr_model.setup_optimization()
device = 'cuda' if torch.cuda.is_available() else 'cpu'

## computing the importance and storing the old params
params, importance = compute_mas_params(asr_model_old,asr_model_old._train_dl,device)
## computing the mas_importance and storing the old params
params, mas_importance = compute_mas_params(asr_model_old,asr_model_old._train_dl,device)

## load the old params (if applicable)
old_param_path = f"{os.path.abspath(os.path.join(log_dir,'..'))}/{os.path.split(cl_config.mas_params.old_config)[-1].split('.')[0]}/mas.pkl"
if not ('ep0' in old_param_path and 'full_finetune' in old_param_path):
assert os.path.exists(old_param_path),'Old param path is required'

with open(f'{log_dir}/mas.pkl','rb') as reader:
saved = pickle.load(reader)
old_mas_importance = saved['mas_importance']

# do the necessary scaling
for name in mas_importance:
if name in old_mas_importance:
old_importance = old_mas_importance[name]
mas_importance[name] *= 1 - cl_config.mas_params.alpha
mas_importance[name] += (
cl_config.mas_params.alpha * old_importance
)

# assert os.path.exists(old_param_path) if 'ep0' not in old_param_path,
with open(f'{log_dir}/mas.pkl','wb') as writer:
pickle.dump({'params':params,'importance':importance},writer)
pickle.dump({'params':params,'mas_importance':mas_importance},writer)

del asr_model_old, trainer

gc.collect()
torch.cuda.empty_cache()

else:
## load the param and importance
## load the param and mas_importance
with open(f'{log_dir}/mas.pkl','rb') as reader:
saved = pickle.load(reader)
params, importance = saved['params'],saved['importance']

params, mas_importance = saved['params'],saved['mas_importance']
cl_params = {
# 'alpha': cl_config.mas_params.alpha,
'alpha': cl_config.mas_params.alpha,
'lamda': cl_config.mas_params.lda,
'params': params,
'importance': importance
'mas_importance': mas_importance
}

trainer = pl.Trainer(**cfg.trainer)
asr_model = EncDecHybridRNNTCTCBPEModelMAS(cfg=cfg.model, trainer=trainer, cl_params=cl_params)
asr_model = EncDecHybridRNNTCTCBPEModelMAS(cfg=cfg.model, trainer=trainer)
asr_model.maybe_init_from_pretrained_checkpoint(cfg)

# setting cl params
asr_model.set_cl_params(cl_params)

trainer.fit(asr_model)
elif cl_config.name == 'LWF':
pass
Expand Down
18 changes: 9 additions & 9 deletions nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,12 +655,12 @@ def list_available_models(cls) -> List[PretrainedModelInfo]:
return results

class EncDecHybridRNNTCTCBPEModelEWC(EncDecHybridRNNTCTCBPEModel):
def __init__(self, cfg: DictConfig, trainer: Trainer = None, cl_params=None):
def __init__(self, cfg: DictConfig, trainer: Trainer = None):
cfg = model_utils.convert_model_config_to_dict_config(cfg)
cfg = model_utils.maybe_update_config_version(cfg)
super().__init__(cfg=cfg, trainer=trainer)

assert cl_params is not None,'Parameters are null'
def set_cl_params(self,cl_params):
self.lda = cl_params['lamda']
self.alpha = cl_params['alpha']
self.fisher = cl_params['fisher']
Expand Down Expand Up @@ -791,7 +791,7 @@ def training_step(self, batch, batch_nb):

# EWC related changes
for name, param in self.named_parameters():
if not param.requires_grad or param.grad is None:
if not param.requires_grad or param.grad is None or name not in self.fisher:
continue
loss_value += (
(
Expand Down Expand Up @@ -820,15 +820,15 @@ def training_step(self, batch, batch_nb):
return {'loss': loss_value}

class EncDecHybridRNNTCTCBPEModelMAS(EncDecHybridRNNTCTCBPEModel):
def __init__(self, cfg: DictConfig, trainer: Trainer = None, cl_params=None):
def __init__(self, cfg: DictConfig, trainer: Trainer = None):
cfg = model_utils.convert_model_config_to_dict_config(cfg)
cfg = model_utils.maybe_update_config_version(cfg)
super().__init__(cfg=cfg, trainer=trainer)

assert cl_params is not None,'Parameters are null'
def set_cl_params(self,cl_params):
self.lda = cl_params['lamda']
# self.alpha = cl_params['alpha']
self.importance = cl_params['importance']
self.importance = cl_params['mas_importance']
self.old_params = cl_params['params']

def training_step(self, batch, batch_nb):
Expand Down Expand Up @@ -956,7 +956,7 @@ def training_step(self, batch, batch_nb):

# EWC related changes
for name, param in self.named_parameters():
if not param.requires_grad or param.grad is None:
if not param.requires_grad or param.grad is None or name not in self.importance:
continue
loss_value += (
(
Expand Down

0 comments on commit 5bf5e6d

Please sign in to comment.