Skip to content

Commit

Permalink
langgraph: expose tags in the metadata for streamed message chunks (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
vbarda authored Jan 29, 2025
1 parent 82148e9 commit 39e65a1
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 0 deletions.
4 changes: 4 additions & 0 deletions libs/langgraph/langgraph/pregel/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,15 @@ def on_llm_new_token(
chunk: Optional[ChatGenerationChunk] = None,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[list[str]] = None,
**kwargs: Any,
) -> Any:
if not isinstance(chunk, ChatGenerationChunk):
return
if meta := self.metadata.get(run_id):
filtered_tags = [t for t in (tags or []) if not t.startswith("seq:step")]
if filtered_tags:
meta[1]["tags"] = filtered_tags
self._emit(meta, chunk.message)

def on_llm_end(
Expand Down
37 changes: 37 additions & 0 deletions libs/langgraph/tests/test_pregel.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

import httpx
import pytest
from langchain_core.language_models import GenericFakeChatModel
from langchain_core.runnables import (
RunnableConfig,
RunnableLambda,
Expand Down Expand Up @@ -80,6 +81,7 @@
from tests.memory_assert import MemorySaverAssertCheckpointMetadata
from tests.messages import (
_AnyIdAIMessage,
_AnyIdAIMessageChunk,
_AnyIdHumanMessage,
_AnyIdToolMessage,
)
Expand Down Expand Up @@ -6288,3 +6290,38 @@ def workflow(inputs: dict) -> dict:
{"qux": "foo|bar|baz|custom_baz|qux"},
{"workflow": "foo|bar|baz|custom_baz|qux"},
]


def test_tags_stream_mode_messages() -> None:
model = GenericFakeChatModel(messages=iter(["foo"]), tags=["meow"])
graph = (
StateGraph(MessagesState)
.add_node(
"call_model", lambda state: {"messages": model.invoke(state["messages"])}
)
.add_edge(START, "call_model")
.compile()
)
assert list(
graph.stream(
{
"messages": "hi",
},
stream_mode="messages",
)
) == [
(
_AnyIdAIMessageChunk(content="foo"),
{
"langgraph_step": 1,
"langgraph_node": "call_model",
"langgraph_triggers": ["start:call_model"],
"langgraph_path": ("__pregel_pull", "call_model"),
"langgraph_checkpoint_ns": AnyStr("call_model:"),
"checkpoint_ns": AnyStr("call_model:"),
"ls_provider": "genericfakechatmodel",
"ls_model_type": "chat",
"tags": ["meow"],
},
)
]
40 changes: 40 additions & 0 deletions libs/langgraph/tests/test_pregel_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

import httpx
import pytest
from langchain_core.language_models import GenericFakeChatModel
from langchain_core.runnables import (
RunnableConfig,
RunnableLambda,
Expand Down Expand Up @@ -82,6 +83,7 @@
)
from tests.messages import (
_AnyIdAIMessage,
_AnyIdAIMessageChunk,
_AnyIdHumanMessage,
_AnyIdToolMessage,
)
Expand Down Expand Up @@ -7466,3 +7468,41 @@ async def main(inputs, store: BaseStore) -> str:
return "OK"

assert await main.ainvoke({}) == "OK"


async def test_tags_stream_mode_messages() -> None:
model = GenericFakeChatModel(messages=iter(["foo"]), tags=["meow"])

async def call_model(state, config):
return {"messages": await model.ainvoke(state["messages"], config)}

graph = (
StateGraph(MessagesState)
.add_node(call_model)
.add_edge(START, "call_model")
.compile()
)
assert [
c
async for c in graph.astream(
{
"messages": "hi",
},
stream_mode="messages",
)
] == [
(
_AnyIdAIMessageChunk(content="foo"),
{
"langgraph_step": 1,
"langgraph_node": "call_model",
"langgraph_triggers": ["start:call_model"],
"langgraph_path": ("__pregel_pull", "call_model"),
"langgraph_checkpoint_ns": AnyStr("call_model:"),
"checkpoint_ns": AnyStr("call_model:"),
"ls_provider": "genericfakechatmodel",
"ls_model_type": "chat",
"tags": ["meow"],
},
)
]

0 comments on commit 39e65a1

Please sign in to comment.