From 14732cf3203b143bdcd331e4b5fc0b286699c02f Mon Sep 17 00:00:00 2001 From: Nithin Rao Date: Fri, 9 Dec 2022 03:26:19 +0530 Subject: [PATCH] Update torchmetrics (#5566) * add task arg Signed-off-by: nithinraok * update state Signed-off-by: nithinraok Signed-off-by: nithinraok Co-authored-by: Taejin Park --- nemo/collections/asr/models/label_models.py | 4 ++-- requirements/requirements_lightning.txt | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/nemo/collections/asr/models/label_models.py b/nemo/collections/asr/models/label_models.py index 94048e5ab0c4..62e4dd5e456f 100644 --- a/nemo/collections/asr/models/label_models.py +++ b/nemo/collections/asr/models/label_models.py @@ -165,7 +165,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): self.encoder = EncDecSpeakerLabelModel.from_config_dict(cfg.encoder) self.decoder = EncDecSpeakerLabelModel.from_config_dict(cfg.decoder) - self._macro_accuracy = Accuracy(num_classes=num_classes, average='macro') + self._macro_accuracy = Accuracy(num_classes=num_classes, average='macro', task='multiclass') self.labels = None if hasattr(self._cfg, 'spec_augment') and self._cfg.spec_augment is not None: @@ -365,7 +365,7 @@ def evaluation_step(self, batch, batch_idx, dataloader_idx: int = 0, tag: str = acc_top_k = self._accuracy(logits=logits, labels=labels) correct_counts, total_counts = self._accuracy.correct_counts_k, self._accuracy.total_counts_k self._macro_accuracy.update(preds=logits, target=labels) - stats = self._macro_accuracy._get_final_stats() + stats = self._macro_accuracy._final_state() return { f'{tag}_loss': loss_value, diff --git a/requirements/requirements_lightning.txt b/requirements/requirements_lightning.txt index 9876002b806d..faa23c758746 100644 --- a/requirements/requirements_lightning.txt +++ b/requirements/requirements_lightning.txt @@ -2,7 +2,7 @@ hydra-core>=1.2.0,<1.3 omegaconf>=2.2,<2.3 pytorch-lightning>=1.8.3 pyyaml<6 # Pinned until omegaconf works with pyyaml>=6 -torchmetrics>=0.4.1rc0,<=0.10.3 +torchmetrics transformers>=4.0.1,<=4.21.2 wandb webdataset>=0.1.48,<=0.1.62