Skip to content

Commit

Permalink
Merge pull request #9 from ros-ai/dev-1.2.0
Browse files Browse the repository at this point in the history
whisper.cpp 1.5.0
  • Loading branch information
mhubii authored Nov 20, 2023
2 parents 30f3e78 + 0d50e53 commit c9d27b0
Show file tree
Hide file tree
Showing 19 changed files with 65 additions and 29 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Changelog for package ROS 2 Whisper
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
1.2.0 (2023-11-19)
------------------
* `whisper_util`: Upgrade to `whisper.cpp` 1.5.0 release https://github.com/ggerganov/whisper.cpp/releases/tag/v1.5.0 (full CUDA backend)

1.1.0 (2023-09-01)
------------------
* `whisper_demos`: Improved terminal output
* `whisper_server`: Improved state machine

Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ colcon build --symlink-install --cmake-args -DWHISPER_CUBLAS=On
## Demos
Run the inference nodes (this will download models to `$HOME/.cache/whisper.cpp`):
```shell
ros2 launch whisper_bringup bringup.launch.py n_thread:=8
ros2 launch whisper_bringup bringup.launch.py n_thread:=4
```
Run a client node (activated on space bar press):
```shell
Expand Down
2 changes: 1 addition & 1 deletion audio_listener/package.xml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
<?xml-model href="http://download.ros.org/schema/package_format3.xsd" schematypens="http://www.w3.org/2001/XMLSchema"?>
<package format="3">
<name>audio_listener</name>
<version>1.1.0</version>
<version>1.2.0</version>
<description>Audio common replica.</description>
<maintainer email="[email protected]">mhubii</maintainer>
<license>MIT</license>
Expand Down
2 changes: 1 addition & 1 deletion audio_listener/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

setup(
name=package_name,
version="1.1.0",
version="1.2.0",
packages=find_packages(exclude=["test"]),
data_files=[
("share/ament_index/resource_index/packages", ["resource/" + package_name]),
Expand Down
2 changes: 2 additions & 0 deletions whisper_bringup/launch/bringup.launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def generate_launch_description() -> LaunchDescription:
ld.add_action(WhisperServerMixin.arg_model_name())
ld.add_action(WhisperServerMixin.arg_n_threads())
ld.add_action(WhisperServerMixin.arg_language())
ld.add_action(WhisperServerMixin.arg_use_gpu())
ld.add_action(WhisperServerMixin.arg_batch_capacity())
ld.add_action(WhisperServerMixin.arg_buffer_capacity())
ld.add_action(WhisperServerMixin.arg_carry_over_capacity())
Expand All @@ -30,6 +31,7 @@ def generate_launch_description() -> LaunchDescription:
WhisperServerMixin.param_model_name(),
WhisperServerMixin.param_n_threads(),
WhisperServerMixin.param_language(),
WhisperServerMixin.param_use_gpu(),
WhisperServerMixin.param_batch_capacity(),
WhisperServerMixin.param_buffer_capacity(),
WhisperServerMixin.param_carry_over_capacity(),
Expand Down
2 changes: 1 addition & 1 deletion whisper_bringup/package.xml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
<?xml-model href="http://download.ros.org/schema/package_format3.xsd" schematypens="http://www.w3.org/2001/XMLSchema"?>
<package format="3">
<name>whisper_bringup</name>
<version>1.1.0</version>
<version>1.2.0</version>
<description>TODO: Package description</description>
<maintainer email="[email protected]">mhubii</maintainer>
<license>MIT</license>
Expand Down
15 changes: 13 additions & 2 deletions whisper_cpp_vendor/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ include(FetchContent)
find_package(ament_cmake REQUIRED)

set(WHISPER_VERSION_MAJOR 1 CACHE STRING "Major whisper.cpp version.")
set(WHISPER_VERSION_MINOR 4 CACHE STRING "Minor whisper.cpp version.")
set(WHISPER_VERSION_PATCH 2 CACHE STRING "Patch whisper.cpp version.")
set(WHISPER_VERSION_MINOR 5 CACHE STRING "Minor whisper.cpp version.")
set(WHISPER_VERSION_PATCH 0 CACHE STRING "Patch whisper.cpp version.")

FetchContent_Declare(
whisper
Expand All @@ -21,12 +21,23 @@ FetchContent_Declare(

FetchContent_MakeAvailable(whisper)

#######################################################################
# note that target properties need change as whisper.cpp CMake is buggy
#######################################################################
set_target_properties(
whisper PROPERTIES
INTERFACE_INCLUDE_DIRECTORIES
$<INSTALL_INTERFACE:.>
)

# install ggml header
install(
FILES ${whisper_SOURCE_DIR}/ggml.h
DESTINATION include
)
##############
# end of fixes
##############

ament_export_targets(export_whisper HAS_LIBRARY_TARGET)

Expand Down
2 changes: 1 addition & 1 deletion whisper_cpp_vendor/package.xml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
<?xml-model href="http://download.ros.org/schema/package_format3.xsd" schematypens="http://www.w3.org/2001/XMLSchema"?>
<package format="3">
<name>whisper_cpp_vendor</name>
<version>1.1.0</version>
<version>1.2.0</version>
<description>Vendor package for whisper.cpp.</description>
<maintainer email="[email protected]">mhubii</maintainer>
<license>MIT</license>
Expand Down
2 changes: 1 addition & 1 deletion whisper_demos/package.xml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
<?xml-model href="http://download.ros.org/schema/package_format3.xsd" schematypens="http://www.w3.org/2001/XMLSchema"?>
<package format="3">
<name>whisper_demos</name>
<version>1.1.0</version>
<version>1.2.0</version>
<description>Demos for using the ROS 2 whisper package.</description>
<maintainer email="[email protected]">mhubii</maintainer>
<license>MIT</license>
Expand Down
2 changes: 1 addition & 1 deletion whisper_demos/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

setup(
name=package_name,
version="1.1.0",
version="1.2.0",
packages=find_packages(exclude=["test"]),
data_files=[
("share/ament_index/resource_index/packages", ["resource/" + package_name]),
Expand Down
2 changes: 1 addition & 1 deletion whisper_msgs/package.xml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
<?xml-model href="http://download.ros.org/schema/package_format3.xsd" schematypens="http://www.w3.org/2001/XMLSchema"?>
<package format="3">
<name>whisper_msgs</name>
<version>1.1.0</version>
<version>1.2.0</version>
<description>Messages for the ROS 2 whisper package</description>
<maintainer email="[email protected]">mhubii</maintainer>
<license>MIT</license>
Expand Down
3 changes: 2 additions & 1 deletion whisper_server/config/whisper.yaml
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
whisper:
ros__parameters:
# whisper
model_name: "tiny.en" # other models https://huggingface.co/ggerganov/whisper.cpp
model_name: "base.en" # other models https://huggingface.co/ggerganov/whisper.cpp
language: "en"
n_threads: 4
print_progress: false
use_gpu: true

# buffer
batch_capacity: 6 # seconds
Expand Down
2 changes: 1 addition & 1 deletion whisper_server/package.xml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
<?xml-model href="http://download.ros.org/schema/package_format3.xsd" schematypens="http://www.w3.org/2001/XMLSchema"?>
<package format="3">
<name>whisper_server</name>
<version>1.1.0</version>
<version>1.2.0</version>
<description>ROS 2 whisper.cpp inference server.</description>
<maintainer email="[email protected]">mhubii</maintainer>
<license>MIT</license>
Expand Down
17 changes: 10 additions & 7 deletions whisper_server/src/inference_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,13 @@ void InferenceNode::declare_parameters_() {
node_ptr_->declare_parameter("carry_over_capacity", 200);

// whisper parameters
node_ptr_->declare_parameter("model_name", "tiny.en");
node_ptr_->declare_parameter("model_name", "base.en");
// consider other parameters:
// https://github.com/ggerganov/whisper.cpp/blob/a4bb2df36aeb4e6cfb0c1ca9fbcf749ef39cc852/whisper.h#L351
node_ptr_->declare_parameter("language", "en");
node_ptr_->declare_parameter("n_threads", 4);
node_ptr_->declare_parameter("print_progress", false);
node_ptr_->declare_parameter("use_gpu", true);
}

void InferenceNode::initialize_whisper_() {
Expand All @@ -71,19 +72,20 @@ void InferenceNode::initialize_whisper_() {
RCLCPP_INFO(node_ptr_->get_logger(), "Model %s initialized.", model_name.c_str());

language_ = node_ptr_->get_parameter("language").as_string();
whisper_->params.language = language_.c_str();
whisper_->params.n_threads = node_ptr_->get_parameter("n_threads").as_int();
whisper_->params.print_progress = node_ptr_->get_parameter("print_progress").as_bool();
whisper_->wparams.language = language_.c_str();
whisper_->wparams.n_threads = node_ptr_->get_parameter("n_threads").as_int();
whisper_->wparams.print_progress = node_ptr_->get_parameter("print_progress").as_bool();
whisper_->cparams.use_gpu = node_ptr_->get_parameter("use_gpu").as_bool();
}

rcl_interfaces::msg::SetParametersResult
InferenceNode::on_parameter_set_(const std::vector<rclcpp::Parameter> &parameters) {
rcl_interfaces::msg::SetParametersResult result;
for (const auto &parameter : parameters) {
if (parameter.get_name() == "n_threads") {
whisper_->params.n_threads = parameter.as_int();
whisper_->wparams.n_threads = parameter.as_int();
RCLCPP_INFO(node_ptr_->get_logger(), "Parameter %s set to %d.", parameter.get_name().c_str(),
whisper_->params.n_threads);
whisper_->wparams.n_threads);
continue;
}
result.reason = "Parameter " + parameter.get_name() + " not handled.";
Expand Down Expand Up @@ -143,7 +145,8 @@ void InferenceNode::on_inference_accepted_(const std::shared_ptr<GoalHandleInfer
goal_handle->publish_feedback(feedback);

// update inference result
if (result->transcriptions.size() == batched_buffer_->batch_idx() + 1) {
if (result->transcriptions.size() ==
static_cast<std::size_t>(batched_buffer_->batch_idx() + 1)) {
result->transcriptions[result->transcriptions.size() - 1] = feedback->transcription;
} else {
result->transcriptions.push_back(feedback->transcription);
Expand Down
17 changes: 15 additions & 2 deletions whisper_server/whisper_server_launch_mixin/whisper_server_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,14 @@ class InferenceMixin:
def arg_model_name() -> DeclareLaunchArgument:
return DeclareLaunchArgument(
name="model_name",
default_value="tiny.en",
default_value="base.en",
description="Model name for whisper.cpp. Refer to https://huggingface.co/ggerganov/whisper.cpp.",
choices=[
"tiny.en",
"tiny",
"tiny.en",
"base",
"base.en",
"small.en",
"small",
"medium.en",
Expand All @@ -44,6 +45,14 @@ def arg_language() -> DeclareLaunchArgument:
choices=["en", "auto"],
)

@staticmethod
def arg_use_gpu() -> DeclareLaunchArgument:
return DeclareLaunchArgument(
name="use_gpu",
default_value="true",
description="Use GPU for inference.",
)

@staticmethod
def arg_batch_capacity() -> DeclareLaunchArgument:
return DeclareLaunchArgument(
Expand All @@ -70,7 +79,7 @@ def arg_carry_over_capacity() -> DeclareLaunchArgument:

@staticmethod
def param_model_name() -> Dict[str, LaunchConfiguration]:
return {"model_name": LaunchConfiguration("model_name", default="tiny.en")}
return {"model_name": LaunchConfiguration("model_name", default="base.en")}

@staticmethod
def param_n_threads() -> Dict[str, LaunchConfiguration]:
Expand All @@ -80,6 +89,10 @@ def param_n_threads() -> Dict[str, LaunchConfiguration]:
def param_language() -> Dict[str, LaunchConfiguration]:
return {"language": LaunchConfiguration("language", default="en")}

@staticmethod
def param_use_gpu() -> Dict[str, LaunchConfiguration]:
return {"use_gpu": LaunchConfiguration("use_gpu", default="true")}

@staticmethod
def param_batch_capacity() -> Dict[str, LaunchConfiguration]:
return {"batch_capacity": LaunchConfiguration("batch_capacity", default="6")}
Expand Down
6 changes: 3 additions & 3 deletions whisper_util/include/whisper_util/model_manager.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ class ModelManager {
const std::string &cache_path = std::string(std::getenv("HOME")) +
"/.cache/whisper.cpp");
void mkdir(const std::string &path);
bool is_available(const std::string &model_name = "tiny.en");
int make_available(const std::string &model_name = "tiny.en");
std::string get_model_path(const std::string &model_name = "tiny.en");
bool is_available(const std::string &model_name = "base.en");
int make_available(const std::string &model_name = "base.en");
std::string get_model_path(const std::string &model_name = "base.en");

protected:
std::string model_name_to_file_name_(const std::string &model_name);
Expand Down
3 changes: 2 additions & 1 deletion whisper_util/include/whisper_util/whisper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ class Whisper {
std::vector<whisper_token> tokens();

whisper_context *ctx;
whisper_full_params params;
whisper_full_params wparams;
whisper_context_params cparams;
};
} // end of namespace whisper
#endif // WHISPER_UTIL__WHISPER_HPP_
2 changes: 1 addition & 1 deletion whisper_util/package.xml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
<?xml-model href="http://download.ros.org/schema/package_format3.xsd" schematypens="http://www.w3.org/2001/XMLSchema"?>
<package format="3">
<name>whisper_util</name>
<version>1.1.0</version>
<version>1.2.0</version>
<description>ROS 2 wrapper for whisper.cpp.</description>
<maintainer email="[email protected]">mhubii</maintainer>
<license>MIT</license>
Expand Down
6 changes: 3 additions & 3 deletions whisper_util/src/whisper.cpp
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
#include "whisper_util/whisper.hpp"

namespace whisper {
Whisper::Whisper() { params = whisper_full_default_params(WHISPER_SAMPLING_GREEDY); }
Whisper::Whisper() { wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY); }

Whisper::Whisper(const std::string &model_path) { initialize(model_path); }

Whisper::~Whisper() { whisper_free(ctx); }

void Whisper::initialize(const std::string &model_path) {
ctx = whisper_init_from_file(model_path.c_str());
ctx = whisper_init_from_file_with_params(model_path.c_str(), cparams);
}

std::string Whisper::forward(const std::vector<float> &input) {
if (whisper_full(ctx, params, input.data(), input.size()) != 0) {
if (whisper_full(ctx, wparams, input.data(), input.size()) != 0) {
return {};
}
std::vector<std::string> segments;
Expand Down

0 comments on commit c9d27b0

Please sign in to comment.