diff --git a/eole/models/model.py b/eole/models/model.py index b4a7c112..d2d8b896 100644 --- a/eole/models/model.py +++ b/eole/models/model.py @@ -178,8 +178,9 @@ def maybe_lora(self, running_config): hasattr(running_config, "lora_layers") and len(running_config.lora_layers) > 0 ): - if running_config.freeze_encoder or running_config.freeze_decoder: - raise ValueError("Cannot use LoRa with Enc/Dec-oder freezing") + # I think we need to allow encoder freezing while training LoRa + # if running_config.freeze_encoder or running_config.freeze_decoder: + # raise ValueError("Cannot use LoRa with Enc/Dec-oder freezing") for layer in running_config.lora_layers: if ( hasattr(running_config, "quant_layers") @@ -716,12 +717,12 @@ def load_safe_state_dict( elif strict and ( "lora" not in param_name and "slopes" not in param_name ): - pass # TO FIX - patch for estimator - # raise ValueError( - # "Missing key in safetensors checkpoint: %s" % name - # + "." - # + param_name - # ) + # Let's warn instead of just passing + logger.info( + "Missing key in safetensors checkpoint: %s" % name + + "." + + param_name + ) if precision == torch.int8: torch.quantization.quantize_dynamic(module, inplace=True) else: @@ -927,13 +928,17 @@ def forward(self, src, tgt, src_len, bptt=False, with_align=False): mask = sequence_mask(src_len) enc_out, enc_final_hs = self.encoder(self.src_emb(src), mask=mask) - if self.add_estimator: # on prend en compte les average de enc_out et dec_out + if self.add_estimator: + # Version with average + """ pad_mask1 = ~src.eq(1) in_estim1 = (enc_out * pad_mask1.unsqueeze(-1).float()).sum( dim=1 ) / pad_mask1.sum(dim=1, keepdim=True).float() estim = self.estimator(in_estim1.half()).squeeze(-1) - # estim = self.estimator(enc_out[:, 0, :]).squeeze(-1) + """ + # Version with first token + estim = self.estimator(enc_out[:, 0, :]).squeeze(-1) else: estim = None diff --git a/eole/predict/encoder.py b/eole/predict/encoder.py index f917e55c..232af03b 100644 --- a/eole/predict/encoder.py +++ b/eole/predict/encoder.py @@ -113,14 +113,16 @@ def _predict_batch_with_strategy(self, batch, decode_strategy): ) if self.add_estimator: + """ + # Version with encoder out average pad_mask1 = ~src.eq(1) in_estim1 = (enc_out * pad_mask1.unsqueeze(-1).float()).sum( dim=1 ) / pad_mask1.sum(dim=1, keepdim=True).float() estim = self.model.estimator(in_estim1.half()).squeeze(-1) - # estim = self.model.estimator( - # enc_out[:, 0, :] - # ).squeeze(-1) + """ + # Version with first token embedding (same as COMET) + estim = self.model.estimator(enc_out[:, 0, :]).squeeze(-1) else: estim = torch.ones([enc_out.size(0)]) estim = [[item] for item in estim.tolist()] diff --git a/eole/transforms/misc.py b/eole/transforms/misc.py index aabb51f1..c65cab20 100644 --- a/eole/transforms/misc.py +++ b/eole/transforms/misc.py @@ -58,9 +58,8 @@ def _parse_config(self): def apply(self, example, is_train=False, stats=None, **kwargs): """Return None if too long else return as is.""" - if ( - len(example["src"]) > self.src_seq_length - or len(example["tgt"]) > self.tgt_seq_length - 2 + if len(example["src"]) > self.src_seq_length or ( + example["tgt"] is not None and len(example["tgt"]) > self.tgt_seq_length - 2 ): if stats is not None: stats.update(FilterTooLongStats())