Skip to content

Commit

Permalink
runtime (#58)
Browse files Browse the repository at this point in the history
* Delete egs_new.py

* Delete lr_scheduler_new.py

* Delete trainer_new.py

* Update trainer_online.py

* Update runEcapaXvector_online.py

* Update plda_base.py

* Update plda_base.py

* Update scoreSets.sh

* Update extract_embeddings_new.py

* Update README.md

* Update README.md

* Update components.py

* Update make_voxceleb1_v2.pl

* Update make_voxceleb1_v2.pl

* Update make_voxceleb2.pl

* runtime
  • Loading branch information
sssyousen authored Nov 13, 2022
1 parent 0e23ae0 commit ea4db02
Show file tree
Hide file tree
Showing 73 changed files with 4,624 additions and 6 deletions.
6 changes: 3 additions & 3 deletions recipe/voxceleb/prepare/make_voxceleb1_v2.pl
Original file line number Diff line number Diff line change
Expand Up @@ -116,10 +116,10 @@
}

if (system(
"utils/utt2spk_to_spk2utt.pl $out_dir/utt2spk >$out_dir/spk2utt") != 0) {
"subtools/kaldi/utils/utt2spk_to_spk2utt.pl $out_dir/utt2spk >$out_dir/spk2utt") != 0) {
die "Error creating spk2utt file in directory $out_dir";
}
system("env LC_COLLATE=C utils/fix_data_dir.sh $out_dir");
if (system("env LC_COLLATE=C utils/validate_data_dir.sh --no-text --no-feats $out_dir") != 0) {
system("env LC_COLLATE=C subtools/kaldi/utils/fix_data_dir.sh $out_dir");
if (system("env LC_COLLATE=C subtools/kaldi/utils/validate_data_dir.sh --no-text --no-feats $out_dir") != 0) {
die "Error validating directory $out_dir";
}
6 changes: 3 additions & 3 deletions recipe/voxceleb/prepare/make_voxceleb2.pl
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,10 @@
close(WAV) or die;

if (system(
"utils/utt2spk_to_spk2utt.pl $out_dir/utt2spk >$out_dir/spk2utt") != 0) {
"subtools/kaldi/utils/utt2spk_to_spk2utt.pl $out_dir/utt2spk >$out_dir/spk2utt") != 0) {
die "Error creating spk2utt file in directory $out_dir";
}
system("env LC_COLLATE=C utils/fix_data_dir.sh $out_dir");
if (system("env LC_COLLATE=C utils/validate_data_dir.sh --no-text --no-feats $out_dir") != 0) {
system("env LC_COLLATE=C subtools/kaldi/utils/fix_data_dir.sh $out_dir");
if (system("env LC_COLLATE=C subtools/kaldi/utils/validate_data_dir.sh --no-text --no-feats $out_dir") != 0) {
die "Error validating directory $out_dir";
}
66 changes: 66 additions & 0 deletions runtime/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
cmake_minimum_required(VERSION 3.14 FATAL_ERROR)

project(subtools_jit VERSION 0.1)


set(CMAKE_VERBOSE_MAKEFILE on)
option(CXX11_ABI "whether to use CXX11_ABI libtorch" OFF)


include(FetchContent)
include(ExternalProject)
set(FETCHCONTENT_QUIET off)
get_filename_component(fc_base "fc_base" REALPATH BASE_DIR "${CMAKE_SOURCE_DIR}")
set(FETCHCONTENT_BASE_DIR ${fc_base})



list(APPEND CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cmake)

set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++14 -pthread -fPIC")




include_directories(
${CMAKE_CURRENT_SOURCE_DIR}
)


include(libtorch)
include(yaml)
# include(gtest)
include(gflags)
include(glog)


add_subdirectory(kaldifeat/csrc)

# utils
add_library(utils STATIC
utils/utils.cc utils/options.cc utils/string.cc
)
target_link_libraries(utils PUBLIC gflags glog yaml-cpp)


# frontend
add_library(frontend STATIC
frontend/feature_pipeline.cc frontend/features.cc
)
target_link_libraries(frontend PUBLIC utils kaldifeat_core)

# extractor
add_library(extractor STATIC
extractor/torch_asv_extractor.cc
extractor/torch_asv_model.cc
)
target_link_libraries(extractor PUBLIC ${TORCH_LIBRARIES} utils)

# binary
add_executable(extractor_main bin/extractor_main.cc)
target_link_libraries(extractor_main PUBLIC utils frontend extractor)





81 changes: 81 additions & 0 deletions runtime/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# ASV Runtime
This Runtime is based on LibTorch. Models trained by Pytorch can be converted to TorchScrip by torch JIT, then employed in C++ applications.
* Fearture extraction: [Kaldifeat](https://github.com/csukuangfj/kaldifeat)
* Inference: LibTorch

<div align='center'>
<img src="../doc/runtime_deploy.png" width=50% height=50%/>
</div>

### Local build

Build extractor_main. cmake > 3.14, gcc > 5.

```bash
mkdir build && cd build && cmake .. && cmake --build .
```

```
# project
.
|-- CMakeLists.txt
|-- README.md
|-- cmake
|-- bin
| `-- extractor_main.cc
|-- build
| |-- extractor_main
| `
|-- fc_base
|-- test
| |-- wav
| |-- gen_jit.py
| |-- wav.scp
| |-- feat_conf.yaml
| `-- test.sh
`
```
### Evaluate RTF

The follow script will generate jit models and extract xvectors of wavs in `./test/wav`
```bash
cd ./test
python3 gen_jit.py
./test.sh
```
### Construct your own directory

1. export your model.

Go to your project directory which contains subtools.

```shell
model_dir=exp/resnet34 # directory of your model
epoch=4 # model_checpoint eg. epoch=18

subtools/pytorch/pipeline/export_jit_model.sh --model-dir $model_dir --model-file $epoch.params \
--output-file-name $epoch.pt \
--output_quant_name ${epoch}_quant.pt
```
2. Back to `test` of this directory, and move the model and feats config to your runtime directory,
```
# project
.
|-- model
| |--resnet34
| ` |--4.pt
| |--config
| ` |--feat_conf.yaml
|-- wav `
|-- gen_jit.py
|-- wav.scp
|-- feat_conf.yaml
|-- test.sh
`

3. execute.
```bash
./test.sh --model_path ./model/resnet34/4.pt \
--feat_conf ./model/resnet34/config/feat_conf.yaml
```
Thansk to [Wenet](https://github.com/wenet-e2e/wenet/tree/main/runtime) project for their contribution of production works.
122 changes: 122 additions & 0 deletions runtime/bin/extractor_main.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
#include <iomanip>
#include <utility>
#include <iostream>
// #include "torch/script.h"
#include "extractor/params.h"
#include "frontend/wav.h"
#include "utils/timer.h"
#include "utils/utils.h"
#include "utils/string.h"

DEFINE_string(wav_path, "", "single wave path");
DEFINE_string(wav_scp, "", "input wav scp");
DEFINE_int32(warmup, 20, "num of warmup decode, 0 means no warmup");

int main(int argc, char *argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]);
google::SetLogDestination(google::GLOG_INFO,"./test.1.log");

auto feature_config = subtools::InitFeaturePiplineConfigFromFlags();
auto extractor_config = subtools::InitExtractOptionsFromFlags();
auto asv_model = subtools::InitTorchAsvModel();
if (FLAGS_wav_path.empty() && FLAGS_wav_scp.empty()) {
LOG(FATAL) << "Please provide the wave path or the wav scp.";
}
std::vector<std::pair<std::string, std::string>> waves;
if (!FLAGS_wav_path.empty()) {
waves.emplace_back(make_pair("test", FLAGS_wav_path));
} else {
std::ifstream wav_scp(FLAGS_wav_scp);
std::string line;
while (getline(wav_scp, line)) {
std::vector<std::string> strs;
subtools::SplitString(line, &strs);
CHECK_GE(strs.size(), 2);
waves.emplace_back(make_pair(strs[0], strs[1]));
}
}

// Warmup
if (FLAGS_warmup > 0) {
LOG(INFO) << "Warming up...";

auto wav = waves[0];
for (int i = 0; i < FLAGS_warmup; i++) {
subtools::WavReader wav_reader(wav.second);
// CHECK_EQ(wav_reader.sample_rate(), FLAGS_sample_rate);
auto feature_pipeline =
std::make_shared<subtools::FeaturePipeline>(*feature_config);
subtools::Timer timer;
feature_pipeline->AcceptWaveform(std::vector<float>(
wav_reader.data(), wav_reader.data() + wav_reader.num_sample()));
feature_pipeline->set_input_finished();
if(i==0){
LOG(INFO) << "make features for "<< feature_pipeline->num_frames()
<< " frames takes " <<timer.Elapsed() << "ms.";
}

subtools::TorchAsvExtractor extractor(feature_pipeline, asv_model,
*extractor_config);

int wave_dur =
static_cast<int>(static_cast<float>(wav_reader.num_sample()) /
wav_reader.sample_rate() * 1000);

timer.Reset();
extractor.Extract();
int extract_time = timer.Elapsed();
LOG(INFO) << " Warmup RTF " << static_cast<float>(extract_time) / wave_dur
<< "ms.";
LOG(INFO) << " Warmup num " << i+1 << " Done! " << std::endl;
}


LOG(INFO) << "Warmup done.";
}

int total_waves_dur = 0;
int total_extract_time = 0;
for (auto &wav : waves) {
LOG(INFO) << wav.first << " Start! " << std::endl;
subtools::WavReader wav_reader(wav.second);
// CHECK_EQ(wav_reader.sample_rate(), FLAGS_sample_rate);
auto feature_pipeline =
std::make_shared<subtools::FeaturePipeline>(*feature_config);
subtools::Timer timer;
feature_pipeline->AcceptWaveform(std::vector<float>(
wav_reader.data(), wav_reader.data() + wav_reader.num_sample()));
feature_pipeline->set_input_finished();

LOG(INFO) << "num frames " << feature_pipeline->num_frames();
LOG(INFO) << "make features for "<< feature_pipeline->num_frames()
<< " frames takes " <<timer.Elapsed() << "ms.";

subtools::TorchAsvExtractor extractor(feature_pipeline, asv_model,
*extractor_config);

int wave_dur =
static_cast<int>(static_cast<float>(wav_reader.num_sample()) /
wav_reader.sample_rate() * 1000);

timer.Reset();
extractor.Extract();
int extract_time = timer.Elapsed();

torch::Tensor result = extractor.result();
std::cout<<result;
LOG(INFO) << "extracted xvector of " << wave_dur << "ms audio taken " << extract_time
<< "ms.";
LOG(INFO) <<"RTF: "<< static_cast<float>(extract_time) / wave_dur;
LOG(INFO) << wav.first << " Done! " << std::endl;
total_waves_dur += wave_dur;
total_extract_time += extract_time;
}
LOG(INFO) << "Total: processed " << total_waves_dur << "ms audio taken "
<< total_extract_time << "ms.";
LOG(INFO) << "RTF: " << std::setprecision(4)
<< static_cast<float>(total_extract_time) / total_waves_dur;
google::ShutdownGoogleLogging();
return 0;

}
8 changes: 8 additions & 0 deletions runtime/cmake/boost.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# third_party: boost
FetchContent_Declare(boost
# URL https://boostorg.jfrog.io/artifactory/main/release/1.75.0/source/boost_1_75_0.tar.gz
URL /work/kaldi/egs/ldx/voxceleb/require_run/boost_1_75_0.tar.gz
URL_HASH SHA256=aeb26f80e80945e82ee93e5939baebdca47b9dee80a07d3144be1e1a6a66dd6a
)
FetchContent_MakeAvailable(boost)
include_directories(${boost_SOURCE_DIR})
8 changes: 8 additions & 0 deletions runtime/cmake/gflags.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# third_party: gflags
FetchContent_Declare(gflags
URL https://github.com/gflags/gflags/archive/v2.2.1.zip
URL_HASH SHA256=4e44b69e709c826734dbbbd5208f61888a2faf63f239d73d8ba0011b2dccc97a
)
FetchContent_MakeAvailable(gflags)
# find_package(gflags REQUIRED)
include_directories(${gflags_BINARY_DIR}/include)
7 changes: 7 additions & 0 deletions runtime/cmake/glog.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# third_party: glog
FetchContent_Declare(glog
URL https://github.com/google/glog/archive/v0.4.0.zip
URL_HASH SHA256=9e1b54eb2782f53cd8af107ecf08d2ab64b8d0dc2b7f5594472f3bd63ca85cdc
)
FetchContent_MakeAvailable(glog)
include_directories(${glog_SOURCE_DIR}/src ${glog_BINARY_DIR})
11 changes: 11 additions & 0 deletions runtime/cmake/gtest.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@

# third_party: gtest
FetchContent_Declare(googletest
URL https://github.com/google/googletest/archive/release-1.10.0.zip
URL_HASH SHA256=94c634d499558a76fa649edb13721dce6e98fb1e7018dfaeba3cd7a083945e91
)

if(MSVC)
set(gtest_force_shared_crt ON CACHE BOOL "Always use msvcrt.dll" FORCE)
endif()
FetchContent_MakeAvailable(googletest)
32 changes: 32 additions & 0 deletions runtime/cmake/libtorch.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# third_party: libtorch use FetchContent_Declare to download, and
# use find_package to find since libtorch is not a standard cmake project
set(PYTORCH_VERSION "1.10.0")
if(${CMAKE_SYSTEM_NAME} STREQUAL "Windows")
set(LIBTORCH_URL "https://download.pytorch.org/libtorch/cpu/libtorch-win-shared-with-deps-${PYTORCH_VERSION}%2Bcpu.zip")
# set(URL_HASH "SHA256=d7043b7d7bdb5463e5027c896ac21b83257c32c533427d4d0d7b251548db8f4b")
set(CMAKE_BUILD_TYPE "Release")
elseif(${CMAKE_SYSTEM_NAME} STREQUAL "Linux")
if(CXX11_ABI)
set(LIBTORCH_URL "https://download.pytorch.org/libtorch/cpu/libtorch-cxx11-abi-shared-with-deps-${PYTORCH_VERSION}%2Bcpu.zip")
# set(URL_HASH "SHA256=6d7be1073d1bd76f6563572b2aa5548ad51d5bc241d6895e3181b7dc25554426")
else()
set(LIBTORCH_URL "https://download.pytorch.org/libtorch/cpu/libtorch-shared-with-deps-${PYTORCH_VERSION}%2Bcpu.zip")
# set(URL_HASH "SHA256=6d7be1073d1bd76f6563572b2aa5548ad51d5bc241d6895e3181b7dc25554426")
endif()
elseif(${CMAKE_SYSTEM_NAME} STREQUAL "Darwin")
set(LIBTORCH_URL "https://download.pytorch.org/libtorch/cpu/libtorch-macos-${PYTORCH_VERSION}.zip")
# set(URL_HASH "SHA256=07cac2c36c34f13065cb9559ad5270109ecbb468252fb0aeccfd89322322a2b5")
else()
message(FATAL_ERROR "Unsupported CMake System Name '${CMAKE_SYSTEM_NAME}' (expected 'Windows', 'Linux' or 'Darwin')")
endif()

FetchContent_Declare(libtorch
URL ${LIBTORCH_URL}
# URL_HASH ${URL_HASH}
)

FetchContent_MakeAvailable(libtorch)
find_package(Torch REQUIRED PATHS ${libtorch_SOURCE_DIR} NO_DEFAULT_PATH)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS} -DC10_USE_GLOG")


17 changes: 17 additions & 0 deletions runtime/cmake/yaml.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# third_party: yaml

set(YAML_CPP_BUILD_TESTS OFF CACHE BOOL "build test")
FetchContent_Declare(yaml
URL https://github.com/jbeder/yaml-cpp/archive/refs/tags/yaml-cpp-0.6.3.zip
URL_HASH SHA256=7c0ddc08a99655508ae110ba48726c67e4a10b290c214aed866ce4bbcbe3e84c
)
FetchContent_MakeAvailable(yaml)
include_directories(${yaml_SOURCE_DIR}/include ${yaml_BINARY_DIR})

# FetchContent_GetProperties(yaml)
# if(NOT yaml-cpp_POPULATED)
# message(STATUS "Fetching yaml-cpp...")
# FetchContent_Populate(yaml)
# add_subdirectory(${yaml_SOURCE_DIR} ${yaml_BINARY_DIR})
# endif()
# include_directories(${yaml_SOURCE_DIR}/include)
Loading

0 comments on commit ea4db02

Please sign in to comment.