Skip to content

Commit

Permalink
Merge pull request #70 from laithalsaadoon/feat/claude-3-tool-calling
Browse files Browse the repository at this point in the history
Feat: claude 3 tool calling
  • Loading branch information
baskaryan authored Jun 12, 2024
2 parents 1bdbe20 + 15c93ed commit 99409d8
Show file tree
Hide file tree
Showing 8 changed files with 972 additions and 78 deletions.
336 changes: 295 additions & 41 deletions libs/aws/langchain_aws/chat_models/bedrock.py

Large diffs are not rendered by default.

83 changes: 83 additions & 0 deletions libs/aws/langchain_aws/function_calling.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,16 @@
Dict,
List,
Literal,
Optional,
Type,
Union,
cast,
)

from langchain_core.messages import ToolCall
from langchain_core.output_parsers import BaseGenerationOutputParser
from langchain_core.outputs import ChatGeneration, Generation
from langchain_core.prompts.chat import AIMessage
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.tools import BaseTool
from langchain_core.utils.function_calling import convert_to_openai_tool
Expand Down Expand Up @@ -63,6 +69,35 @@ class AnthropicTool(TypedDict):
input_schema: Dict[str, Any]


def _tools_in_params(params: dict) -> bool:
return "tools" in params or (
"extra_body" in params and params["extra_body"].get("tools")
)


class _AnthropicToolUse(TypedDict):
type: Literal["tool_use"]
name: str
input: dict
id: str


def _lc_tool_calls_to_anthropic_tool_use_blocks(
tool_calls: List[ToolCall],
) -> List[_AnthropicToolUse]:
blocks = []
for tool_call in tool_calls:
blocks.append(
_AnthropicToolUse(
type="tool_use",
name=tool_call["name"],
input=tool_call["args"],
id=cast(str, tool_call["id"]),
)
)
return blocks


def _get_type(parameter: Dict[str, Any]) -> str:
if "type" in parameter:
return parameter["type"]
Expand Down Expand Up @@ -122,6 +157,54 @@ class ToolDescription(TypedDict):
function: FunctionDescription


class ToolsOutputParser(BaseGenerationOutputParser):
first_tool_only: bool = False
args_only: bool = False
pydantic_schemas: Optional[List[Type[BaseModel]]] = None

class Config:
extra = "forbid"

def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
"""Parse a list of candidate model Generations into a specific format.
Args:
result: A list of Generations to be parsed. The Generations are assumed
to be different candidate outputs for a single model input.
Returns:
Structured output.
"""
if not result or not isinstance(result[0], ChatGeneration):
return None if self.first_tool_only else []
message = result[0].message
if len(message.content) > 0:
tool_calls: List = []
else:
content = cast(AIMessage, message)
_tool_calls = [dict(tc) for tc in content.tool_calls]
# Map tool call id to index
id_to_index = {block["id"]: i for i, block in enumerate(_tool_calls)}
tool_calls = [{**tc, "index": id_to_index[tc["id"]]} for tc in _tool_calls]
if self.pydantic_schemas:
tool_calls = [self._pydantic_parse(tc) for tc in tool_calls]
elif self.args_only:
tool_calls = [tc["args"] for tc in tool_calls]
else:
pass

if self.first_tool_only:
return tool_calls[0] if tool_calls else None
else:
return [tool_call for tool_call in tool_calls]

def _pydantic_parse(self, tool_call: dict) -> BaseModel:
cls_ = {schema.__name__: schema for schema in self.pydantic_schemas or []}[
tool_call["name"]
]
return cls_(**tool_call["args"])


def convert_to_anthropic_tool(
tool: Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool],
) -> AnthropicTool:
Expand Down
117 changes: 96 additions & 21 deletions libs/aws/langchain_aws/llms/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
Mapping,
Optional,
Tuple,
TypedDict,
Union,
)

Expand All @@ -21,10 +22,12 @@
CallbackManagerForLLMRun,
)
from langchain_core.language_models import LLM, BaseLanguageModel
from langchain_core.messages import ToolCall
from langchain_core.outputs import Generation, GenerationChunk, LLMResult
from langchain_core.pydantic_v1 import Extra, Field, root_validator
from langchain_core.utils import get_from_dict_or_env

