Skip to content

Commit

Permalink
Add CTC HLG decoding for zipformer (#1287)
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj authored Oct 2, 2023
1 parent f14b673 commit 109354b
Show file tree
Hide file tree
Showing 10 changed files with 1,545 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,57 @@ log() {

pushd egs/librispeech/ASR

# repo_url=https://github.com/csukuangfj/icefall-asr-conformer-ctc-bpe-500
repo_url=https://huggingface.co/csukuangfj/sherpa-onnx-zipformer-ctc-en-2023-10-02
log "Downloading pre-trained model from $repo_url"
git lfs install
git clone $repo_url
repo=$(basename $repo_url)

log "Display test files"
tree $repo/
ls -lh $repo/test_wavs/*.wav

log "CTC greedy search"

./zipformer/onnx_pretrained_ctc.py \
--nn-model $repo/model.onnx \
--tokens $repo/tokens.txt \
$repo/test_wavs/0.wav \
$repo/test_wavs/1.wav \
$repo/test_wavs/2.wav

log "CTC H decoding"

./zipformer/onnx_pretrained_ctc_H.py \
--nn-model $repo/model.onnx \
--tokens $repo/tokens.txt \
--H $repo/H.fst \
$repo/test_wavs/0.wav \
$repo/test_wavs/1.wav \
$repo/test_wavs/2.wav

log "CTC HL decoding"

./zipformer/onnx_pretrained_ctc_HL.py \
--nn-model $repo/model.onnx \
--words $repo/words.txt \
--HL $repo/HL.fst \
$repo/test_wavs/0.wav \
$repo/test_wavs/1.wav \
$repo/test_wavs/2.wav

log "CTC HLG decoding"

./zipformer/onnx_pretrained_ctc_HLG.py \
--nn-model $repo/model.onnx \
--words $repo/words.txt \
--HLG $repo/HLG.fst \
$repo/test_wavs/0.wav \
$repo/test_wavs/1.wav \
$repo/test_wavs/2.wav

rm -rf $repo

repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-conformer-ctc-jit-bpe-500-2021-11-09
log "Downloading pre-trained model from $repo_url"
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
Expand Down Expand Up @@ -128,7 +178,9 @@ repo=$(basename $repo_url)
pushd $repo

git lfs pull --include "exp/pretrained.pt"
git lfs pull --include "data/lm/G_3_gram_char.fst.txt"
git lfs pull --include "data/lang_char/H.fst"
git lfs pull --include "data/lang_char/HL.fst"
git lfs pull --include "data/lang_char/HLG.fst"

popd

Expand All @@ -153,10 +205,6 @@ popd

ls -lh $repo/exp

log "Generating H.fst, HL.fst"

./local/prepare_lang_fst.py --lang-dir $repo/data/lang_char --ngram-G $repo/data/lm/G_3_gram_char.fst.txt

ls -lh $repo/data/lang_char

log "Decoding with H on CPU with OpenFst"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

name: run-pre-trained-conformer-ctc
name: run-pre-trained-ctc

on:
push:
Expand All @@ -31,12 +31,12 @@ on:
default: 'y'

concurrency:
group: run_pre_trained_conformer_ctc-${{ github.ref }}
group: run_pre_trained_ctc-${{ github.ref }}
cancel-in-progress: true

jobs:
run_pre_trained_conformer_ctc:
if: github.event.label.name == 'ready' || github.event_name == 'push' || github.event.inputs.test-run == 'y'
run_pre_trained_ctc:
if: github.event.label.name == 'ready' || github.event_name == 'push' || github.event.inputs.test-run == 'y' || github.event.label.name == 'ctc'
runs-on: ${{ matrix.os }}
strategy:
matrix:
Expand Down Expand Up @@ -84,4 +84,4 @@ jobs:
export PYTHONPATH=$PWD:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
.github/scripts/run-pre-trained-conformer-ctc.sh
.github/scripts/run-pre-trained-ctc.sh
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def decode(
decoder.decode(decodable)

if not decoder.reached_final():
print(f"failed to decode {filename}")
logging.info(f"failed to decode {filename}")
return [""]

ok, best_path = decoder.get_best_path()
Expand All @@ -157,7 +157,7 @@ def decode(
total_weight,
) = kaldifst.get_linear_symbol_sequence(best_path)
if not ok:
print(f"failed to get linear symbol sequence for {filename}")
logging.info(f"failed to get linear symbol sequence for {filename}")
return [""]

# tokens are incremented during graph construction
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,8 @@ def decode(
contains output from log_softmax.
HL:
The HL graph.
word2token:
A map mapping token ID to word string.
id2word:
A map mapping word ID to word string.
Returns:
Return a list of decoded words.
"""
Expand All @@ -145,7 +145,7 @@ def decode(
decoder.decode(decodable)

if not decoder.reached_final():
print(f"failed to decode {filename}")
logging.info(f"failed to decode {filename}")
return [""]

ok, best_path = decoder.get_best_path()
Expand All @@ -157,7 +157,7 @@ def decode(
total_weight,
) = kaldifst.get_linear_symbol_sequence(best_path)
if not ok:
print(f"failed to get linear symbol sequence for {filename}")
logging.info(f"failed to get linear symbol sequence for {filename}")
return [""]

# are shifted by 1 during graph construction
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,8 @@ def decode(
contains output from log_softmax.
HLG:
The HLG graph.
word2token:
A map mapping token ID to word string.
id2word:
A map mapping word ID to word string.
Returns:
Return a list of decoded words.
"""
Expand All @@ -144,7 +144,7 @@ def decode(
decoder.decode(decodable)

if not decoder.reached_final():
print(f"failed to decode {filename}")
logging.info(f"failed to decode {filename}")
return [""]

ok, best_path = decoder.get_best_path()
Expand All @@ -156,7 +156,7 @@ def decode(
total_weight,
) = kaldifst.get_linear_symbol_sequence(best_path)
if not ok:
print(f"failed to get linear symbol sequence for {filename}")
logging.info(f"failed to get linear symbol sequence for {filename}")
return [""]

# are shifted by 1 during graph construction
Expand Down
Loading

0 comments on commit 109354b

Please sign in to comment.