Skip to content

Commit

Permalink
Merge pull request #14 from ros-ai/dev-buffer-fix
Browse files Browse the repository at this point in the history
Fix executor
  • Loading branch information
mhubii authored Jul 1, 2024
2 parents db5543c + 373a549 commit d1741e6
Show file tree
Hide file tree
Showing 6 changed files with 23 additions and 15 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ Changelog for package ROS 2 Whisper
1.3.1 (2024-07-01)
------------------
* `whisper_msgs`: Changed to `whisper_idl` package
* `whisper_bringup`: Changed executor to `MultiThreadedExecutor` so audio and inference can run in parallel on `whisper_server`

1.3.0 (2024-06-21)
------------------
Expand Down
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,3 @@ Action server under topic `inference` of type [Inference.action](whisper_idl/act

## Troubleshoot
- Encoder inference time: https://github.com/ggerganov/whisper.cpp/issues/10#issuecomment-1302462960
- Compile with GPU support (might differ between platforms): https://github.com/ggerganov/whisper.cpp#nvidia-gpu-support-via-cublas WHISPER_CUBLAS=On
26 changes: 16 additions & 10 deletions whisper_bringup/launch/bringup.launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

from ament_index_python import get_package_share_directory
from launch import LaunchDescription
from launch_ros.actions import Node
from launch_ros.actions import ComposableNodeContainer, Node
from launch_ros.descriptions import ComposableNode


def generate_launch_description() -> LaunchDescription:
Expand All @@ -21,17 +22,22 @@ def generate_launch_description() -> LaunchDescription:
whisper_config = os.path.join(
get_package_share_directory("whisper_server"), "config", "whisper.yaml"
)
composable_node = ComposableNode(
package="whisper_server",
plugin="whisper::InferenceComponent",
name="inference",
namespace="whisper",
parameters=[whisper_config],
remappings=[("audio", "/audio_listener/audio")],
)
ld.add_action(
Node(
package="whisper_server",
executable="whisper",
ComposableNodeContainer(
name="whisper_container",
package="rclcpp_components",
namespace="",
executable="component_container_mt", # require multi-threaded executor so inference server can parallelize audio encueing and inference
output="screen",
namespace="whisper",
parameters=[whisper_config],
remappings=[
("/whisper/audio", "/audio_listener/audio"),
],
composable_node_descriptions=[composable_node],
)
)

return ld
2 changes: 1 addition & 1 deletion whisper_server/config/whisper.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
wparams:
language: "en"
print_progress: false
n_threads: 1
n_threads: 4
cparams:
flash_attn: true
gpu_device: 0
Expand Down
1 change: 1 addition & 0 deletions whisper_util/include/whisper_util/audio_buffers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ class BatchedBuffer {
std::uint16_t batch_idx_;

std::vector<float> audio_;
std::vector<float> carry_over_audio_;
RingBuffer<std::int16_t> audio_buffer_;
};
} // end of namespace whisper
Expand Down
7 changes: 4 additions & 3 deletions whisper_util/src/audio_buffers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ BatchedBuffer::BatchedBuffer(const std::chrono::milliseconds &batch_capacity,
const std::chrono::milliseconds &carry_over_capacity)
: batch_capacity_(time_to_count(batch_capacity)),
carry_over_capacity_(time_to_count(carry_over_capacity)), batch_idx_(0),
audio_buffer_(time_to_count(buffer_capacity)){
carry_over_audio_(carry_over_capacity_), audio_buffer_(time_to_count(buffer_capacity)) {

};

Expand Down Expand Up @@ -77,9 +77,10 @@ bool BatchedBuffer::require_new_batch_() {
}

void BatchedBuffer::carry_over_() {
std::vector<float> carry_over(audio_.end() - carry_over_capacity_, audio_.end());
carry_over_audio_.insert(carry_over_audio_.begin(), audio_.end() - carry_over_capacity_,
audio_.end());
audio_.clear();
audio_.insert(audio_.end(), carry_over.begin(), carry_over.end());
audio_.insert(audio_.end(), carry_over_audio_.begin(), carry_over_audio_.end());
}

void BatchedBuffer::clear() {
Expand Down

0 comments on commit d1741e6

Please sign in to comment.