from langchain_aws.function_calling import _tools_in_params
from langchain_aws.utils import (
enforce_stop_tokens,
get_num_tokens_anthropic,
Expand Down Expand Up @@ -81,7 +84,10 @@ def _human_assistant_format(input_text: str) -> str:


def _stream_response_to_generation_chunk(
stream_response: Dict[str, Any], provider: str, output_key: str, messages_api: bool
stream_response: Dict[str, Any],
provider: str,
output_key: str,
messages_api: bool,
) -> Union[GenerationChunk, None]:
"""Convert a stream response to a generation chunk."""
if messages_api:
Expand Down Expand Up @@ -174,6 +180,23 @@ def _combine_generation_info_for_llm_result(
return {"usage": total_usage_info, "stop_reason": stop_reason}


def extract_tool_calls(content: List[dict]) -> List[ToolCall]:
tool_calls = []
for block in content:
if block["type"] != "tool_use":
continue
tool_calls.append(
ToolCall(name=block["name"], args=block["input"], id=block["id"])
)
return tool_calls


class AnthropicTool(TypedDict):
name: str
description: str
input_schema: Dict[str, Any]


class LLMInputOutputAdapter:
"""Adapter class to prepare the inputs from Langchain to a format
that LLM model expects.
Expand All @@ -197,10 +220,13 @@ def prepare_input(
prompt: Optional[str] = None,
system: Optional[str] = None,
messages: Optional[List[Dict]] = None,
tools: Optional[List[AnthropicTool]] = None,
) -> Dict[str, Any]:
input_body = {**model_kwargs}
if provider == "anthropic":
if messages:
if tools:
input_body["tools"] = tools
input_body["anthropic_version"] = "bedrock-2023-05-31"
input_body["messages"] = messages
if system:
Expand All @@ -225,16 +251,20 @@ def prepare_input(
@classmethod
def prepare_output(cls, provider: str, response: Any) -> dict:
text = ""
tool_calls = []
response_body = json.loads(response.get("body").read().decode())

if provider == "anthropic":
response_body = json.loads(response.get("body").read().decode())
if "completion" in response_body:
text = response_body.get("completion")
elif "content" in response_body:
content = response_body.get("content")
text = content[0].get("text")
else:
response_body = json.loads(response.get("body").read())
if len(content) == 1 and content[0]["type"] == "text":
text = content[0]["text"]
elif any(block["type"] == "tool_use" for block in content):
tool_calls = extract_tool_calls(content)

else:
if provider == "ai21":
text = response_body.get("completions")[0].get("data").get("text")
elif provider == "cohere":
Expand All @@ -251,6 +281,7 @@ def prepare_output(cls, provider: str, response: Any) -> dict:
completion_tokens = int(headers.get("x-amzn-bedrock-output-token-count", 0))
return {
"text": text,
"tool_calls": tool_calls,
"body": response_body,
"usage": {
"prompt_tokens": prompt_tokens,
Expand Down Expand Up @@ -584,19 +615,32 @@ def _prepare_input_and_invoke(
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Tuple[str, Dict[str, Any]]:
) -> Tuple[
str,
List[dict],
Dict[str, Any],
]:
_model_kwargs = self.model_kwargs or {}

provider = self._get_provider()
params = {**_model_kwargs, **kwargs}

input_body = LLMInputOutputAdapter.prepare_input(
provider=provider,
model_kwargs=params,
prompt=prompt,
system=system,
messages=messages,
)
if "claude-3" in self._get_model():
if _tools_in_params(params):
input_body = LLMInputOutputAdapter.prepare_input(
provider=provider,
model_kwargs=params,
prompt=prompt,
system=system,
messages=messages,
tools=params["tools"],
)
body = json.dumps(input_body)
accept = "application/json"
contentType = "application/json"
Expand All @@ -621,9 +665,13 @@ def _prepare_input_and_invoke(
try:
response = self.client.invoke_model(**request_options)

text, body, usage_info, stop_reason = LLMInputOutputAdapter.prepare_output(
provider, response
).values()
(
text,
tool_calls,
body,
usage_info,
stop_reason,
) = LLMInputOutputAdapter.prepare_output(provider, response).values()

except Exception as e:
raise ValueError(f"Error raised by bedrock service: {e}")
Expand All @@ -646,7 +694,7 @@ def _prepare_input_and_invoke(
**services_trace,
)

return text, llm_output
return text, tool_calls, llm_output

def _get_bedrock_services_signal(self, body: dict) -> dict:
"""
Expand Down Expand Up @@ -711,6 +759,16 @@ def _prepare_input_and_invoke_stream(
messages=messages,
model_kwargs=params,
)
if "claude-3" in self._get_model():
if _tools_in_params(params):
input_body = LLMInputOutputAdapter.prepare_input(
provider=provider,
model_kwargs=params,
prompt=prompt,
system=system,
messages=messages,
tools=params["tools"],
)
body = json.dumps(input_body)

request_options = {
Expand All @@ -737,7 +795,10 @@ def _prepare_input_and_invoke_stream(
raise ValueError(f"Error raised by bedrock service: {e}")

for chunk in LLMInputOutputAdapter.prepare_output_stream(
provider, response, stop, True if messages else False
provider,
response,
stop,
True if messages else False,
):
yield chunk
# verify and raise callback error if any middleware intervened
Expand Down Expand Up @@ -770,13 +831,24 @@ async def _aprepare_input_and_invoke_stream(
_model_kwargs["stream"] = True

params = {**_model_kwargs, **kwargs}
input_body = LLMInputOutputAdapter.prepare_input(
provider=provider,
prompt=prompt,
system=system,
messages=messages,
model_kwargs=params,
)
if "claude-3" in self._get_model():
if _tools_in_params(params):
input_body = LLMInputOutputAdapter.prepare_input(
provider=provider,
model_kwargs=params,
prompt=prompt,
system=system,
messages=messages,
tools=params["tools"],
)
else:
input_body = LLMInputOutputAdapter.prepare_input(
provider=provider,
prompt=prompt,
system=system,
messages=messages,
model_kwargs=params,
)
body = json.dumps(input_body)

response = await asyncio.get_running_loop().run_in_executor(
Expand All @@ -790,7 +862,10 @@ async def _aprepare_input_and_invoke_stream(
)

async for chunk in LLMInputOutputAdapter.aprepare_output_stream(
provider, response, stop, True if messages else False
provider,
response,
stop,
True if messages else False,
):
yield chunk
if run_manager is not None and asyncio.iscoroutinefunction(
Expand Down Expand Up @@ -951,7 +1026,7 @@ def _call(

return completion

text, llm_output = self._prepare_input_and_invoke(
text, tool_calls, llm_output = self._prepare_input_and_invoke(
prompt=prompt, stop=stop, run_manager=run_manager, **kwargs
)
if run_manager is not None:
Expand Down
Loading

0 comments on commit 99409d8

Please sign in to comment.