Skip to content

Commit

Permalink
Add request_stream to InstrumentedModel (#922)
Browse files Browse the repository at this point in the history
  • Loading branch information
alexmojaki authored Feb 18, 2025
1 parent 8fcf8c9 commit 8d5b47a
Show file tree
Hide file tree
Showing 3 changed files with 289 additions and 10 deletions.
46 changes: 39 additions & 7 deletions pydantic_ai_slim/pydantic_ai/models/instrumented.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

from collections.abc import AsyncIterator
from contextlib import asynccontextmanager, contextmanager
from dataclasses import dataclass
from functools import partial
from typing import Any, Literal
Expand All @@ -20,7 +22,7 @@
)
from ..settings import ModelSettings
from ..usage import Usage
from . import ModelRequestParameters
from . import ModelRequestParameters, StreamedResponse
from .wrapper import WrapperModel

MODEL_SETTING_ATTRIBUTES: tuple[
Expand Down Expand Up @@ -60,6 +62,35 @@ async def request(
model_settings: ModelSettings | None,
model_request_parameters: ModelRequestParameters,
) -> tuple[ModelResponse, Usage]:
with self._instrument(messages, model_settings) as finish:
response, usage = await super().request(messages, model_settings, model_request_parameters)
finish(response, usage)
return response, usage

@asynccontextmanager
async def request_stream(
self,
messages: list[ModelMessage],
model_settings: ModelSettings | None,
model_request_parameters: ModelRequestParameters,
) -> AsyncIterator[StreamedResponse]:
with self._instrument(messages, model_settings) as finish:
response_stream: StreamedResponse | None = None
try:
async with super().request_stream(
messages, model_settings, model_request_parameters
) as response_stream:
yield response_stream
finally:
if response_stream:
finish(response_stream.get(), response_stream.usage())

@contextmanager
def _instrument(
self,
messages: list[ModelMessage],
model_settings: ModelSettings | None,
):
operation = 'chat'
model_name = self.model_name
span_name = f'{operation} {model_name}'
Expand Down Expand Up @@ -95,17 +126,18 @@ async def request(
for body in _response_bodies(message):
emit_event('gen_ai.assistant.message', body)

response, usage = await super().request(messages, model_settings, model_request_parameters)
def finish(response: ModelResponse, usage: Usage):
if not span.is_recording():
return

if span.is_recording():
for body in _response_bodies(response):
if body:
for response_body in _response_bodies(response):
if response_body:
emit_event(
'gen_ai.choice',
{
# TODO finish_reason
'index': 0,
'message': body,
'message': response_body,
},
)
span.set_attributes(
Expand All @@ -122,7 +154,7 @@ async def request(
}
)

return response, usage
yield finish

def _emit_event(self, system: str, event_name: str, body: dict[str, Any]) -> None:
self.logfire_instance.info(event_name, **{'gen_ai.system': system}, **body)
Expand Down
17 changes: 15 additions & 2 deletions pydantic_ai_slim/pydantic_ai/models/wrapper.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from __future__ import annotations

from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from dataclasses import dataclass
from typing import Any

from ..messages import ModelResponse
from ..messages import ModelMessage, ModelResponse
from ..settings import ModelSettings
from ..usage import Usage
from . import Model
from . import Model, ModelRequestParameters, StreamedResponse


@dataclass
Expand All @@ -17,6 +20,16 @@ class WrapperModel(Model):
async def request(self, *args: Any, **kwargs: Any) -> tuple[ModelResponse, Usage]:
return await self.wrapped.request(*args, **kwargs)

@asynccontextmanager
async def request_stream(
self,
messages: list[ModelMessage],
model_settings: ModelSettings | None,
model_request_parameters: ModelRequestParameters,
) -> AsyncIterator[StreamedResponse]:
async with self.wrapped.request_stream(messages, model_settings, model_request_parameters) as response_stream:
yield response_stream

@property
def model_name(self) -> str:
return self.wrapped.model_name
Expand Down
236 changes: 235 additions & 1 deletion tests/models/test_instrumented.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from __future__ import annotations

from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from datetime import datetime

import pytest
from dirty_equals import IsJson
from inline_snapshot import snapshot
Expand All @@ -8,14 +12,18 @@
ModelMessage,
ModelRequest,
ModelResponse,
ModelResponseStreamEvent,
PartDeltaEvent,
PartStartEvent,
RetryPromptPart,
SystemPromptPart,
TextPart,
TextPartDelta,
ToolCallPart,
ToolReturnPart,
UserPromptPart,
)
from pydantic_ai.models import Model, ModelRequestParameters
from pydantic_ai.models import Model, ModelRequestParameters, StreamedResponse
from pydantic_ai.models.instrumented import InstrumentedModel
from pydantic_ai.settings import ModelSettings
from pydantic_ai.usage import Usage
Expand Down Expand Up @@ -62,6 +70,30 @@ async def request(
Usage(request_tokens=100, response_tokens=200),
)

@asynccontextmanager
async def request_stream(
self,
messages: list[ModelMessage],
model_settings: ModelSettings | None,
model_request_parameters: ModelRequestParameters,
) -> AsyncIterator[StreamedResponse]:
yield MyResponseStream()


class MyResponseStream(StreamedResponse):
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
self._usage = Usage(request_tokens=300, response_tokens=400)
yield self._parts_manager.handle_text_delta(vendor_part_id=0, content='text1')
yield self._parts_manager.handle_text_delta(vendor_part_id=0, content='text2')

@property
def model_name(self) -> str:
return 'my_model_123'

@property
def timestamp(self) -> datetime:
return datetime(2022, 1, 1)


@pytest.mark.anyio
async def test_instrumented_model(capfire: CaptureLogfire):
Expand Down Expand Up @@ -322,3 +354,205 @@ async def test_instrumented_model_not_recording(capfire: CaptureLogfire):
)

