Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix torchscript export to use tokens.txt instead of lang_dir #1475

Merged
merged 10 commits into from
Jan 26, 2024
25 changes: 16 additions & 9 deletions egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/export.py
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
Expand All @@ -20,7 +21,7 @@
Usage:
./pruned_transducer_stateless2/export.py \
--exp-dir ./pruned_transducer_stateless2/exp \
--lang-dir data/lang_char \
--tokens data/lang_char/tokens.txt \
--epoch 29 \
--avg 19

Expand All @@ -45,12 +46,13 @@
import logging
from pathlib import Path

import k2
import torch
from scaling_converter import convert_scaled_to_non_scaled
from train import get_params, get_transducer_model

from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.lexicon import Lexicon
from icefall.utils import str2bool
from icefall.utils import num_tokens, str2bool


def get_parser():
Expand Down Expand Up @@ -85,10 +87,10 @@ def get_parser():
)

parser.add_argument(
"--lang-dir",
"--tokens",
type=str,
default="data/lang_char",
help="The lang dir",
default="data/lang_char/tokens.txt",
help="Path to the tokens.txt.",
)

parser.add_argument(
Expand Down Expand Up @@ -122,10 +124,14 @@ def main():

logging.info(f"device: {device}")

lexicon = Lexicon(params.lang_dir)
# Load tokens.txt here
token_table = k2.SymbolTable.from_file(params.tokens)

params.blank_id = 0
params.vocab_size = max(lexicon.tokens) + 1
# Load id of the <blk> token and the vocab size
# <blk> is defined in local/train_bpe_model.py
params.blank_id = token_table["<blk>"]
params.unk_id = token_table["<unk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>

logging.info(params)

Expand All @@ -152,6 +158,7 @@ def main():
model.eval()

if params.jit:
convert_scaled_to_non_scaled(model, inplace=True)
# We won't use the forward() method of the model in C++, so just ignore
# it here.
# Otherwise, one of its arguments is a ragged tensor and is not
Expand Down
24 changes: 14 additions & 10 deletions egs/gigaspeech/ASR/pruned_transducer_stateless2/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
Usage:
./pruned_transducer_stateless2/export.py \
--exp-dir ./pruned_transducer_stateless2/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--tokens ./data/lang_bpe_500/tokens.txt \
--epoch 20 \
--avg 10

Expand All @@ -47,12 +47,13 @@
import logging
from pathlib import Path

import sentencepiece as spm
import k2
import torch
from scaling_converter import convert_scaled_to_non_scaled
from train import get_params, get_transducer_model

from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
from icefall.utils import str2bool
from icefall.utils import num_tokens, str2bool


def get_parser():
Expand Down Expand Up @@ -98,10 +99,10 @@ def get_parser():
)

parser.add_argument(
"--bpe-model",
"--tokens",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
default="data/lang_bpe_500/tokens.txt",
help="Path to the tokens.txt.",
)

parser.add_argument(
Expand Down Expand Up @@ -135,12 +136,14 @@ def main():

logging.info(f"device: {device}")

sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# Load tokens.txt here
token_table = k2.SymbolTable.from_file(params.tokens)

# Load id of the <blk> token and the vocab size
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
params.blank_id = token_table["<blk>"]
params.unk_id = token_table["<unk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>

logging.info(params)

Expand Down Expand Up @@ -183,6 +186,7 @@ def main():
model.eval()

if params.jit:
convert_scaled_to_non_scaled(model, inplace=True)
# We won't use the forward() method of the model in C++, so just ignore
# it here.
# Otherwise, one of its arguments is a ragged tensor and is not
Expand Down
1 change: 1 addition & 0 deletions egs/gigaspeech/ASR/pruned_transducer_stateless2/lstmp.py
7 changes: 3 additions & 4 deletions egs/librispeech/ASR/lstm_transducer_stateless2/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,10 +218,9 @@ def export_decoder_model_jit_trace(
decoder_filename:
The filename to save the exported model.
"""
y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64)
need_pad = torch.tensor([False])

traced_model = torch.jit.trace(decoder_model, (y, need_pad))
# TODO(fangjun): Change the function name since we are actually using
# torch.jit.script instead of torch.jit.trace
traced_model = torch.jit.script(decoder_model)
traced_model.save(decoder_filename)
logging.info(f"Saved to {decoder_filename}")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ def main():

# Load id of the <blk> token and the vocab size
params.blank_id = token_table["<blk>"]
params.unk_id = token_table["<unk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>

logging.info(params)
Expand Down
2 changes: 1 addition & 1 deletion egs/librispeech/ASR/pruned_transducer_stateless/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor:
Returns:
Return a tensor of shape (N, U, embedding_dim).
"""
embedding_out = self.embedding(y)
embedding_out = self.embedding(y.clamp(min=0)) * (y >= 0).unsqueeze(-1)
if self.context_size > 1:
embedding_out = embedding_out.permute(0, 2, 1)
if need_pad is True:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

./pruned_transducer_stateless7_streaming/export.py \
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--epoch 30 \
--avg 9 \
--jit 1
Expand All @@ -45,7 +45,7 @@

./pruned_transducer_stateless7_streaming/export.py \
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--epoch 20 \
--avg 10

Expand Down Expand Up @@ -87,7 +87,7 @@
ln -s pretrained.pt epoch-999.pt
./pruned_transducer_stateless7_streaming/export.py \
--exp-dir ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--use-averaged-model False \
--epoch 999 \
--avg 1 \
Expand All @@ -113,7 +113,7 @@
ln -s pretrained.pt epoch-999.pt
./pruned_transducer_stateless7_streaming/export.py \
--exp-dir ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--use-averaged-model False \
--epoch 999 \
--avg 1 \
Expand Down
31 changes: 15 additions & 16 deletions egs/tal_csasr/ASR/pruned_transducer_stateless5/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
Usage:
./pruned_transducer_stateless5/export.py \
--exp-dir ./pruned_transducer_stateless5/exp \
--lang-dir ./data/lang_char \
--tokens ./data/lang_char/tokens.txt \
--epoch 30 \
--avg 24 \
--use-averaged-model True
Expand All @@ -50,8 +50,9 @@
import logging
from pathlib import Path

import sentencepiece as spm
import k2
import torch
from scaling_converter import convert_scaled_to_non_scaled
from train import add_model_arguments, get_params, get_transducer_model

from icefall.checkpoint import (
Expand All @@ -60,8 +61,7 @@
find_checkpoints,
load_checkpoint,
)
from icefall.lexicon import Lexicon
from icefall.utils import str2bool
from icefall.utils import num_tokens, str2bool


def get_parser():
Expand Down Expand Up @@ -118,13 +118,10 @@ def get_parser():
)

parser.add_argument(
"--lang-dir",
"--tokens",
type=str,
default="data/lang_char",
help="""The lang dir
It contains language related input files such as
"lexicon.txt"
""",
default="data/lang_char/tokens.txt",
help="Path to the tokens.txt.",
)

parser.add_argument(
Expand Down Expand Up @@ -160,13 +157,14 @@ def main():

logging.info(f"device: {device}")

bpe_model = params.lang_dir + "/bpe.model"
sp = spm.SentencePieceProcessor()
sp.load(bpe_model)
# Load tokens.txt here
token_table = k2.SymbolTable.from_file(params.tokens)

lexicon = Lexicon(params.lang_dir)
params.blank_id = lexicon.token_table["<blk>"]
params.vocab_size = max(lexicon.tokens) + 1
# Load id of the <blk> token and the vocab size
# <blk> is defined in local/train_bpe_model.py
params.blank_id = token_table["<blk>"]
params.unk_id = token_table["<unk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>

logging.info(params)

Expand Down Expand Up @@ -256,6 +254,7 @@ def main():
model.eval()

if params.jit:
convert_scaled_to_non_scaled(model, inplace=True)
# We won't use the forward() method of the model in C++, so just ignore
# it here.
# Otherwise, one of its arguments is a ragged tensor and is not
Expand Down
1 change: 1 addition & 0 deletions egs/tal_csasr/ASR/pruned_transducer_stateless5/lstmp.py
23 changes: 11 additions & 12 deletions egs/wenetspeech/ASR/pruned_transducer_stateless2/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

./pruned_transducer_stateless2/export.py \
--exp-dir ./pruned_transducer_stateless2/exp \
--lang-dir data/lang_char \
--tokens data/lang_char/tokens.txt \
--epoch 10 \
--avg 2 \
--jit 1
Expand All @@ -47,7 +47,7 @@

./pruned_transducer_stateless2/export.py \
--exp-dir ./pruned_transducer_stateless2/exp \
--lang-dir data/lang_char \
--tokens data/lang_char/tokens.txt \
--epoch 10 \
--avg 2 \
--jit-trace 1
Expand All @@ -63,7 +63,7 @@

./pruned_transducer_stateless2/export.py \
--exp-dir ./pruned_transducer_stateless2/exp \
--lang-dir data/lang_char \
--tokens data/lang_char/tokens.txt \
--epoch 10 \
--avg 2

Expand Down Expand Up @@ -91,14 +91,14 @@
import logging
from pathlib import Path

import k2
import torch
import torch.nn as nn
from scaling_converter import convert_scaled_to_non_scaled
from train import get_params, get_transducer_model

from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.lexicon import Lexicon
from icefall.utils import str2bool
from icefall.utils import num_tokens, str2bool


def get_parser():
Expand Down Expand Up @@ -133,10 +133,10 @@ def get_parser():
)

parser.add_argument(
"--lang-dir",
"--tokens",
type=str,
default="data/lang_char",
help="The lang dir",
default="data/lang_char/tokens.txt",
help="Path to the tokens.txt",
)

parser.add_argument(
Expand Down Expand Up @@ -313,10 +313,9 @@ def main():

logging.info(f"device: {device}")

lexicon = Lexicon(params.lang_dir)

params.blank_id = 0
params.vocab_size = max(lexicon.tokens) + 1
token_table = k2.SymbolTable.from_file(params.tokens)
params.blank_id = token_table["<blk>"]
params.vocab_size = num_tokens(token_table) + 1

logging.info(params)

Expand Down
Loading
Loading