diff --git a/finetune/train.py b/finetune/train.py index 1fc535c..34408dd 100755 --- a/finetune/train.py +++ b/finetune/train.py @@ -184,16 +184,21 @@ def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: ) """ -Manually calculate the accuracy, f1, matthews_correlation, precision, recall with sklearn. -""" -def calculate_metric_with_sklearn(logits: np.ndarray, labels: np.ndarray): +Compute metrics used for huggingface trainer. +""" +def compute_metrics(eval_pred): + logits = eval_pred.predictions[0] # predictions[1] is label pass through preprocess_logits_for_metrics() + labels = eval_pred.label_ids + if logits.ndim == 3: # Reshape logits to 2D if needed logits = logits.reshape(-1, logits.shape[-1]) + predictions = np.argmax(logits, axis=-1) valid_mask = labels != -100 # Exclude padding tokens (assuming -100 is the padding token ID) valid_predictions = predictions[valid_mask] valid_labels = labels[valid_mask] + return { "accuracy": sklearn.metrics.accuracy_score(valid_labels, valid_predictions), "f1": sklearn.metrics.f1_score( @@ -207,19 +212,22 @@ def calculate_metric_with_sklearn(logits: np.ndarray, labels: np.ndarray): ), "recall": sklearn.metrics.recall_score( valid_labels, valid_predictions, average="macro", zero_division=0 - ), + ) } - """ -Compute metrics used for huggingface trainer. +Fix memory problem in eval_step when using compute_metrics """ -def compute_metrics(eval_pred): - logits, labels = eval_pred - if isinstance(logits, tuple): # Unpack logits if it's a tuple - logits = logits[0] - return calculate_metric_with_sklearn(logits, labels) +def preprocess_logits_for_metrics(logits, labels): + """ + Original Trainer may have a memory leak. + This is a workaround to avoid storing too many tensors that are not needed. + """ + # logits = (model output logits, dnabert output) + # model output logits : (bs, 2) + # dnabert output : (bs, 452, 768) + return logits[0], labels def train(): @@ -281,7 +289,8 @@ def train(): compute_metrics=compute_metrics, train_dataset=train_dataset, eval_dataset=val_dataset, - data_collator=data_collator) + data_collator=data_collator, + preprocess_logits_for_metrics=preprocess_logits_for_metrics) trainer.train() if training_args.save_model: