-
Notifications
You must be signed in to change notification settings - Fork 130
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Delete egs_new.py * Delete lr_scheduler_new.py * Delete trainer_new.py * Update trainer_online.py * Update runEcapaXvector_online.py * Update plda_base.py * Update plda_base.py * Update scoreSets.sh * Update extract_embeddings_new.py * Update README.md * Update README.md * Update components.py * Update make_voxceleb1_v2.pl * Update make_voxceleb1_v2.pl * Update make_voxceleb2.pl * runtime
- Loading branch information
Showing
73 changed files
with
4,624 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
cmake_minimum_required(VERSION 3.14 FATAL_ERROR) | ||
|
||
project(subtools_jit VERSION 0.1) | ||
|
||
|
||
set(CMAKE_VERBOSE_MAKEFILE on) | ||
option(CXX11_ABI "whether to use CXX11_ABI libtorch" OFF) | ||
|
||
|
||
include(FetchContent) | ||
include(ExternalProject) | ||
set(FETCHCONTENT_QUIET off) | ||
get_filename_component(fc_base "fc_base" REALPATH BASE_DIR "${CMAKE_SOURCE_DIR}") | ||
set(FETCHCONTENT_BASE_DIR ${fc_base}) | ||
|
||
|
||
|
||
list(APPEND CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cmake) | ||
|
||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++14 -pthread -fPIC") | ||
|
||
|
||
|
||
|
||
include_directories( | ||
${CMAKE_CURRENT_SOURCE_DIR} | ||
) | ||
|
||
|
||
include(libtorch) | ||
include(yaml) | ||
# include(gtest) | ||
include(gflags) | ||
include(glog) | ||
|
||
|
||
add_subdirectory(kaldifeat/csrc) | ||
|
||
# utils | ||
add_library(utils STATIC | ||
utils/utils.cc utils/options.cc utils/string.cc | ||
) | ||
target_link_libraries(utils PUBLIC gflags glog yaml-cpp) | ||
|
||
|
||
# frontend | ||
add_library(frontend STATIC | ||
frontend/feature_pipeline.cc frontend/features.cc | ||
) | ||
target_link_libraries(frontend PUBLIC utils kaldifeat_core) | ||
|
||
# extractor | ||
add_library(extractor STATIC | ||
extractor/torch_asv_extractor.cc | ||
extractor/torch_asv_model.cc | ||
) | ||
target_link_libraries(extractor PUBLIC ${TORCH_LIBRARIES} utils) | ||
|
||
# binary | ||
add_executable(extractor_main bin/extractor_main.cc) | ||
target_link_libraries(extractor_main PUBLIC utils frontend extractor) | ||
|
||
|
||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
# ASV Runtime | ||
This Runtime is based on LibTorch. Models trained by Pytorch can be converted to TorchScrip by torch JIT, then employed in C++ applications. | ||
* Fearture extraction: [Kaldifeat](https://github.com/csukuangfj/kaldifeat) | ||
* Inference: LibTorch | ||
|
||
<div align='center'> | ||
<img src="../doc/runtime_deploy.png" width=50% height=50%/> | ||
</div> | ||
|
||
### Local build | ||
|
||
Build extractor_main. cmake > 3.14, gcc > 5. | ||
|
||
```bash | ||
mkdir build && cd build && cmake .. && cmake --build . | ||
``` | ||
|
||
``` | ||
# project | ||
. | ||
|-- CMakeLists.txt | ||
|-- README.md | ||
|-- cmake | ||
|-- bin | ||
| `-- extractor_main.cc | ||
|-- build | ||
| |-- extractor_main | ||
| ` | ||
|-- fc_base | ||
|-- test | ||
| |-- wav | ||
| |-- gen_jit.py | ||
| |-- wav.scp | ||
| |-- feat_conf.yaml | ||
| `-- test.sh | ||
` | ||
``` | ||
### Evaluate RTF | ||
|
||
The follow script will generate jit models and extract xvectors of wavs in `./test/wav` | ||
```bash | ||
cd ./test | ||
python3 gen_jit.py | ||
./test.sh | ||
``` | ||
### Construct your own directory | ||
|
||
1. export your model. | ||
|
||
Go to your project directory which contains subtools. | ||
|
||
```shell | ||
model_dir=exp/resnet34 # directory of your model | ||
epoch=4 # model_checpoint eg. epoch=18 | ||
|
||
subtools/pytorch/pipeline/export_jit_model.sh --model-dir $model_dir --model-file $epoch.params \ | ||
--output-file-name $epoch.pt \ | ||
--output_quant_name ${epoch}_quant.pt | ||
``` | ||
2. Back to `test` of this directory, and move the model and feats config to your runtime directory, | ||
``` | ||
# project | ||
. | ||
|-- model | ||
| |--resnet34 | ||
| ` |--4.pt | ||
| |--config | ||
| ` |--feat_conf.yaml | ||
|-- wav ` | ||
|-- gen_jit.py | ||
|-- wav.scp | ||
|-- feat_conf.yaml | ||
|-- test.sh | ||
` | ||
|
||
3. execute. | ||
```bash | ||
./test.sh --model_path ./model/resnet34/4.pt \ | ||
--feat_conf ./model/resnet34/config/feat_conf.yaml | ||
``` | ||
Thansk to [Wenet](https://github.com/wenet-e2e/wenet/tree/main/runtime) project for their contribution of production works. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
#include <iomanip> | ||
#include <utility> | ||
#include <iostream> | ||
// #include "torch/script.h" | ||
#include "extractor/params.h" | ||
#include "frontend/wav.h" | ||
#include "utils/timer.h" | ||
#include "utils/utils.h" | ||
#include "utils/string.h" | ||
|
||
DEFINE_string(wav_path, "", "single wave path"); | ||
DEFINE_string(wav_scp, "", "input wav scp"); | ||
DEFINE_int32(warmup, 20, "num of warmup decode, 0 means no warmup"); | ||
|
||
int main(int argc, char *argv[]) { | ||
gflags::ParseCommandLineFlags(&argc, &argv, false); | ||
google::InitGoogleLogging(argv[0]); | ||
google::SetLogDestination(google::GLOG_INFO,"./test.1.log"); | ||
|
||
auto feature_config = subtools::InitFeaturePiplineConfigFromFlags(); | ||
auto extractor_config = subtools::InitExtractOptionsFromFlags(); | ||
auto asv_model = subtools::InitTorchAsvModel(); | ||
if (FLAGS_wav_path.empty() && FLAGS_wav_scp.empty()) { | ||
LOG(FATAL) << "Please provide the wave path or the wav scp."; | ||
} | ||
std::vector<std::pair<std::string, std::string>> waves; | ||
if (!FLAGS_wav_path.empty()) { | ||
waves.emplace_back(make_pair("test", FLAGS_wav_path)); | ||
} else { | ||
std::ifstream wav_scp(FLAGS_wav_scp); | ||
std::string line; | ||
while (getline(wav_scp, line)) { | ||
std::vector<std::string> strs; | ||
subtools::SplitString(line, &strs); | ||
CHECK_GE(strs.size(), 2); | ||
waves.emplace_back(make_pair(strs[0], strs[1])); | ||
} | ||
} | ||
|
||
// Warmup | ||
if (FLAGS_warmup > 0) { | ||
LOG(INFO) << "Warming up..."; | ||
|
||
auto wav = waves[0]; | ||
for (int i = 0; i < FLAGS_warmup; i++) { | ||
subtools::WavReader wav_reader(wav.second); | ||
// CHECK_EQ(wav_reader.sample_rate(), FLAGS_sample_rate); | ||
auto feature_pipeline = | ||
std::make_shared<subtools::FeaturePipeline>(*feature_config); | ||
subtools::Timer timer; | ||
feature_pipeline->AcceptWaveform(std::vector<float>( | ||
wav_reader.data(), wav_reader.data() + wav_reader.num_sample())); | ||
feature_pipeline->set_input_finished(); | ||
if(i==0){ | ||
LOG(INFO) << "make features for "<< feature_pipeline->num_frames() | ||
<< " frames takes " <<timer.Elapsed() << "ms."; | ||
} | ||
|
||
subtools::TorchAsvExtractor extractor(feature_pipeline, asv_model, | ||
*extractor_config); | ||
|
||
int wave_dur = | ||
static_cast<int>(static_cast<float>(wav_reader.num_sample()) / | ||
wav_reader.sample_rate() * 1000); | ||
|
||
timer.Reset(); | ||
extractor.Extract(); | ||
int extract_time = timer.Elapsed(); | ||
LOG(INFO) << " Warmup RTF " << static_cast<float>(extract_time) / wave_dur | ||
<< "ms."; | ||
LOG(INFO) << " Warmup num " << i+1 << " Done! " << std::endl; | ||
} | ||
|
||
|
||
LOG(INFO) << "Warmup done."; | ||
} | ||
|
||
int total_waves_dur = 0; | ||
int total_extract_time = 0; | ||
for (auto &wav : waves) { | ||
LOG(INFO) << wav.first << " Start! " << std::endl; | ||
subtools::WavReader wav_reader(wav.second); | ||
// CHECK_EQ(wav_reader.sample_rate(), FLAGS_sample_rate); | ||
auto feature_pipeline = | ||
std::make_shared<subtools::FeaturePipeline>(*feature_config); | ||
subtools::Timer timer; | ||
feature_pipeline->AcceptWaveform(std::vector<float>( | ||
wav_reader.data(), wav_reader.data() + wav_reader.num_sample())); | ||
feature_pipeline->set_input_finished(); | ||
|
||
LOG(INFO) << "num frames " << feature_pipeline->num_frames(); | ||
LOG(INFO) << "make features for "<< feature_pipeline->num_frames() | ||
<< " frames takes " <<timer.Elapsed() << "ms."; | ||
|
||
subtools::TorchAsvExtractor extractor(feature_pipeline, asv_model, | ||
*extractor_config); | ||
|
||
int wave_dur = | ||
static_cast<int>(static_cast<float>(wav_reader.num_sample()) / | ||
wav_reader.sample_rate() * 1000); | ||
|
||
timer.Reset(); | ||
extractor.Extract(); | ||
int extract_time = timer.Elapsed(); | ||
|
||
torch::Tensor result = extractor.result(); | ||
std::cout<<result; | ||
LOG(INFO) << "extracted xvector of " << wave_dur << "ms audio taken " << extract_time | ||
<< "ms."; | ||
LOG(INFO) <<"RTF: "<< static_cast<float>(extract_time) / wave_dur; | ||
LOG(INFO) << wav.first << " Done! " << std::endl; | ||
total_waves_dur += wave_dur; | ||
total_extract_time += extract_time; | ||
} | ||
LOG(INFO) << "Total: processed " << total_waves_dur << "ms audio taken " | ||
<< total_extract_time << "ms."; | ||
LOG(INFO) << "RTF: " << std::setprecision(4) | ||
<< static_cast<float>(total_extract_time) / total_waves_dur; | ||
google::ShutdownGoogleLogging(); | ||
return 0; | ||
|
||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
# third_party: boost | ||
FetchContent_Declare(boost | ||
# URL https://boostorg.jfrog.io/artifactory/main/release/1.75.0/source/boost_1_75_0.tar.gz | ||
URL /work/kaldi/egs/ldx/voxceleb/require_run/boost_1_75_0.tar.gz | ||
URL_HASH SHA256=aeb26f80e80945e82ee93e5939baebdca47b9dee80a07d3144be1e1a6a66dd6a | ||
) | ||
FetchContent_MakeAvailable(boost) | ||
include_directories(${boost_SOURCE_DIR}) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
# third_party: gflags | ||
FetchContent_Declare(gflags | ||
URL https://github.com/gflags/gflags/archive/v2.2.1.zip | ||
URL_HASH SHA256=4e44b69e709c826734dbbbd5208f61888a2faf63f239d73d8ba0011b2dccc97a | ||
) | ||
FetchContent_MakeAvailable(gflags) | ||
# find_package(gflags REQUIRED) | ||
include_directories(${gflags_BINARY_DIR}/include) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
# third_party: glog | ||
FetchContent_Declare(glog | ||
URL https://github.com/google/glog/archive/v0.4.0.zip | ||
URL_HASH SHA256=9e1b54eb2782f53cd8af107ecf08d2ab64b8d0dc2b7f5594472f3bd63ca85cdc | ||
) | ||
FetchContent_MakeAvailable(glog) | ||
include_directories(${glog_SOURCE_DIR}/src ${glog_BINARY_DIR}) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
|
||
# third_party: gtest | ||
FetchContent_Declare(googletest | ||
URL https://github.com/google/googletest/archive/release-1.10.0.zip | ||
URL_HASH SHA256=94c634d499558a76fa649edb13721dce6e98fb1e7018dfaeba3cd7a083945e91 | ||
) | ||
|
||
if(MSVC) | ||
set(gtest_force_shared_crt ON CACHE BOOL "Always use msvcrt.dll" FORCE) | ||
endif() | ||
FetchContent_MakeAvailable(googletest) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
# third_party: libtorch use FetchContent_Declare to download, and | ||
# use find_package to find since libtorch is not a standard cmake project | ||
set(PYTORCH_VERSION "1.10.0") | ||
if(${CMAKE_SYSTEM_NAME} STREQUAL "Windows") | ||
set(LIBTORCH_URL "https://download.pytorch.org/libtorch/cpu/libtorch-win-shared-with-deps-${PYTORCH_VERSION}%2Bcpu.zip") | ||
# set(URL_HASH "SHA256=d7043b7d7bdb5463e5027c896ac21b83257c32c533427d4d0d7b251548db8f4b") | ||
set(CMAKE_BUILD_TYPE "Release") | ||
elseif(${CMAKE_SYSTEM_NAME} STREQUAL "Linux") | ||
if(CXX11_ABI) | ||
set(LIBTORCH_URL "https://download.pytorch.org/libtorch/cpu/libtorch-cxx11-abi-shared-with-deps-${PYTORCH_VERSION}%2Bcpu.zip") | ||
# set(URL_HASH "SHA256=6d7be1073d1bd76f6563572b2aa5548ad51d5bc241d6895e3181b7dc25554426") | ||
else() | ||
set(LIBTORCH_URL "https://download.pytorch.org/libtorch/cpu/libtorch-shared-with-deps-${PYTORCH_VERSION}%2Bcpu.zip") | ||
# set(URL_HASH "SHA256=6d7be1073d1bd76f6563572b2aa5548ad51d5bc241d6895e3181b7dc25554426") | ||
endif() | ||
elseif(${CMAKE_SYSTEM_NAME} STREQUAL "Darwin") | ||
set(LIBTORCH_URL "https://download.pytorch.org/libtorch/cpu/libtorch-macos-${PYTORCH_VERSION}.zip") | ||
# set(URL_HASH "SHA256=07cac2c36c34f13065cb9559ad5270109ecbb468252fb0aeccfd89322322a2b5") | ||
else() | ||
message(FATAL_ERROR "Unsupported CMake System Name '${CMAKE_SYSTEM_NAME}' (expected 'Windows', 'Linux' or 'Darwin')") | ||
endif() | ||
|
||
FetchContent_Declare(libtorch | ||
URL ${LIBTORCH_URL} | ||
# URL_HASH ${URL_HASH} | ||
) | ||
|
||
FetchContent_MakeAvailable(libtorch) | ||
find_package(Torch REQUIRED PATHS ${libtorch_SOURCE_DIR} NO_DEFAULT_PATH) | ||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS} -DC10_USE_GLOG") | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
# third_party: yaml | ||
|
||
set(YAML_CPP_BUILD_TESTS OFF CACHE BOOL "build test") | ||
FetchContent_Declare(yaml | ||
URL https://github.com/jbeder/yaml-cpp/archive/refs/tags/yaml-cpp-0.6.3.zip | ||
URL_HASH SHA256=7c0ddc08a99655508ae110ba48726c67e4a10b290c214aed866ce4bbcbe3e84c | ||
) | ||
FetchContent_MakeAvailable(yaml) | ||
include_directories(${yaml_SOURCE_DIR}/include ${yaml_BINARY_DIR}) | ||
|
||
# FetchContent_GetProperties(yaml) | ||
# if(NOT yaml-cpp_POPULATED) | ||
# message(STATUS "Fetching yaml-cpp...") | ||
# FetchContent_Populate(yaml) | ||
# add_subdirectory(${yaml_SOURCE_DIR} ${yaml_BINARY_DIR}) | ||
# endif() | ||
# include_directories(${yaml_SOURCE_DIR}/include) |
Oops, something went wrong.