assert capfire.exporter.exported_spans_as_dict() == snapshot([])


@pytest.mark.anyio
async def test_instrumented_model_stream(capfire: CaptureLogfire):
model = InstrumentedModel(MyModel())

messages: list[ModelMessage] = [
ModelRequest(
parts=[
UserPromptPart('user_prompt'),
]
),
]
async with model.request_stream(
messages,
model_settings=ModelSettings(temperature=1),
model_request_parameters=ModelRequestParameters(
function_tools=[],
allow_text_result=True,
result_tools=[],
),
) as response_stream:
assert [event async for event in response_stream] == snapshot(
[
PartStartEvent(index=0, part=TextPart(content='text1')),
PartDeltaEvent(index=0, delta=TextPartDelta(content_delta='text2')),
]
)

assert capfire.exporter.exported_spans_as_dict() == snapshot(
[
{
'name': 'gen_ai.user.message',
'context': {'trace_id': 1, 'span_id': 3, 'is_remote': False},
'parent': {'trace_id': 1, 'span_id': 1, 'is_remote': False},
'start_time': 2000000000,
'end_time': 2000000000,
'attributes': {
'logfire.span_type': 'log',
'logfire.level_num': 9,
'logfire.msg_template': 'gen_ai.user.message',
'logfire.msg': 'gen_ai.user.message',
'code.filepath': 'test_instrumented.py',
'code.function': 'test_instrumented_model_stream',
'code.lineno': 123,
'gen_ai.system': 'my_system',
'content': 'user_prompt',
'logfire.json_schema': '{"type":"object","properties":{"gen_ai.system":{},"content":{}}}',
},
},
{
'name': 'gen_ai.choice',
'context': {'trace_id': 1, 'span_id': 4, 'is_remote': False},
'parent': {'trace_id': 1, 'span_id': 1, 'is_remote': False},
'start_time': 3000000000,
'end_time': 3000000000,
'attributes': {
'logfire.span_type': 'log',
'logfire.level_num': 9,
'logfire.msg_template': 'gen_ai.choice',
'logfire.msg': 'gen_ai.choice',
'code.filepath': 'test_instrumented.py',
'code.function': 'test_instrumented_model_stream',
'code.lineno': 123,
'gen_ai.system': 'my_system',
'index': 0,
'message': '{"content":"text1text2"}',
'logfire.json_schema': '{"type":"object","properties":{"gen_ai.system":{},"index":{},"message":{"type":"object"}}}',
},
},
{
'name': 'chat my_model',
'context': {'trace_id': 1, 'span_id': 1, 'is_remote': False},
'parent': None,
'start_time': 1000000000,
'end_time': 4000000000,
'attributes': {
'code.filepath': 'test_instrumented.py',
'code.function': 'test_instrumented_model_stream',
'code.lineno': 123,
'gen_ai.operation.name': 'chat',
'gen_ai.system': 'my_system',
'gen_ai.request.model': 'my_model',
'gen_ai.request.temperature': 1,
'logfire.msg_template': 'chat my_model',
'logfire.msg': 'chat my_model',
'logfire.span_type': 'span',
'gen_ai.response.model': 'my_model_123',
'gen_ai.usage.input_tokens': 300,
'gen_ai.usage.output_tokens': 400,
'logfire.json_schema': '{"type":"object","properties":{"gen_ai.operation.name":{},"gen_ai.system":{},"gen_ai.request.model":{},"gen_ai.request.temperature":{},"gen_ai.response.model":{},"gen_ai.usage.input_tokens":{},"gen_ai.usage.output_tokens":{}}}',
},
},
]
)


