Skip to content

Commit

Permalink
aws[major]: release 0.2 (#182)
Browse files Browse the repository at this point in the history
As part of upcoming langchain-core 0.3 release

---------

Co-authored-by: Bagatur <[email protected]>
  • Loading branch information
ccurme and baskaryan authored Sep 13, 2024
1 parent fc72065 commit 4ca387a
Show file tree
Hide file tree
Showing 25 changed files with 957 additions and 849 deletions.
3 changes: 3 additions & 0 deletions libs/aws/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ integration_test integration_tests: TEST_FILE = tests/integration_tests/
test tests integration_test integration_tests:
poetry run pytest $(TEST_FILE)

test_watch:
poetry run ptw --snapshot-update --now . -- -vv $(TEST_FILE)

######################
# LINTING AND FORMATTING
######################
Expand Down
6 changes: 2 additions & 4 deletions libs/aws/langchain_aws/__init__.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,16 @@
from langchain_aws.chat_models import BedrockChat, ChatBedrock, ChatBedrockConverse
from langchain_aws.chat_models import ChatBedrock, ChatBedrockConverse
from langchain_aws.embeddings import BedrockEmbeddings
from langchain_aws.graphs import NeptuneAnalyticsGraph, NeptuneGraph
from langchain_aws.llms import Bedrock, BedrockLLM, SagemakerEndpoint
from langchain_aws.llms import BedrockLLM, SagemakerEndpoint
from langchain_aws.retrievers import (
AmazonKendraRetriever,
AmazonKnowledgeBasesRetriever,
)
from langchain_aws.vectorstores.inmemorydb import InMemoryVectorStore

__all__ = [
"Bedrock",
"BedrockEmbeddings",
"BedrockLLM",
"BedrockChat",
"ChatBedrock",
"ChatBedrockConverse",
"SagemakerEndpoint",
Expand Down
7 changes: 4 additions & 3 deletions libs/aws/langchain_aws/agents/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@
from langchain_core.callbacks import CallbackManager
from langchain_core.load import dumpd
from langchain_core.messages import AIMessage
from langchain_core.pydantic_v1 import root_validator
from langchain_core.runnables import RunnableConfig, RunnableSerializable, ensure_config
from langchain_core.tools import BaseTool
from pydantic import model_validator

_DEFAULT_ACTION_GROUP_NAME = "DEFAULT_AG_"
_TEST_AGENT_ALIAS_ID = "TSTALIASID"
Expand Down Expand Up @@ -329,8 +329,9 @@ class BedrockAgentsRunnable(RunnableSerializable[Dict, OutputType]):
endpoint_url: Optional[str] = None
"""Endpoint URL"""

@root_validator(skip_on_failure=True)
def validate_agent(cls, values: dict) -> dict:
@model_validator(mode="before")
@classmethod
def validate_agent(cls, values: dict) -> Any:
if values.get("client") is not None:
return values

Expand Down
4 changes: 2 additions & 2 deletions libs/aws/langchain_aws/chat_models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from langchain_aws.chat_models.bedrock import BedrockChat, ChatBedrock
from langchain_aws.chat_models.bedrock import ChatBedrock
from langchain_aws.chat_models.bedrock_converse import ChatBedrockConverse

__all__ = ["BedrockChat", "ChatBedrock", "ChatBedrockConverse"]
__all__ = ["ChatBedrock", "ChatBedrockConverse"]
21 changes: 7 additions & 14 deletions libs/aws/langchain_aws/chat_models/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
cast,
)

from langchain_core._api.deprecation import deprecated
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models import (
BaseChatModel,
Expand All @@ -34,10 +33,10 @@
from langchain_core.messages.ai import UsageMetadata
from langchain_core.messages.tool import ToolCall, ToolMessage
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
from langchain_core.tools import BaseTool
from langchain_core.utils.pydantic import TypeBaseModel, is_basemodel_subclass
from pydantic import BaseModel, ConfigDict

from langchain_aws.chat_models.bedrock_converse import ChatBedrockConverse
from langchain_aws.function_calling import (
Expand Down Expand Up @@ -214,7 +213,7 @@ def _merge_messages(
"""Merge runs of human/tool messages into single human messages with content blocks.""" # noqa: E501
merged: list = []
for curr in messages:
curr = curr.copy(deep=True)
curr = curr.model_copy(deep=True)
if isinstance(curr, ToolMessage):
if isinstance(curr.content, list) and all(
isinstance(block, dict) and block.get("type") == "tool_result"
Expand Down Expand Up @@ -417,10 +416,9 @@ def lc_attributes(self) -> Dict[str, Any]:

return attributes

class Config:
"""Configuration for this pydantic object."""

extra = "forbid"
model_config = ConfigDict(
extra="forbid",
)

def _get_ls_params(
self, stop: Optional[List[str]] = None, **kwargs: Any
Expand Down Expand Up @@ -708,7 +706,7 @@ def with_structured_output(
.. code-block:: python
from langchain_aws.chat_models.bedrock import ChatBedrock
from langchain_core.pydantic_v1 import BaseModel
from pydantic import BaseModel
class AnswerWithJustification(BaseModel):
'''An answer to the user question along with justification for the answer.'''
Expand All @@ -732,7 +730,7 @@ class AnswerWithJustification(BaseModel):
.. code-block:: python
from langchain_aws.chat_models.bedrock import ChatBedrock
from langchain_core.pydantic_v1 import BaseModel
from pydantic import BaseModel
class AnswerWithJustification(BaseModel):
'''An answer to the user question along with justification for the answer.'''
Expand Down Expand Up @@ -829,8 +827,3 @@ def _as_converse(self) -> ChatBedrockConverse:
guardrail_config=(self.guardrails if self._guardrails_enabled else None), # type: ignore[call-arg]
**kwargs,
)


@deprecated(since="0.1.0", removal="0.2.0", alternative="ChatBedrock")
class BedrockChat(ChatBedrock):
pass
67 changes: 33 additions & 34 deletions libs/aws/langchain_aws/chat_models/bedrock_converse.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,15 @@
from langchain_core.output_parsers import JsonOutputKeyToolsParser, PydanticToolsParser
from langchain_core.output_parsers.base import OutputParserLike
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
from langchain_core.tools import BaseTool
from langchain_core.utils.function_calling import (
convert_to_openai_function,
convert_to_openai_tool,
)
from langchain_core.utils.pydantic import TypeBaseModel, is_basemodel_subclass
from pydantic import BaseModel, ConfigDict, Field, model_validator
from typing_extensions import Self

from langchain_aws.function_calling import ToolsOutputParser

Expand Down Expand Up @@ -152,7 +153,7 @@ class ChatBedrockConverse(BaseChatModel):
Tool calling:
.. code-block:: python
from langchain_core.pydantic_v1 import BaseModel, Field
from pydantic import BaseModel, Field
class GetWeather(BaseModel):
'''Get the current weather in a given location'''
Expand Down Expand Up @@ -190,7 +191,7 @@ class GetPopulation(BaseModel):
from typing import Optional
from langchain_core.pydantic_v1 import BaseModel, Field
from pydantic import BaseModel, Field
class Joke(BaseModel):
'''Joke to tell user.'''
Expand Down Expand Up @@ -352,14 +353,14 @@ class Joke(BaseModel):
model is used, ('auto', 'any') if a 'mistral-large' model is used, empty otherwise.
"""

class Config:
"""Configuration for this pydantic object."""

extra = "forbid"
allow_population_by_field_name = True
model_config = ConfigDict(
extra="forbid",
populate_by_name=True,
)

@root_validator(pre=True)
def set_disable_streaming(cls, values: Dict) -> Dict:
@model_validator(mode="before")
@classmethod
def set_disable_streaming(cls, values: Dict) -> Any:
values["provider"] = (
values.get("provider")
or (values.get("model_id", values["model"])).split(".")[0]
Expand All @@ -372,15 +373,15 @@ def set_disable_streaming(cls, values: Dict) -> Dict:
)
return values

@root_validator(pre=False, skip_on_failure=True)
def validate_environment(cls, values: Dict) -> Dict:
@model_validator(mode="after")
def validate_environment(self) -> Self:
"""Validate that AWS credentials to and python package exists in environment."""
if values["client"] is not None:
return values
if self.client is not None:
return self

try:
if values["credentials_profile_name"] is not None:
session = boto3.Session(profile_name=values["credentials_profile_name"])
if self.credentials_profile_name is not None:
session = boto3.Session(profile_name=self.credentials_profile_name)
else:
session = boto3.Session()
except ValueError as e:
Expand All @@ -392,22 +393,20 @@ def validate_environment(cls, values: Dict) -> Dict:
f"profile name are valid. Bedrock error: {e}"
) from e

values["region_name"] = (
values.get("region_name")
or os.getenv("AWS_DEFAULT_REGION")
or session.region_name
self.region_name = (
self.region_name or os.getenv("AWS_DEFAULT_REGION") or session.region_name
)

client_params = {}
if values["region_name"]:
client_params["region_name"] = values["region_name"]
if values["endpoint_url"]:
client_params["endpoint_url"] = values["endpoint_url"]
if values["config"]:
client_params["config"] = values["config"]
if self.region_name:
client_params["region_name"] = self.region_name
if self.endpoint_url:
client_params["endpoint_url"] = self.endpoint_url
if self.config:
client_params["config"] = self.config

try:
values["client"] = session.client("bedrock-runtime", **client_params)
self.client = session.client("bedrock-runtime", **client_params)
except ValueError as e:
raise ValueError(f"Error raised by bedrock service: {e}")
except Exception as e:
Expand All @@ -419,15 +418,15 @@ def validate_environment(cls, values: Dict) -> Dict:

# As of 08/05/24 only claude-3 and mistral-large models support tool choice:
# https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolChoice.html
if values["supports_tool_choice_values"] is None:
if "claude-3" in values["model_id"]:
values["supports_tool_choice_values"] = ("auto", "any", "tool")
elif "mistral-large" in values["model_id"]:
values["supports_tool_choice_values"] = ("auto", "any")
if self.supports_tool_choice_values is None:
if "claude-3" in self.model_id:
self.supports_tool_choice_values = ("auto", "any", "tool")
elif "mistral-large" in self.model_id:
self.supports_tool_choice_values = ("auto", "any")
else:
values["supports_tool_choice_values"] = ()
self.supports_tool_choice_values = ()

return values
return self

def _generate(
self,
Expand Down
41 changes: 21 additions & 20 deletions libs/aws/langchain_aws/embeddings/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@

import numpy as np
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel, root_validator
from langchain_core.runnables.config import run_in_executor
from pydantic import BaseModel, ConfigDict, Field, model_validator
from typing_extensions import Self


class BedrockEmbeddings(BaseModel, Embeddings):
Expand Down Expand Up @@ -40,7 +41,7 @@ class BedrockEmbeddings(BaseModel, Embeddings):
)
"""

client: Any #: :meta private:
client: Any = Field(default=None, exclude=True) #: :meta private:
"""Bedrock client."""
region_name: Optional[str] = None
"""The aws region e.g., `us-west-2`. Fallsback to AWS_DEFAULT_REGION env variable
Expand Down Expand Up @@ -71,38 +72,38 @@ class BedrockEmbeddings(BaseModel, Embeddings):
config: Any = None
"""An optional botocore.config.Config instance to pass to the client."""

class Config:
"""Configuration for this pydantic object."""
model_config = ConfigDict(
extra="forbid",
protected_namespaces=(),
)

extra = "forbid"

@root_validator(pre=False, skip_on_failure=True)
def validate_environment(cls, values: Dict) -> Dict:
@model_validator(mode="after")
def validate_environment(self) -> Self:
"""Validate that AWS credentials to and python package exists in environment."""

if values["client"] is not None:
return values
if self.client is not None:
return self

try:
import boto3

if values["credentials_profile_name"] is not None:
session = boto3.Session(profile_name=values["credentials_profile_name"])
if self.credentials_profile_name is not None:
session = boto3.Session(profile_name=self.credentials_profile_name)
else:
# use default credentials
session = boto3.Session()

client_params = {}
if values["region_name"]:
client_params["region_name"] = values["region_name"]
if self.region_name:
client_params["region_name"] = self.region_name

if values["endpoint_url"]:
client_params["endpoint_url"] = values["endpoint_url"]
if self.endpoint_url:
client_params["endpoint_url"] = self.endpoint_url

if values["config"]:
client_params["config"] = values["config"]
if self.config:
client_params["config"] = self.config

values["client"] = session.client("bedrock-runtime", **client_params)
self.client = session.client("bedrock-runtime", **client_params)

except ImportError:
raise ModuleNotFoundError(
Expand All @@ -116,7 +117,7 @@ def validate_environment(cls, values: Dict) -> Dict:
f"profile name are valid. Bedrock error: {e}"
) from e

return values
return self

def _embedding_func(self, text: str) -> List[float]:
"""Call out to Bedrock embedding endpoint."""
Expand Down
10 changes: 6 additions & 4 deletions libs/aws/langchain_aws/function_calling.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import json
from typing import (
Annotated,
Any,
Callable,
Dict,
Expand All @@ -17,10 +18,10 @@
from langchain_core.output_parsers import BaseGenerationOutputParser
from langchain_core.outputs import ChatGeneration, Generation
from langchain_core.prompts.chat import AIMessage
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.tools import BaseTool
from langchain_core.utils.function_calling import convert_to_openai_tool
from langchain_core.utils.pydantic import TypeBaseModel
from pydantic import BaseModel, ConfigDict, SkipValidation
from typing_extensions import TypedDict

PYTHON_TO_JSON_TYPES = {
Expand Down Expand Up @@ -160,10 +161,11 @@ class ToolDescription(TypedDict):
class ToolsOutputParser(BaseGenerationOutputParser):
first_tool_only: bool = False
args_only: bool = False
pydantic_schemas: Optional[List[TypeBaseModel]] = None
pydantic_schemas: Optional[List[Annotated[TypeBaseModel, SkipValidation()]]] = None

class Config:
extra = "forbid"
model_config = ConfigDict(
extra="forbid",
)

def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
"""Parse a list of candidate model Generations into a specific format.
Expand Down
2 changes: 0 additions & 2 deletions libs/aws/langchain_aws/llms/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from langchain_aws.llms.bedrock import (
ALTERNATION_ERROR,
Bedrock,
BedrockBase,
BedrockLLM,
LLMInputOutputAdapter,
Expand All @@ -9,7 +8,6 @@

__all__ = [
"ALTERNATION_ERROR",
"Bedrock",
"BedrockBase",
"BedrockLLM",
"LLMInputOutputAdapter",
Expand Down
Loading

0 comments on commit 4ca387a

Please sign in to comment.