From c19d2b43afccae95158d0761063c966c23c3b23f Mon Sep 17 00:00:00 2001 From: zr_jin Date: Mon, 9 Dec 2024 20:04:20 +0800 Subject: [PATCH] Update train.py --- egs/ljspeech/TTS/matcha/train.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/egs/ljspeech/TTS/matcha/train.py b/egs/ljspeech/TTS/matcha/train.py index 31135f623b..853042413c 100755 --- a/egs/ljspeech/TTS/matcha/train.py +++ b/egs/ljspeech/TTS/matcha/train.py @@ -488,9 +488,10 @@ def save_bad_model(suffix: str = ""): loss = sum(losses.values()) - optimizer.zero_grad() scaler.scale(loss).backward() scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() loss_info = MetricsTracker() loss_info["samples"] = batch_size