Skip to content

Commit

Permalink
Merge pull request #1139 from pipecat-ai/aleix/task-start-metadata
Browse files Browse the repository at this point in the history
pipeline task start metadata and unit test improvements
  • Loading branch information
aconchillo authored Feb 5, 2025
2 parents 3be6990 + a363d12 commit ba31546
Show file tree
Hide file tree
Showing 8 changed files with 163 additions and 76 deletions.
8 changes: 8 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

- Added a new `start_metadata` field to `PipelineParams`. The provided metadata
will be set to the initial `StartFrame` being pushed from the `PipelineTask`.

- Added new fields to `PipelineParams` to control audio input and output sample
rates for the whole pipeline. This allows controlling sample rates from a
single place instead of having to specify sample rates in each
Expand Down Expand Up @@ -107,6 +110,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Other

- Improved Unit Test `run_test()` to use `PipelineTask` and
`PipelineRunner`. There's now also some control around `StartFrame` and
`EndFrame`. The `EndTaskFrame` has been removed since it doesn't seem
necessary with this new approach.

- Updated `twilio-chatbot` with a few new features: use 8000 sample rate and
avoid resampling, a new client useful for stress testing and testing locally
without the need to make phone calls. Also, added audio recording on both the
Expand Down
4 changes: 2 additions & 2 deletions dev-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
build~=1.2.2
grpcio-tools~=1.69.0
grpcio-tools~=1.67.1
pip-tools~=7.4.1
pre-commit~=4.0.1
pyright~=1.1.392
pytest~=8.3.4
pytest-asyncio~=0.25.2
ruff~=0.9.1
setuptools~=75.8.0
setuptools~=70.0.0
setuptools_scm~=8.1.0
python-dotenv~=1.0.1
19 changes: 15 additions & 4 deletions src/pipecat/frames/frames.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,18 @@

from dataclasses import dataclass, field
from enum import Enum
from typing import TYPE_CHECKING, Any, Awaitable, Callable, List, Literal, Mapping, Optional, Tuple
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
Dict,
List,
Literal,
Mapping,
Optional,
Tuple,
)

from pipecat.audio.vad.vad_analyzer import VADParams
from pipecat.clocks.base_clock import BaseClock
Expand Down Expand Up @@ -48,13 +59,13 @@ class Frame:
id: int = field(init=False)
name: str = field(init=False)
pts: Optional[int] = field(init=False)
metadata: dict = field(init=False)
metadata: Dict[str, Any] = field(init=False)

def __post_init__(self):
self.id: int = obj_id()
self.name: str = f"{self.__class__.__name__}#{obj_count(self)}"
self.pts: Optional[int] = None
self.metadata: dict = {}
self.metadata: Dict[str, Any] = {}

def __str__(self):
return self.name
Expand Down Expand Up @@ -433,8 +444,8 @@ class StartFrame(SystemFrame):
allow_interruptions: bool = False
enable_metrics: bool = False
enable_usage_metrics: bool = False
report_only_initial_ttfb: bool = False
observer: Optional["BaseObserver"] = None
report_only_initial_ttfb: bool = False


@dataclass
Expand Down
18 changes: 10 additions & 8 deletions src/pipecat/pipeline/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#

import asyncio
from typing import AsyncIterable, Iterable, List
from typing import Any, AsyncIterable, Dict, Iterable, List

from loguru import logger
from pydantic import BaseModel, ConfigDict
Expand Down Expand Up @@ -40,16 +40,17 @@
class PipelineParams(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)

allow_interruptions: bool = False
audio_in_sample_rate: int = 16000
audio_out_sample_rate: int = 24000
allow_interruptions: bool = False
enable_heartbeats: bool = False
enable_metrics: bool = False
enable_usage_metrics: bool = False
send_initial_empty_metrics: bool = True
report_only_initial_ttfb: bool = False
observers: List[BaseObserver] = []
heartbeats_period_secs: float = HEARTBEAT_SECONDS
observers: List[BaseObserver] = []
report_only_initial_ttfb: bool = False
send_initial_empty_metrics: bool = True
start_metadata: Dict[str, Any] = {}


class PipelineTaskSource(FrameProcessor):
Expand Down Expand Up @@ -278,13 +279,14 @@ async def _process_push_queue(self):
clock=self._clock,
task_manager=self._task_manager,
allow_interruptions=self._params.allow_interruptions,
audio_in_sample_rate=self._params.audio_in_sample_rate,
audio_out_sample_rate=self._params.audio_out_sample_rate,
enable_metrics=self._params.enable_metrics,
enable_usage_metrics=self._params.enable_usage_metrics,
report_only_initial_ttfb=self._params.report_only_initial_ttfb,
observer=self._observer,
audio_in_sample_rate=self._params.audio_in_sample_rate,
audio_out_sample_rate=self._params.audio_out_sample_rate,
report_only_initial_ttfb=self._params.report_only_initial_ttfb,
)
start_frame.metadata = self._params.start_metadata
await self._source.queue_frame(start_frame, FrameDirection.DOWNSTREAM)

