From c38be2f02c2f86c6aacc78d9c9fb8e05627be311 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=BD=95=E5=BD=A5=E5=8D=97=20=28yen-nan=20ho=29?= <46644821+aaron1aaron2@users.noreply.github.com> Date: Tue, 2 Apr 2024 17:43:57 +0800 Subject: [PATCH 1/3] Update train.py - fix memory problem --- finetune/train.py | 39 ++++++++++++++++++++++++++++----------- 1 file changed, 28 insertions(+), 11 deletions(-) diff --git a/finetune/train.py b/finetune/train.py index 1fc535c..3440013 100755 --- a/finetune/train.py +++ b/finetune/train.py @@ -184,16 +184,24 @@ def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: ) """ -Manually calculate the accuracy, f1, matthews_correlation, precision, recall with sklearn. -""" -def calculate_metric_with_sklearn(logits: np.ndarray, labels: np.ndarray): +Compute metrics used for huggingface trainer. +""" +def compute_metrics(eval_pred): + logits = eval_pred.predictions[0] # predictions[1] is label pass through preprocess_logits_for_metrics() + labels = eval_pred.label_ids + if logits.ndim == 3: # Reshape logits to 2D if needed logits = logits.reshape(-1, logits.shape[-1]) + predictions = np.argmax(logits, axis=-1) valid_mask = labels != -100 # Exclude padding tokens (assuming -100 is the padding token ID) valid_predictions = predictions[valid_mask] valid_labels = labels[valid_mask] + + pred_prob = softmax(logits, axis=1) + pred_prob = pred_prob[:,[1]].flatten() + return { "accuracy": sklearn.metrics.accuracy_score(valid_labels, valid_predictions), "f1": sklearn.metrics.f1_score( @@ -208,18 +216,27 @@ def calculate_metric_with_sklearn(logits: np.ndarray, labels: np.ndarray): "recall": sklearn.metrics.recall_score( valid_labels, valid_predictions, average="macro", zero_division=0 ), + "auPRCs": sklearn.metrics.average_precision_score( + y_true=valid_labels, y_score=pred_prob + ), + "auROC": sklearn.metrics.roc_auc_score( + y_true=valid_labels, y_score=pred_prob + ) } - """ -Compute metrics used for huggingface trainer. +Fix memory problem in eval_step when using compute_metrics """ -def compute_metrics(eval_pred): - logits, labels = eval_pred - if isinstance(logits, tuple): # Unpack logits if it's a tuple - logits = logits[0] - return calculate_metric_with_sklearn(logits, labels) - +def preprocess_logits_for_metrics(logits, labels): + """ + Original Trainer may have a memory leak. + This is a workaround to avoid storing too many tensors that are not needed. + """ + # logits = (model output logits, dnabert output) + # model output logits : (bs, 2) + # dnabert output : (bs, 452, 768) + + return logits[0], labels def train(): From 7ebf6ae1d90c61a39423dfe72dbaeb90b8a83b2f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=BD=95=E5=BD=A5=E5=8D=97=20=28yen-nan=20ho=29?= <46644821+aaron1aaron2@users.noreply.github.com> Date: Tue, 2 Apr 2024 17:45:13 +0800 Subject: [PATCH 2/3] Update train.py --- finetune/train.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/finetune/train.py b/finetune/train.py index 3440013..e6ff2bb 100755 --- a/finetune/train.py +++ b/finetune/train.py @@ -298,7 +298,8 @@ def train(): compute_metrics=compute_metrics, train_dataset=train_dataset, eval_dataset=val_dataset, - data_collator=data_collator) + data_collator=data_collator, + preprocess_logits_for_metrics=preprocess_logits_for_metrics) trainer.train() if training_args.save_model: From f3eb93094b9c2a9035283db90dceafa6b00b8edb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=BD=95=E5=BD=A5=E5=8D=97=20=28yen-nan=20ho=29?= <46644821+aaron1aaron2@users.noreply.github.com> Date: Tue, 2 Apr 2024 23:53:08 +0800 Subject: [PATCH 3/3] Update train.py --- finetune/train.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/finetune/train.py b/finetune/train.py index e6ff2bb..34408dd 100755 --- a/finetune/train.py +++ b/finetune/train.py @@ -199,9 +199,6 @@ def compute_metrics(eval_pred): valid_predictions = predictions[valid_mask] valid_labels = labels[valid_mask] - pred_prob = softmax(logits, axis=1) - pred_prob = pred_prob[:,[1]].flatten() - return { "accuracy": sklearn.metrics.accuracy_score(valid_labels, valid_predictions), "f1": sklearn.metrics.f1_score( @@ -215,12 +212,6 @@ def compute_metrics(eval_pred): ), "recall": sklearn.metrics.recall_score( valid_labels, valid_predictions, average="macro", zero_division=0 - ), - "auPRCs": sklearn.metrics.average_precision_score( - y_true=valid_labels, y_score=pred_prob - ), - "auROC": sklearn.metrics.roc_auc_score( - y_true=valid_labels, y_score=pred_prob ) }