Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

docs: "TorchMetrics in PyTorch Lightning" enhancement #2925

Open
adosar opened this issue Jan 30, 2025 · 1 comment
Open

docs: "TorchMetrics in PyTorch Lightning" enhancement #2925

adosar opened this issue Jan 30, 2025 · 1 comment
Labels
documentation Improvements or additions to documentation

Comments

@adosar
Copy link

adosar commented Jan 30, 2025

The first snippet in the documentation of TorchMetrics in PyTorch Lightning is the following:

class MyModel(LightningModule):

    def __init__(self, num_classes):
        ...
        self.accuracy = torchmetrics.classification.Accuracy(task="multiclass", num_classes=num_classes)

    def training_step(self, batch, batch_idx):
        x, y = batch
        preds = self(x)
        ...
        # log step metric
        self.accuracy(preds, y)
        self.log('train_acc_step', self.accuracy)
        ...

    def on_train_epoch_end(self):
        # log epoch metric
        self.log('train_acc_epoch', self.accuracy)

However, on the Common Pitfalls it is stated that:

Mixing the two logging methods by calling self.log("val", self.metric) in {training|validation|test}_step method and then calling self.log("val", self.metric.compute()) in the corresponding on_{train|validation|test}_epoch_end method.
Because the object is logged in the first case, Lightning will reset the metric before calling the second line leading to errors or nonsense results.

Therefore, isn't the above snippet a "bad" practice, since the metric is reset before we call the second line, i.e. before calling:

self.log('train_acc_epoch', self.accuracy)

Moreover, since the snippet shows how to use a metric within a LightningModule, wouldn't be better to adhere to the Lightning's automatic logging?

class MyModel(LightningModule):

    def __init__(self, num_classes):
        ...
        self.accuracy = torchmetrics.classification.Accuracy(task="multiclass", num_classes=num_classes)

    def training_step(self, batch, batch_idx):
        x, y = batch
        preds = self(x)
        ...
        # log step + epoch metric
        self.accuracy(preds, y)
        self.log('train_acc', self.accuracy, on_epoch=True)
        # Automatically logs at the end of each epoch.
        # Two keys, 'train_acc_step' and 'train_acc_epoch'.
        ...

   # def on_train_epoch_end(self):
   #    # log step metric
   #      self.log('train_acc_epoch', self.accuracy)
@adosar adosar added the documentation Improvements or additions to documentation label Jan 30, 2025
Copy link

Hi! thanks for your contribution!, great first issue!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation Improvements or additions to documentation
Projects
None yet
Development

No branches or pull requests

1 participant