if self._params.enable_metrics and self._params.send_initial_empty_metrics:
Expand Down
70 changes: 36 additions & 34 deletions src/pipecat/tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,25 @@
#

import asyncio
from dataclasses import dataclass
from typing import Awaitable, Callable, Sequence, Tuple
import sys
from typing import Any, Awaitable, Callable, Dict, Sequence, Tuple

from loguru import logger

from pipecat.clocks.system_clock import SystemClock
from pipecat.frames.frames import (
ControlFrame,
EndFrame,
Frame,
HeartbeatFrame,
StartFrame,
)
from pipecat.observers.base_observer import BaseObserver
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineParams, PipelineTask
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
from pipecat.utils.asyncio import TaskManager


@dataclass
class EndTestFrame(ControlFrame):
pass
logger.remove(0)
logger.add(sys.stderr, level="TRACE")


class HeartbeatsObserver(BaseObserver):
Expand All @@ -48,54 +49,58 @@ async def on_push_frame(


class QueuedFrameProcessor(FrameProcessor):
def __init__(self, queue: asyncio.Queue, ignore_start: bool = True):
def __init__(
self, queue: asyncio.Queue, queue_direction: FrameDirection, ignore_start: bool = True
):
super().__init__()
self._queue = queue
self._queue_direction = queue_direction
self._ignore_start = ignore_start

async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)

if self._ignore_start and isinstance(frame, StartFrame):
await self.push_frame(frame, direction)
else:
await self._queue.put(frame)
await self.push_frame(frame, direction)
if direction == self._queue_direction:
if not isinstance(frame, StartFrame) or not self._ignore_start:
await self._queue.put(frame)
await self.push_frame(frame, direction)


async def run_test(
processor: FrameProcessor,
*,
frames_to_send: Sequence[Frame],
expected_down_frames: Sequence[type],
expected_up_frames: Sequence[type] = [],
ignore_start: bool = True,
start_metadata: Dict[str, Any] = {},
send_end_frame: bool = True,
) -> Tuple[Sequence[Frame], Sequence[Frame]]:
received_up = asyncio.Queue()
received_down = asyncio.Queue()
source = QueuedFrameProcessor(received_up)
sink = QueuedFrameProcessor(received_down)
source = QueuedFrameProcessor(received_up, FrameDirection.UPSTREAM, ignore_start)
sink = QueuedFrameProcessor(received_down, FrameDirection.DOWNSTREAM, ignore_start)

source.link(processor)
processor.link(sink)
pipeline = Pipeline([source, processor, sink])

task_manager = TaskManager()
task_manager.set_event_loop(asyncio.get_event_loop())
await source.queue_frame(StartFrame(clock=SystemClock(), task_manager=task_manager))
task = PipelineTask(pipeline, params=PipelineParams(start_metadata=start_metadata))

for frame in frames_to_send:
await processor.process_frame(frame, FrameDirection.DOWNSTREAM)
await task.queue_frame(frame)

if send_end_frame:
await task.queue_frame(EndFrame())

await processor.queue_frame(EndTestFrame())
await processor.queue_frame(EndTestFrame(), FrameDirection.UPSTREAM)
runner = PipelineRunner()
await runner.run(task)

#
# Down frames
#
received_down_frames: Sequence[Frame] = []
running = True
while running:
while not received_down.empty():
frame = await received_down.get()
running = not isinstance(frame, EndTestFrame)
if running:
if not isinstance(frame, EndFrame) or not send_end_frame:
received_down_frames.append(frame)

print("received DOWN frames =", received_down_frames)
Expand All @@ -109,12 +114,9 @@ async def run_test(
# Up frames
#
received_up_frames: Sequence[Frame] = []
running = True
while running:
while not received_up.empty():
frame = await received_up.get()
running = not isinstance(frame, EndTestFrame)
if running:
received_up_frames.append(frame)
received_up_frames.append(frame)

print("received UP frames =", received_up_frames)

Expand Down
10 changes: 8 additions & 2 deletions tests/test_aggregators.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,11 @@ async def test_sentence_aggregator(self):

expected_returned_frames = [TextFrame, TextFrame, TextFrame]

(received_down, _) = await run_test(aggregator, frames_to_send, expected_returned_frames)
(received_down, _) = await run_test(
aggregator,
frames_to_send=frames_to_send,
expected_down_frames=expected_returned_frames,
)
assert received_down[-3].text == "Hello, world. "
assert received_down[-2].text == "How are you? "
assert received_down[-1].text == "I am fine! "
Expand Down Expand Up @@ -66,5 +70,7 @@ async def test_gated_aggregator(self):
]

(received_down, _) = await run_test(
gated_aggregator, frames_to_send, expected_returned_frames
gated_aggregator,
frames_to_send=frames_to_send,
expected_down_frames=expected_returned_frames,
)
Loading

0 comments on commit ba31546

Please sign in to comment.