@pytest.mark.anyio
async def test_instrumented_model_stream_break(capfire: CaptureLogfire):
model = InstrumentedModel(MyModel())

messages: list[ModelMessage] = [
ModelRequest(
parts=[
UserPromptPart('user_prompt'),
]
),
]

with pytest.raises(RuntimeError):
async with model.request_stream(
messages,
model_settings=ModelSettings(temperature=1),
model_request_parameters=ModelRequestParameters(
function_tools=[],
allow_text_result=True,
result_tools=[],
),
) as response_stream:
async for event in response_stream:
assert event == PartStartEvent(index=0, part=TextPart(content='text1'))
raise RuntimeError

assert capfire.exporter.exported_spans_as_dict() == snapshot(
[
{
'name': 'gen_ai.user.message',
'context': {'trace_id': 1, 'span_id': 3, 'is_remote': False},
'parent': {'trace_id': 1, 'span_id': 1, 'is_remote': False},
'start_time': 2000000000,
'end_time': 2000000000,
'attributes': {
'logfire.span_type': 'log',
'logfire.level_num': 9,
'logfire.msg_template': 'gen_ai.user.message',
'logfire.msg': 'gen_ai.user.message',
'code.filepath': 'test_instrumented.py',
'code.function': 'test_instrumented_model_stream_break',
'code.lineno': 123,
'gen_ai.system': 'my_system',
'content': 'user_prompt',
'logfire.json_schema': '{"type":"object","properties":{"gen_ai.system":{},"content":{}}}',
},
},
{
'name': 'gen_ai.choice',
'context': {'trace_id': 1, 'span_id': 4, 'is_remote': False},
'parent': {'trace_id': 1, 'span_id': 1, 'is_remote': False},
'start_time': 3000000000,
'end_time': 3000000000,
'attributes': {
'logfire.span_type': 'log',
'logfire.level_num': 9,
'logfire.msg_template': 'gen_ai.choice',
'logfire.msg': 'gen_ai.choice',
'code.filepath': 'test_instrumented.py',
'code.function': 'test_instrumented_model_stream_break',
'code.lineno': 123,
'gen_ai.system': 'my_system',
'index': 0,
'message': '{"content":"text1"}',
'logfire.json_schema': '{"type":"object","properties":{"gen_ai.system":{},"index":{},"message":{"type":"object"}}}',
},
},
{
'name': 'chat my_model',
'context': {'trace_id': 1, 'span_id': 1, 'is_remote': False},
'parent': None,
'start_time': 1000000000,
'end_time': 5000000000,
'attributes': {
'code.filepath': 'test_instrumented.py',
'code.function': 'test_instrumented_model_stream_break',
'code.lineno': 123,
'gen_ai.operation.name': 'chat',
'gen_ai.system': 'my_system',
'gen_ai.request.model': 'my_model',
'gen_ai.request.temperature': 1,
'logfire.msg_template': 'chat my_model',
'logfire.msg': 'chat my_model',
'logfire.span_type': 'span',
'gen_ai.response.model': 'my_model_123',
'gen_ai.usage.input_tokens': 300,
'gen_ai.usage.output_tokens': 400,
'logfire.level_num': 17,
'logfire.json_schema': '{"type":"object","properties":{"gen_ai.operation.name":{},"gen_ai.system":{},"gen_ai.request.model":{},"gen_ai.request.temperature":{},"gen_ai.response.model":{},"gen_ai.usage.input_tokens":{},"gen_ai.usage.output_tokens":{}}}',
},
'events': [
{
'name': 'exception',
'timestamp': 4000000000,
'attributes': {
'exception.type': 'RuntimeError',
'exception.message': '',
'exception.stacktrace': 'RuntimeError',
'exception.escaped': 'True',
},
}
],
},
]
)

0 comments on commit 8d5b47a

Please sign in to comment.