From 1b89c6dac4ed4ad472917e0cd873733ba1ed4fdd Mon Sep 17 00:00:00 2001 From: Your Name <> Date: Mon, 28 Oct 2024 22:58:40 -0700 Subject: [PATCH] skipping batch counts hurts performance --- egs/librispeech/SSL/hubert/finetune.py | 2 +- egs/librispeech/SSL/hubert/finetune_ce.py | 2 +- egs/librispeech/SSL/zipformer/finetune.py | 2 +- egs/librispeech/SSL/zipformer_ctc/finetune_ctc.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/egs/librispeech/SSL/hubert/finetune.py b/egs/librispeech/SSL/hubert/finetune.py index 582771dee8..05b942f632 100755 --- a/egs/librispeech/SSL/hubert/finetune.py +++ b/egs/librispeech/SSL/hubert/finetune.py @@ -99,7 +99,7 @@ def get_adjusted_batch_count(params: AttributeDict) -> float: * params.accum_grad * (params.max_duration * params.world_size) / params.ref_duration - ) + 100000 + ) def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: diff --git a/egs/librispeech/SSL/hubert/finetune_ce.py b/egs/librispeech/SSL/hubert/finetune_ce.py index cec42ea123..1081313f1d 100755 --- a/egs/librispeech/SSL/hubert/finetune_ce.py +++ b/egs/librispeech/SSL/hubert/finetune_ce.py @@ -99,7 +99,7 @@ def get_adjusted_batch_count(params: AttributeDict) -> float: * params.accum_grad * (params.max_duration * params.world_size) / params.ref_duration - ) + 100000 + ) def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: diff --git a/egs/librispeech/SSL/zipformer/finetune.py b/egs/librispeech/SSL/zipformer/finetune.py index 336c358136..2e521f1772 100755 --- a/egs/librispeech/SSL/zipformer/finetune.py +++ b/egs/librispeech/SSL/zipformer/finetune.py @@ -99,7 +99,7 @@ def get_adjusted_batch_count(params: AttributeDict) -> float: * params.accum_grad * (params.max_duration * params.world_size) / params.ref_duration - ) + 100000 + ) def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: diff --git a/egs/librispeech/SSL/zipformer_ctc/finetune_ctc.py b/egs/librispeech/SSL/zipformer_ctc/finetune_ctc.py index 572e040bd9..d5dd8d71f4 100755 --- a/egs/librispeech/SSL/zipformer_ctc/finetune_ctc.py +++ b/egs/librispeech/SSL/zipformer_ctc/finetune_ctc.py @@ -93,7 +93,7 @@ def get_adjusted_batch_count(params: AttributeDict) -> float: * params.accum_grad * (params.max_duration * params.world_size) / params.ref_duration - ) + 100000 + ) def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: