diff --git a/libs/aws/langchain_aws/chat_models/bedrock.py b/libs/aws/langchain_aws/chat_models/bedrock.py index 52153e39..4914c45b 100644 --- a/libs/aws/langchain_aws/chat_models/bedrock.py +++ b/libs/aws/langchain_aws/chat_models/bedrock.py @@ -1,5 +1,7 @@ +import json import re from collections import defaultdict +from operator import itemgetter from typing import ( Any, Callable, @@ -29,12 +31,19 @@ HumanMessage, SystemMessage, ) +from langchain_core.messages.tool import ToolCall, ToolMessage from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult from langchain_core.pydantic_v1 import BaseModel, Extra -from langchain_core.runnables import Runnable +from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough from langchain_core.tools import BaseTool -from langchain_aws.function_calling import convert_to_anthropic_tool, get_system_message +from langchain_aws.function_calling import ( + ToolsOutputParser, + _lc_tool_calls_to_anthropic_tool_use_blocks, + _tools_in_params, + convert_to_anthropic_tool, + get_system_message, +) from langchain_aws.llms.bedrock import ( BedrockBase, _combine_generation_info_for_llm_result, @@ -197,23 +206,54 @@ def _format_image(image_url: str) -> Dict: } +def _merge_messages( + messages: Sequence[BaseMessage], +) -> List[Union[SystemMessage, AIMessage, HumanMessage]]: + """Merge runs of human/tool messages into single human messages with content blocks.""" # noqa: E501 + merged: list = [] + for curr in messages: + curr = curr.copy(deep=True) + if isinstance(curr, ToolMessage): + if isinstance(curr.content, list) and all( + isinstance(block, dict) and block.get("type") == "tool_result" + for block in curr.content + ): + curr = HumanMessage(curr.content) # type: ignore[misc] + else: + curr = HumanMessage( # type: ignore[misc] + [ + { + "type": "tool_result", + "content": curr.content, + "tool_use_id": curr.tool_call_id, + } + ] + ) + last = merged[-1] if merged else None + if isinstance(last, HumanMessage) and isinstance(curr, HumanMessage): + if isinstance(last.content, str): + new_content: List = [{"type": "text", "text": last.content}] + else: + new_content = last.content + if isinstance(curr.content, str): + new_content.append({"type": "text", "text": curr.content}) + else: + new_content.extend(curr.content) + last.content = new_content + else: + merged.append(curr) + return merged + + def _format_anthropic_messages( messages: List[BaseMessage], ) -> Tuple[Optional[str], List[Dict]]: """Format messages for anthropic.""" - - """ - [ - { - "role": _message_type_lookups[m.type], - "content": [_AnthropicMessageContent(text=m.content).dict()], - } - for m in messages - ] - """ system: Optional[str] = None formatted_messages: List[Dict] = [] - for i, message in enumerate(messages): + + merged_messages = _merge_messages(messages) + for i, message in enumerate(merged_messages): if message.type == "system": if i != 0: raise ValueError("System message must be at beginning of message list.") @@ -226,7 +266,7 @@ def _format_anthropic_messages( continue role = _message_type_lookups[message.type] - content: Union[str, List[Dict]] + content: Union[str, List] if not isinstance(message.content, str): # parse as dict @@ -238,39 +278,58 @@ def _format_anthropic_messages( content = [] for item in message.content: if isinstance(item, str): - content.append( - { - "type": "text", - "text": item, - } - ) + content.append({"type": "text", "text": item}) elif isinstance(item, dict): if "type" not in item: raise ValueError("Dict content item must have a type key") - if item["type"] == "image_url": + elif item["type"] == "image_url": # convert format source = _format_image(item["image_url"]["url"]) - content.append( - { - "type": "image", - "source": source, - } - ) + content.append({"type": "image", "source": source}) + elif item["type"] == "tool_use": + # If a tool_call with the same id as a tool_use content block + # exists, the tool_call is preferred. + if isinstance(message, AIMessage) and item["id"] in [ + tc["id"] for tc in message.tool_calls + ]: + overlapping = [ + tc + for tc in message.tool_calls + if tc["id"] == item["id"] + ] + content.extend( + _lc_tool_calls_to_anthropic_tool_use_blocks(overlapping) + ) + else: + item.pop("text", None) + content.append(item) + elif item["type"] == "text": + text = item.get("text", "") + # Only add non-empty strings for now as empty ones are not + # accepted. + # https://github.com/anthropics/anthropic-sdk-python/issues/461 + if text.strip(): + content.append({"type": "text", "text": text}) else: content.append(item) else: raise ValueError( f"Content items must be str or dict, instead was: {type(item)}" ) + elif isinstance(message, AIMessage) and message.tool_calls: + content = ( + [] + if not message.content + else [{"type": "text", "text": message.content}] + ) + # Note: Anthropic can't have invalid tool calls as presently defined, + # since the model already returns dicts args not JSON strings, and invalid + # tool calls are those with invalid JSON for args. + content += _lc_tool_calls_to_anthropic_tool_use_blocks(message.tool_calls) else: content = message.content - formatted_messages.append( - { - "role": role, - "content": content, - } - ) + formatted_messages.append({"role": role, "content": content}) return system, formatted_messages @@ -316,7 +375,12 @@ def format_messages( ) -_message_type_lookups = {"human": "user", "ai": "assistant"} +_message_type_lookups = { + "human": "user", + "ai": "assistant", + "AIMessageChunk": "assistant", + "HumanMessageChunk": "user", +} class ChatBedrock(BaseChatModel, BedrockBase): @@ -363,6 +427,31 @@ def _stream( provider = self._get_provider() prompt, system, formatted_messages = None, None, None + if "claude-3" in self._get_model(): + if _tools_in_params({**kwargs}): + result = self._generate( + messages, stop=stop, run_manager=run_manager, **kwargs + ) + message = result.generations[0].message + if isinstance(message, AIMessage) and message.tool_calls is not None: + tool_call_chunks = [ + { + "name": tool_call["name"], + "args": json.dumps(tool_call["args"]), + "id": tool_call["id"], + "index": idx, + } + for idx, tool_call in enumerate(message.tool_calls) + ] + message_chunk = AIMessageChunk( + content=message.content, + tool_call_chunks=tool_call_chunks, # type: ignore[arg-type] + usage_metadata=message.usage_metadata, + ) + yield ChatGenerationChunk(message=message_chunk) + else: + yield cast(ChatGenerationChunk, result.generations[0]) + return if provider == "anthropic": system, formatted_messages = ChatPromptAdapter.format_messages( provider, messages @@ -403,6 +492,7 @@ def _generate( ) -> ChatResult: completion = "" llm_output: Dict[str, Any] = {} + tool_calls: List[Dict[str, Any]] = [] provider_stop_reason_code = self.provider_stop_reason_key_map.get( self._get_provider(), "stop_reason" ) @@ -411,6 +501,8 @@ def _generate( for chunk in self._stream(messages, stop, run_manager, **kwargs): completion += chunk.text response_metadata.append(chunk.message.response_metadata) + if "tool_calls" in chunk.message.additional_kwargs.keys(): + tool_calls = chunk.message.additional_kwargs["tool_calls"] llm_output = _combine_generation_info_for_llm_result( response_metadata, provider_stop_reason_code ) @@ -423,6 +515,7 @@ def _generate( system, formatted_messages = ChatPromptAdapter.format_messages( provider, messages ) + # use tools the new way with claude 3 if self.system_prompt_with_tools: if system: system = self.system_prompt_with_tools + f"\n{system}" @@ -436,7 +529,7 @@ def _generate( if stop: params["stop_sequences"] = stop - completion, llm_output = self._prepare_input_and_invoke( + completion, tool_calls, llm_output = self._prepare_input_and_invoke( prompt=prompt, stop=stop, run_manager=run_manager, @@ -446,10 +539,18 @@ def _generate( ) llm_output["model_id"] = self.model_id + if len(tool_calls) > 0: + msg = AIMessage( + content=completion, + additional_kwargs=llm_output, + tool_calls=cast(List[ToolCall], tool_calls), + ) + else: + msg = AIMessage(content=completion, additional_kwargs=llm_output) return ChatResult( generations=[ ChatGeneration( - message=AIMessage(content=completion, additional_kwargs=llm_output) + message=msg, ) ], llm_output=llm_output, @@ -507,14 +608,167 @@ def bind_tools( **kwargs: Any additional parameters to pass to the :class:`~langchain.runnable.Runnable` constructor. """ - provider = self._get_provider() - - if provider == "anthropic": + if self._get_provider() == "anthropic": formatted_tools = [convert_to_anthropic_tool(tool) for tool in tools] - system_formatted_tools = get_system_message(formatted_tools) - self.set_system_prompt_with_tools(system_formatted_tools) + + # true if the model is a claude 3 model + if "claude-3" in self._get_model(): + if not tool_choice: + pass + elif isinstance(tool_choice, dict): + kwargs["tool_choice"] = tool_choice + elif isinstance(tool_choice, str) and tool_choice in ("any", "auto"): + kwargs["tool_choice"] = {"type": tool_choice} + elif isinstance(tool_choice, str): + kwargs["tool_choice"] = {"type": "tool", "name": tool_choice} + else: + raise ValueError( + f"Unrecognized 'tool_choice' type {tool_choice=}." + f"Expected dict, str, or None." + ) + return self.bind(tools=formatted_tools, **kwargs) + else: + # add tools to the system prompt, the old way + system_formatted_tools = get_system_message(formatted_tools) + self.set_system_prompt_with_tools(system_formatted_tools) return self + def with_structured_output( + self, + schema: Union[Dict, Type[BaseModel]], + *, + include_raw: bool = False, + **kwargs: Any, + ) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]: + """Model wrapper that returns outputs formatted to match the given schema. + + Args: + schema: The output schema as a dict or a Pydantic class. If a Pydantic class + then the model output will be an object of that class. If a dict then + the model output will be a dict. With a Pydantic class the returned + attributes will be validated, whereas with a dict they will not be. + include_raw: If False then only the parsed structured output is returned. If + an error occurs during model output parsing it will be raised. If True + then both the raw model response (a BaseMessage) and the parsed model + response will be returned. If an error occurs during output parsing it + will be caught and returned as well. The final output is always a dict + with keys "raw", "parsed", and "parsing_error". + + Returns: + A Runnable that takes any ChatModel input. The output type depends on + include_raw and schema. + + If include_raw is True then output is a dict with keys: + raw: BaseMessage, + parsed: Optional[_DictOrPydantic], + parsing_error: Optional[BaseException], + + If include_raw is False and schema is a Dict then the runnable outputs a Dict. + If include_raw is False and schema is a Type[BaseModel] then the runnable + outputs a BaseModel. + + Example: Pydantic schema (include_raw=False): + .. code-block:: python + + from langchain_aws.chat_models.bedrock import ChatBedrock + from langchain_core.pydantic_v1 import BaseModel + + class AnswerWithJustification(BaseModel): + '''An answer to the user question along with justification for the answer.''' + answer: str + justification: str + + llm =ChatBedrock( + model_id="anthropic.claude-3-sonnet-20240229-v1:0", + model_kwargs={"temperature": 0.001}, + ) # type: ignore[call-arg] + structured_llm = llm.with_structured_output(AnswerWithJustification) + + structured_llm.invoke("What weighs more a pound of bricks or a pound of feathers") + + # -> AnswerWithJustification( + # answer='They weigh the same', + # justification='Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume or density of the objects may differ.' + # ) + + Example: Pydantic schema (include_raw=True): + .. code-block:: python + + from langchain_aws.chat_models.bedrock import ChatBedrock + from langchain_core.pydantic_v1 import BaseModel + + class AnswerWithJustification(BaseModel): + '''An answer to the user question along with justification for the answer.''' + answer: str + justification: str + + llm =ChatBedrock( + model_id="anthropic.claude-3-sonnet-20240229-v1:0", + model_kwargs={"temperature": 0.001}, + ) # type: ignore[call-arg] + structured_llm = llm.with_structured_output(AnswerWithJustification, include_raw=True) + + structured_llm.invoke("What weighs more a pound of bricks or a pound of feathers") + # -> { + # 'raw': AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_Ao02pnFYXD6GN1yzc0uXPsvF', 'function': {'arguments': '{"answer":"They weigh the same.","justification":"Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume or density of the objects may differ."}', 'name': 'AnswerWithJustification'}, 'type': 'function'}]}), + # 'parsed': AnswerWithJustification(answer='They weigh the same.', justification='Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume or density of the objects may differ.'), + # 'parsing_error': None + # } + + Example: Dict schema (include_raw=False): + .. code-block:: python + + from langchain_aws.chat_models.bedrock import ChatBedrock + + schema = { + "name": "AnswerWithJustification", + "description": "An answer to the user question along with justification for the answer.", + "input_schema": { + "type": "object", + "properties": { + "answer": {"type": "string"}, + "justification": {"type": "string"}, + }, + "required": ["answer", "justification"] + } + } + llm =ChatBedrock( + model_id="anthropic.claude-3-sonnet-20240229-v1:0", + model_kwargs={"temperature": 0.001}, + ) # type: ignore[call-arg] + structured_llm = llm.with_structured_output(schema) + + structured_llm.invoke("What weighs more a pound of bricks or a pound of feathers") + # -> { + # 'answer': 'They weigh the same', + # 'justification': 'Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume and density of the two substances differ.' + # } + + """ # noqa: E501 + if "claude-3" not in self._get_model(): + ValueError( + f"Structured output is not supported for model {self._get_model()}" + ) + llm = self.bind_tools([schema], tool_choice="any") + if isinstance(schema, type) and issubclass(schema, BaseModel): + output_parser = ToolsOutputParser( + first_tool_only=True, pydantic_schemas=[schema] + ) + else: + output_parser = ToolsOutputParser(first_tool_only=True, args_only=True) + + if include_raw: + parser_assign = RunnablePassthrough.assign( + parsed=itemgetter("raw") | output_parser, parsing_error=lambda _: None + ) + parser_none = RunnablePassthrough.assign(parsed=lambda _: None) + parser_with_fallback = parser_assign.with_fallbacks( + [parser_none], exception_key="parsing_error" + ) + return RunnableMap(raw=llm) | parser_with_fallback + else: + return llm | output_parser + @deprecated(since="0.1.0", removal="0.2.0", alternative="ChatBedrock") class BedrockChat(ChatBedrock): diff --git a/libs/aws/langchain_aws/function_calling.py b/libs/aws/langchain_aws/function_calling.py index 765332e2..1e2c53e1 100644 --- a/libs/aws/langchain_aws/function_calling.py +++ b/libs/aws/langchain_aws/function_calling.py @@ -8,10 +8,16 @@ Dict, List, Literal, + Optional, Type, Union, + cast, ) +from langchain_core.messages import ToolCall +from langchain_core.output_parsers import BaseGenerationOutputParser +from langchain_core.outputs import ChatGeneration, Generation +from langchain_core.prompts.chat import AIMessage from langchain_core.pydantic_v1 import BaseModel from langchain_core.tools import BaseTool from langchain_core.utils.function_calling import convert_to_openai_tool @@ -63,6 +69,35 @@ class AnthropicTool(TypedDict): input_schema: Dict[str, Any] +def _tools_in_params(params: dict) -> bool: + return "tools" in params or ( + "extra_body" in params and params["extra_body"].get("tools") + ) + + +class _AnthropicToolUse(TypedDict): + type: Literal["tool_use"] + name: str + input: dict + id: str + + +def _lc_tool_calls_to_anthropic_tool_use_blocks( + tool_calls: List[ToolCall], +) -> List[_AnthropicToolUse]: + blocks = [] + for tool_call in tool_calls: + blocks.append( + _AnthropicToolUse( + type="tool_use", + name=tool_call["name"], + input=tool_call["args"], + id=cast(str, tool_call["id"]), + ) + ) + return blocks + + def _get_type(parameter: Dict[str, Any]) -> str: if "type" in parameter: return parameter["type"] @@ -122,6 +157,54 @@ class ToolDescription(TypedDict): function: FunctionDescription +class ToolsOutputParser(BaseGenerationOutputParser): + first_tool_only: bool = False + args_only: bool = False + pydantic_schemas: Optional[List[Type[BaseModel]]] = None + + class Config: + extra = "forbid" + + def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any: + """Parse a list of candidate model Generations into a specific format. + + Args: + result: A list of Generations to be parsed. The Generations are assumed + to be different candidate outputs for a single model input. + + Returns: + Structured output. + """ + if not result or not isinstance(result[0], ChatGeneration): + return None if self.first_tool_only else [] + message = result[0].message + if len(message.content) > 0: + tool_calls: List = [] + else: + content = cast(AIMessage, message) + _tool_calls = [dict(tc) for tc in content.tool_calls] + # Map tool call id to index + id_to_index = {block["id"]: i for i, block in enumerate(_tool_calls)} + tool_calls = [{**tc, "index": id_to_index[tc["id"]]} for tc in _tool_calls] + if self.pydantic_schemas: + tool_calls = [self._pydantic_parse(tc) for tc in tool_calls] + elif self.args_only: + tool_calls = [tc["args"] for tc in tool_calls] + else: + pass + + if self.first_tool_only: + return tool_calls[0] if tool_calls else None + else: + return [tool_call for tool_call in tool_calls] + + def _pydantic_parse(self, tool_call: dict) -> BaseModel: + cls_ = {schema.__name__: schema for schema in self.pydantic_schemas or []}[ + tool_call["name"] + ] + return cls_(**tool_call["args"]) + + def convert_to_anthropic_tool( tool: Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool], ) -> AnthropicTool: diff --git a/libs/aws/langchain_aws/llms/bedrock.py b/libs/aws/langchain_aws/llms/bedrock.py index c20e3365..24416eac 100644 --- a/libs/aws/langchain_aws/llms/bedrock.py +++ b/libs/aws/langchain_aws/llms/bedrock.py @@ -12,6 +12,7 @@ Mapping, Optional, Tuple, + TypedDict, Union, ) @@ -21,10 +22,12 @@ CallbackManagerForLLMRun, ) from langchain_core.language_models import LLM, BaseLanguageModel +from langchain_core.messages import ToolCall from langchain_core.outputs import Generation, GenerationChunk, LLMResult from langchain_core.pydantic_v1 import Extra, Field, root_validator from langchain_core.utils import get_from_dict_or_env +from langchain_aws.function_calling import _tools_in_params from langchain_aws.utils import ( enforce_stop_tokens, get_num_tokens_anthropic, @@ -81,7 +84,10 @@ def _human_assistant_format(input_text: str) -> str: def _stream_response_to_generation_chunk( - stream_response: Dict[str, Any], provider: str, output_key: str, messages_api: bool + stream_response: Dict[str, Any], + provider: str, + output_key: str, + messages_api: bool, ) -> Union[GenerationChunk, None]: """Convert a stream response to a generation chunk.""" if messages_api: @@ -174,6 +180,23 @@ def _combine_generation_info_for_llm_result( return {"usage": total_usage_info, "stop_reason": stop_reason} +def extract_tool_calls(content: List[dict]) -> List[ToolCall]: + tool_calls = [] + for block in content: + if block["type"] != "tool_use": + continue + tool_calls.append( + ToolCall(name=block["name"], args=block["input"], id=block["id"]) + ) + return tool_calls + + +class AnthropicTool(TypedDict): + name: str + description: str + input_schema: Dict[str, Any] + + class LLMInputOutputAdapter: """Adapter class to prepare the inputs from Langchain to a format that LLM model expects. @@ -197,10 +220,13 @@ def prepare_input( prompt: Optional[str] = None, system: Optional[str] = None, messages: Optional[List[Dict]] = None, + tools: Optional[List[AnthropicTool]] = None, ) -> Dict[str, Any]: input_body = {**model_kwargs} if provider == "anthropic": if messages: + if tools: + input_body["tools"] = tools input_body["anthropic_version"] = "bedrock-2023-05-31" input_body["messages"] = messages if system: @@ -225,16 +251,20 @@ def prepare_input( @classmethod def prepare_output(cls, provider: str, response: Any) -> dict: text = "" + tool_calls = [] + response_body = json.loads(response.get("body").read().decode()) + if provider == "anthropic": - response_body = json.loads(response.get("body").read().decode()) if "completion" in response_body: text = response_body.get("completion") elif "content" in response_body: content = response_body.get("content") - text = content[0].get("text") - else: - response_body = json.loads(response.get("body").read()) + if len(content) == 1 and content[0]["type"] == "text": + text = content[0]["text"] + elif any(block["type"] == "tool_use" for block in content): + tool_calls = extract_tool_calls(content) + else: if provider == "ai21": text = response_body.get("completions")[0].get("data").get("text") elif provider == "cohere": @@ -251,6 +281,7 @@ def prepare_output(cls, provider: str, response: Any) -> dict: completion_tokens = int(headers.get("x-amzn-bedrock-output-token-count", 0)) return { "text": text, + "tool_calls": tool_calls, "body": response_body, "usage": { "prompt_tokens": prompt_tokens, @@ -584,12 +615,15 @@ def _prepare_input_and_invoke( stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, - ) -> Tuple[str, Dict[str, Any]]: + ) -> Tuple[ + str, + List[dict], + Dict[str, Any], + ]: _model_kwargs = self.model_kwargs or {} provider = self._get_provider() params = {**_model_kwargs, **kwargs} - input_body = LLMInputOutputAdapter.prepare_input( provider=provider, model_kwargs=params, @@ -597,6 +631,16 @@ def _prepare_input_and_invoke( system=system, messages=messages, ) + if "claude-3" in self._get_model(): + if _tools_in_params(params): + input_body = LLMInputOutputAdapter.prepare_input( + provider=provider, + model_kwargs=params, + prompt=prompt, + system=system, + messages=messages, + tools=params["tools"], + ) body = json.dumps(input_body) accept = "application/json" contentType = "application/json" @@ -621,9 +665,13 @@ def _prepare_input_and_invoke( try: response = self.client.invoke_model(**request_options) - text, body, usage_info, stop_reason = LLMInputOutputAdapter.prepare_output( - provider, response - ).values() + ( + text, + tool_calls, + body, + usage_info, + stop_reason, + ) = LLMInputOutputAdapter.prepare_output(provider, response).values() except Exception as e: raise ValueError(f"Error raised by bedrock service: {e}") @@ -646,7 +694,7 @@ def _prepare_input_and_invoke( **services_trace, ) - return text, llm_output + return text, tool_calls, llm_output def _get_bedrock_services_signal(self, body: dict) -> dict: """ @@ -711,6 +759,16 @@ def _prepare_input_and_invoke_stream( messages=messages, model_kwargs=params, ) + if "claude-3" in self._get_model(): + if _tools_in_params(params): + input_body = LLMInputOutputAdapter.prepare_input( + provider=provider, + model_kwargs=params, + prompt=prompt, + system=system, + messages=messages, + tools=params["tools"], + ) body = json.dumps(input_body) request_options = { @@ -737,7 +795,10 @@ def _prepare_input_and_invoke_stream( raise ValueError(f"Error raised by bedrock service: {e}") for chunk in LLMInputOutputAdapter.prepare_output_stream( - provider, response, stop, True if messages else False + provider, + response, + stop, + True if messages else False, ): yield chunk # verify and raise callback error if any middleware intervened @@ -770,13 +831,24 @@ async def _aprepare_input_and_invoke_stream( _model_kwargs["stream"] = True params = {**_model_kwargs, **kwargs} - input_body = LLMInputOutputAdapter.prepare_input( - provider=provider, - prompt=prompt, - system=system, - messages=messages, - model_kwargs=params, - ) + if "claude-3" in self._get_model(): + if _tools_in_params(params): + input_body = LLMInputOutputAdapter.prepare_input( + provider=provider, + model_kwargs=params, + prompt=prompt, + system=system, + messages=messages, + tools=params["tools"], + ) + else: + input_body = LLMInputOutputAdapter.prepare_input( + provider=provider, + prompt=prompt, + system=system, + messages=messages, + model_kwargs=params, + ) body = json.dumps(input_body) response = await asyncio.get_running_loop().run_in_executor( @@ -790,7 +862,10 @@ async def _aprepare_input_and_invoke_stream( ) async for chunk in LLMInputOutputAdapter.aprepare_output_stream( - provider, response, stop, True if messages else False + provider, + response, + stop, + True if messages else False, ): yield chunk if run_manager is not None and asyncio.iscoroutinefunction( @@ -951,7 +1026,7 @@ def _call( return completion - text, llm_output = self._prepare_input_and_invoke( + text, tool_calls, llm_output = self._prepare_input_and_invoke( prompt=prompt, stop=stop, run_manager=run_manager, **kwargs ) if run_manager is not None: diff --git a/libs/aws/poetry.lock b/libs/aws/poetry.lock index 8e14b7b5..d0f665c2 100644 --- a/libs/aws/poetry.lock +++ b/libs/aws/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. [[package]] name = "annotated-types" @@ -334,7 +334,7 @@ files = [ [[package]] name = "langchain-core" -version = "0.2.0rc1" +version = "0.2.5" description = "Building applications with LLMs through composability" optional = false python-versions = ">=3.8.1,<4.0" @@ -343,20 +343,17 @@ develop = false [package.dependencies] jsonpatch = "^1.33" -langsmith = "^0.1.0" +langsmith = "^0.1.75" packaging = "^23.2" pydantic = ">=1,<3" PyYAML = ">=5.3" tenacity = "^8.1.0" -[package.extras] -extended-testing = ["jinja2 (>=3,<4)"] - [package.source] type = "git" url = "https://github.com/langchain-ai/langchain.git" reference = "HEAD" -resolved_reference = "06110e20b96f7b01a47c00477eaa9808149c28c0" +resolved_reference = "00ad19750255008e6f7a86b4c0e89530a4b2a0cc" subdirectory = "libs/core" [[package]] @@ -381,13 +378,13 @@ subdirectory = "libs/standard-tests" [[package]] name = "langsmith" -version = "0.1.59" +version = "0.1.77" description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform." optional = false python-versions = "<4.0,>=3.8.1" files = [ - {file = "langsmith-0.1.59-py3-none-any.whl", hash = "sha256:445e3bc1d3baa1e5340cd979907a19483b9763a2ed37b863a01113d406f69345"}, - {file = "langsmith-0.1.59.tar.gz", hash = "sha256:e748a89f4dd6aa441349143e49e546c03b5dfb43376a25bfef6a5ca792fe1437"}, + {file = "langsmith-0.1.77-py3-none-any.whl", hash = "sha256:2202cc21b1ed7e7b9e5d2af2694be28898afa048c09fdf09f620cbd9301755ae"}, + {file = "langsmith-0.1.77.tar.gz", hash = "sha256:4ace09077a9a4e412afeb4b517ca68e7de7b07f36e4792dc8236ac5207c0c0c7"}, ] [package.dependencies] @@ -1014,4 +1011,4 @@ zstd = ["zstandard (>=0.18.0)"] [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<4.0" -content-hash = "2debfc6d2c6506b7b760adcc5bcc0242ee8fe01f0e73d95965406ead23962461" +content-hash = "fd4ce90ec2f2c93efaf779201bdef7e4ac8ae76b74b15693ddc80311f37f5f71" diff --git a/libs/aws/pyproject.toml b/libs/aws/pyproject.toml index 9b6fe5b5..ebd345de 100644 --- a/libs/aws/pyproject.toml +++ b/libs/aws/pyproject.toml @@ -12,7 +12,7 @@ license = "MIT" [tool.poetry.dependencies] python = ">=3.8.1,<4.0" -langchain-core = ">=0.1.45,<0.3" +langchain-core = ">=0.2.2,<0.3" boto3 = ">=1.34.51,<1.35.0" numpy = "^1" diff --git a/libs/aws/tests/integration_tests/chat_models/test_bedrock.py b/libs/aws/tests/integration_tests/chat_models/test_bedrock.py index 6d1fb57a..1d1d8e96 100644 --- a/libs/aws/tests/integration_tests/chat_models/test_bedrock.py +++ b/libs/aws/tests/integration_tests/chat_models/test_bedrock.py @@ -1,9 +1,10 @@ """Test Bedrock chat model.""" - +import json from typing import Any, cast import pytest from langchain_core.messages import ( + AIMessage, AIMessageChunk, BaseMessage, HumanMessage, @@ -19,7 +20,10 @@ @pytest.fixture def chat() -> ChatBedrock: - return ChatBedrock(model_id="anthropic.claude-v2", model_kwargs={"temperature": 0}) # type: ignore[call-arg] + return ChatBedrock( + model_id="anthropic.claude-3-sonnet-20240229-v1:0", + model_kwargs={"temperature": 0}, + ) # type: ignore[call-arg] @pytest.mark.scheduled @@ -55,7 +59,7 @@ def test_chat_bedrock_generate_with_token_usage(chat: ChatBedrock) -> None: assert isinstance(response.llm_output, dict) usage = response.llm_output["usage"] - assert usage["prompt_tokens"] == 20 + assert usage["prompt_tokens"] == 16 assert usage["completion_tokens"] > 0 assert usage["total_tokens"] > 0 @@ -177,7 +181,88 @@ def test_bedrock_invoke(chat: ChatBedrock) -> None: result = chat.invoke("I'm Pickle Rick", config=dict(tags=["foo"])) assert isinstance(result.content, str) assert "usage" in result.additional_kwargs - assert result.additional_kwargs["usage"]["prompt_tokens"] == 13 + assert result.additional_kwargs["usage"]["prompt_tokens"] == 12 + + +class GetWeather(BaseModel): + """Useful for getting the weather in a location.""" + + location: str = Field(..., description="The city and state") + + +class AnswerWithJustification(BaseModel): + """An answer to the user question along with justification for the answer.""" + + answer: str + justification: str + + +@pytest.mark.scheduled +def test_structured_output() -> None: + chat = ChatBedrock( + model_id="anthropic.claude-3-sonnet-20240229-v1:0", + model_kwargs={"temperature": 0.001}, + ) # type: ignore[call-arg] + structured_llm = chat.with_structured_output(AnswerWithJustification) + + response = structured_llm.invoke( + "What weighs more a pound of bricks or a pound of feathers" + ) + + assert isinstance(response, AnswerWithJustification) + + +@pytest.mark.scheduled +def test_tool_use_call_invoke() -> None: + chat = ChatBedrock( + model_id="anthropic.claude-3-sonnet-20240229-v1:0", + model_kwargs={"temperature": 0.001}, + ) # type: ignore[call-arg] + + llm_with_tools = chat.bind_tools([GetWeather]) + + messages = [HumanMessage(content="what is the weather like in San Francisco CA")] + + response = llm_with_tools.invoke(messages) + assert isinstance(response, AIMessage) + assert isinstance(response.tool_calls, list) + assert len(response.tool_calls) == 1 + tool_call = response.tool_calls[0] + assert tool_call["name"] == "GetWeather" + assert isinstance(tool_call["args"], dict) + assert "location" in tool_call["args"] + + # Test streaming + first = True + for chunk in llm_with_tools.stream("what's the weather in san francisco, ca"): + if first: + gathered = chunk + first = False + else: + gathered = gathered + chunk # type: ignore + assert isinstance(gathered, AIMessageChunk) + assert isinstance(gathered.tool_call_chunks, list) + assert len(gathered.tool_call_chunks) == 1 + tool_call_chunk = gathered.tool_call_chunks[0] + assert tool_call_chunk["name"] == "GetWeather" + assert isinstance(tool_call_chunk["args"], str) + assert "location" in json.loads(tool_call_chunk["args"]) + + +@pytest.mark.parametrize("tool_choice", ["GetWeather", "auto", "any"]) +def test_anthropic_bind_tools_tool_choice(tool_choice: str) -> None: + chat = ChatBedrock( + model_id="anthropic.claude-3-sonnet-20240229-v1:0", + model_kwargs={"temperature": 0.001}, + ) # type: ignore[call-arg] + chat_model_with_tools = chat.bind_tools([GetWeather], tool_choice=tool_choice) + response = chat_model_with_tools.invoke("what's the weather in ny and la") + assert isinstance(response, AIMessage) + assert response.tool_calls + tool_call = response.tool_calls[0] + assert tool_call["name"] == "GetWeather" + assert isinstance(tool_call["args"], dict) + assert "location" in tool_call["args"] @pytest.mark.scheduled diff --git a/libs/aws/tests/unit_tests/chat_models/__init__.py b/libs/aws/tests/unit_tests/chat_models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/libs/aws/tests/unit_tests/chat_models/test_bedrock.py b/libs/aws/tests/unit_tests/chat_models/test_bedrock.py new file mode 100644 index 00000000..fd2b7cf6 --- /dev/null +++ b/libs/aws/tests/unit_tests/chat_models/test_bedrock.py @@ -0,0 +1,400 @@ +"""Test chat model integration.""" + +from typing import Any, Callable, Dict, Literal, Type, cast + +import pytest +from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage +from langchain_core.pydantic_v1 import BaseModel, Field +from langchain_core.runnables import RunnableBinding +from langchain_core.tools import BaseTool + +from langchain_aws import ChatBedrock +from langchain_aws.chat_models.bedrock import ( + _format_anthropic_messages, + _merge_messages, +) +from langchain_aws.function_calling import convert_to_anthropic_tool + + +def test__merge_messages() -> None: + messages = [ + SystemMessage("foo"), # type: ignore[misc] + HumanMessage("bar"), # type: ignore[misc] + AIMessage( # type: ignore[misc] + [ + {"text": "baz", "type": "text"}, + { + "tool_input": {"a": "b"}, + "type": "tool_use", + "id": "1", + "text": None, + "name": "buz", + }, + {"text": "baz", "type": "text"}, + { + "tool_input": {"a": "c"}, + "type": "tool_use", + "id": "2", + "text": None, + "name": "blah", + }, + ] + ), + ToolMessage("buz output", tool_call_id="1"), # type: ignore[misc] + ToolMessage("blah output", tool_call_id="2"), # type: ignore[misc] + HumanMessage("next thing"), # type: ignore[misc] + ] + expected = [ + SystemMessage("foo"), # type: ignore[misc] + HumanMessage("bar"), # type: ignore[misc] + AIMessage( # type: ignore[misc] + [ + {"text": "baz", "type": "text"}, + { + "tool_input": {"a": "b"}, + "type": "tool_use", + "id": "1", + "text": None, + "name": "buz", + }, + {"text": "baz", "type": "text"}, + { + "tool_input": {"a": "c"}, + "type": "tool_use", + "id": "2", + "text": None, + "name": "blah", + }, + ] + ), + HumanMessage( # type: ignore[misc] + [ + {"type": "tool_result", "content": "buz output", "tool_use_id": "1"}, + {"type": "tool_result", "content": "blah output", "tool_use_id": "2"}, + {"type": "text", "text": "next thing"}, + ] + ), + ] + actual = _merge_messages(messages) + assert expected == actual + + +def test__merge_messages_mutation() -> None: + original_messages = [ + HumanMessage([{"type": "text", "text": "bar"}]), # type: ignore[misc] + HumanMessage("next thing"), # type: ignore[misc] + ] + messages = [ + HumanMessage([{"type": "text", "text": "bar"}]), # type: ignore[misc] + HumanMessage("next thing"), # type: ignore[misc] + ] + expected = [ + HumanMessage( # type: ignore[misc] + [{"type": "text", "text": "bar"}, {"type": "text", "text": "next thing"}] + ), + ] + actual = _merge_messages(messages) + assert expected == actual + assert messages == original_messages + + +def test__format_anthropic_messages_with_tool_calls() -> None: + system = SystemMessage("fuzz") # type: ignore[misc] + human = HumanMessage("foo") # type: ignore[misc] + ai = AIMessage( # type: ignore[misc] + "", + tool_calls=[{"name": "bar", "id": "1", "args": {"baz": "buzz"}}], + ) + tool = ToolMessage( # type: ignore[misc] + "blurb", + tool_call_id="1", + ) + messages = [system, human, ai, tool] + expected = ( + "fuzz", + [ + {"role": "user", "content": "foo"}, + { + "role": "assistant", + "content": [ + { + "type": "tool_use", + "name": "bar", + "id": "1", + "input": {"baz": "buzz"}, + } + ], + }, + { + "role": "user", + "content": [ + {"type": "tool_result", "content": "blurb", "tool_use_id": "1"} + ], + }, + ], + ) + actual = _format_anthropic_messages(messages) + assert expected == actual + + +def test__format_anthropic_messages_with_str_content_and_tool_calls() -> None: + system = SystemMessage("fuzz") # type: ignore[misc] + human = HumanMessage("foo") # type: ignore[misc] + # If content and tool_calls are specified and content is a string, then both are + # included with content first. + ai = AIMessage( # type: ignore[misc] + "thought", + tool_calls=[{"name": "bar", "id": "1", "args": {"baz": "buzz"}}], + ) + tool = ToolMessage("blurb", tool_call_id="1") # type: ignore[misc] + messages = [system, human, ai, tool] + expected = ( + "fuzz", + [ + {"role": "user", "content": "foo"}, + { + "role": "assistant", + "content": [ + {"type": "text", "text": "thought"}, + { + "type": "tool_use", + "name": "bar", + "id": "1", + "input": {"baz": "buzz"}, + }, + ], + }, + { + "role": "user", + "content": [ + {"type": "tool_result", "content": "blurb", "tool_use_id": "1"} + ], + }, + ], + ) + actual = _format_anthropic_messages(messages) + assert expected == actual + + +def test__format_anthropic_messages_with_list_content_and_tool_calls() -> None: + system = SystemMessage("fuzz") # type: ignore[misc] + human = HumanMessage("foo") # type: ignore[misc] + # If content and tool_calls are specified and content is a list, then content is + # preferred. + ai = AIMessage( # type: ignore[misc] + [{"type": "text", "text": "thought"}], + tool_calls=[{"name": "bar", "id": "1", "args": {"baz": "buzz"}}], + ) + tool = ToolMessage( # type: ignore[misc] + "blurb", + tool_call_id="1", + ) + messages = [system, human, ai, tool] + expected = ( + "fuzz", + [ + {"role": "user", "content": "foo"}, + { + "role": "assistant", + "content": [{"type": "text", "text": "thought"}], + }, + { + "role": "user", + "content": [ + {"type": "tool_result", "content": "blurb", "tool_use_id": "1"} + ], + }, + ], + ) + actual = _format_anthropic_messages(messages) + assert expected == actual + + +def test__format_anthropic_messages_with_tool_use_blocks_and_tool_calls() -> None: + """Show that tool_calls are preferred to tool_use blocks when both have same id.""" + system = SystemMessage("fuzz") # type: ignore[misc] + human = HumanMessage("foo") # type: ignore[misc] + # NOTE: tool_use block in contents and tool_calls have different arguments. + ai = AIMessage( # type: ignore[misc] + [ + {"type": "text", "text": "thought"}, + { + "type": "tool_use", + "name": "bar", + "id": "1", + "input": {"baz": "NOT_BUZZ"}, + }, + ], + tool_calls=[{"name": "bar", "id": "1", "args": {"baz": "BUZZ"}}], + ) + tool = ToolMessage("blurb", tool_call_id="1") # type: ignore[misc] + messages = [system, human, ai, tool] + expected = ( + "fuzz", + [ + {"role": "user", "content": "foo"}, + { + "role": "assistant", + "content": [ + {"type": "text", "text": "thought"}, + { + "type": "tool_use", + "name": "bar", + "id": "1", + "input": {"baz": "BUZZ"}, # tool_calls value preferred. + }, + ], + }, + { + "role": "user", + "content": [ + {"type": "tool_result", "content": "blurb", "tool_use_id": "1"} + ], + }, + ], + ) + actual = _format_anthropic_messages(messages) + assert expected == actual + + +@pytest.fixture() +def pydantic() -> Type[BaseModel]: + class dummy_function(BaseModel): + """dummy function""" + + arg1: int = Field(..., description="foo") + arg2: Literal["bar", "baz"] = Field(..., description="one of 'bar', 'baz'") + + return dummy_function + + +@pytest.fixture() +def function() -> Callable: + def dummy_function(arg1: int, arg2: Literal["bar", "baz"]) -> None: + """dummy function + + Args: + arg1: foo + arg2: one of 'bar', 'baz' + """ + pass + + return dummy_function + + +@pytest.fixture() +def dummy_tool() -> BaseTool: + class Schema(BaseModel): + arg1: int = Field(..., description="foo") + arg2: Literal["bar", "baz"] = Field(..., description="one of 'bar', 'baz'") + + class DummyFunction(BaseTool): + args_schema: Type[BaseModel] = Schema + name: str = "dummy_function" + description: str = "dummy function" + + def _run(self, *args: Any, **kwargs: Any) -> Any: + pass + + return DummyFunction() + + +@pytest.fixture() +def json_schema() -> Dict: + return { + "title": "dummy_function", + "description": "dummy function", + "type": "object", + "properties": { + "arg1": {"description": "foo", "type": "integer"}, + "arg2": { + "description": "one of 'bar', 'baz'", + "enum": ["bar", "baz"], + "type": "string", + }, + }, + "required": ["arg1", "arg2"], + } + + +@pytest.fixture() +def openai_function() -> Dict: + return { + "name": "dummy_function", + "description": "dummy function", + "parameters": { + "type": "object", + "properties": { + "arg1": {"description": "foo", "type": "integer"}, + "arg2": { + "description": "one of 'bar', 'baz'", + "enum": ["bar", "baz"], + "type": "string", + }, + }, + "required": ["arg1", "arg2"], + }, + } + + +def test_convert_to_anthropic_tool( + pydantic: Type[BaseModel], + function: Callable, + dummy_tool: BaseTool, + json_schema: Dict, + openai_function: Dict, +) -> None: + expected = { + "name": "dummy_function", + "description": "dummy function", + "input_schema": { + "type": "object", + "properties": { + "arg1": {"description": "foo", "type": "integer"}, + "arg2": { + "description": "one of 'bar', 'baz'", + "enum": ["bar", "baz"], + "type": "string", + }, + }, + "required": ["arg1", "arg2"], + }, + } + + for fn in (pydantic, function, dummy_tool, json_schema, expected, openai_function): + actual = convert_to_anthropic_tool(fn) # type: ignore[arg-type] + assert actual == expected + + +class GetWeather(BaseModel): + """Get the current weather in a given location""" + + location: str = Field(..., description="The city and state, e.g. San Francisco, CA") + + +def test_anthropic_bind_tools_tool_choice() -> None: + chat_model = ChatBedrock( + model_id="anthropic.claude-3-opus-20240229", region_name="us-west-2" + ) # type: ignore[call-arg] + chat_model_with_tools = chat_model.bind_tools( + [GetWeather], tool_choice={"type": "tool", "name": "GetWeather"} + ) + assert cast(RunnableBinding, chat_model_with_tools).kwargs["tool_choice"] == { + "type": "tool", + "name": "GetWeather", + } + chat_model_with_tools = chat_model.bind_tools( + [GetWeather], tool_choice="GetWeather" + ) + assert cast(RunnableBinding, chat_model_with_tools).kwargs["tool_choice"] == { + "type": "tool", + "name": "GetWeather", + } + chat_model_with_tools = chat_model.bind_tools([GetWeather], tool_choice="auto") + assert cast(RunnableBinding, chat_model_with_tools).kwargs["tool_choice"] == { + "type": "auto" + } + chat_model_with_tools = chat_model.bind_tools([GetWeather], tool_choice="any") + assert cast(RunnableBinding, chat_model_with_tools).kwargs["tool_choice"] == { + "type": "any" + }