Skip to content

Commit

Permalink
Estim first token instead of average (#46)
Browse files Browse the repository at this point in the history
* estim encoder based on first token
  • Loading branch information
vince62s authored Jul 2, 2024
1 parent 0b70de6 commit b770072
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 16 deletions.
25 changes: 15 additions & 10 deletions eole/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,8 +178,9 @@ def maybe_lora(self, running_config):
hasattr(running_config, "lora_layers")
and len(running_config.lora_layers) > 0
):
if running_config.freeze_encoder or running_config.freeze_decoder:
raise ValueError("Cannot use LoRa with Enc/Dec-oder freezing")
# I think we need to allow encoder freezing while training LoRa
# if running_config.freeze_encoder or running_config.freeze_decoder:
# raise ValueError("Cannot use LoRa with Enc/Dec-oder freezing")
for layer in running_config.lora_layers:
if (
hasattr(running_config, "quant_layers")
Expand Down Expand Up @@ -716,12 +717,12 @@ def load_safe_state_dict(
elif strict and (
"lora" not in param_name and "slopes" not in param_name
):
pass # TO FIX - patch for estimator
# raise ValueError(
# "Missing key in safetensors checkpoint: %s" % name
# + "."
# + param_name
# )
# Let's warn instead of just passing
logger.info(
"Missing key in safetensors checkpoint: %s" % name
+ "."
+ param_name
)
if precision == torch.int8:
torch.quantization.quantize_dynamic(module, inplace=True)
else:
Expand Down Expand Up @@ -927,13 +928,17 @@ def forward(self, src, tgt, src_len, bptt=False, with_align=False):

mask = sequence_mask(src_len)
enc_out, enc_final_hs = self.encoder(self.src_emb(src), mask=mask)
if self.add_estimator: # on prend en compte les average de enc_out et dec_out
if self.add_estimator:
# Version with average
"""
pad_mask1 = ~src.eq(1)
in_estim1 = (enc_out * pad_mask1.unsqueeze(-1).float()).sum(
dim=1
) / pad_mask1.sum(dim=1, keepdim=True).float()
estim = self.estimator(in_estim1.half()).squeeze(-1)
# estim = self.estimator(enc_out[:, 0, :]).squeeze(-1)
"""
# Version with first token
estim = self.estimator(enc_out[:, 0, :]).squeeze(-1)
else:
estim = None

Expand Down
8 changes: 5 additions & 3 deletions eole/predict/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,14 +113,16 @@ def _predict_batch_with_strategy(self, batch, decode_strategy):
)

if self.add_estimator:
"""
# Version with encoder out average
pad_mask1 = ~src.eq(1)
in_estim1 = (enc_out * pad_mask1.unsqueeze(-1).float()).sum(
dim=1
) / pad_mask1.sum(dim=1, keepdim=True).float()
estim = self.model.estimator(in_estim1.half()).squeeze(-1)
# estim = self.model.estimator(
# enc_out[:, 0, :]
# ).squeeze(-1)
"""
# Version with first token embedding (same as COMET)
estim = self.model.estimator(enc_out[:, 0, :]).squeeze(-1)
else:
estim = torch.ones([enc_out.size(0)])
estim = [[item] for item in estim.tolist()]
Expand Down
5 changes: 2 additions & 3 deletions eole/transforms/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,8 @@ def _parse_config(self):

def apply(self, example, is_train=False, stats=None, **kwargs):
"""Return None if too long else return as is."""
if (
len(example["src"]) > self.src_seq_length
or len(example["tgt"]) > self.tgt_seq_length - 2
if len(example["src"]) > self.src_seq_length or (
example["tgt"] is not None and len(example["tgt"]) > self.tgt_seq_length - 2
):
if stats is not None:
stats.update(FilterTooLongStats())
Expand Down

0 comments on commit b770072

Please sign in to comment.