Skip to content

Commit

Permalink
Merge pull request #129 from filip-michalsky/add_bedrock
Browse files Browse the repository at this point in the history
add bedrock
  • Loading branch information
filip-michalsky authored Mar 26, 2024
2 parents bbd9b0f + 18c57c2 commit 4c2b621
Show file tree
Hide file tree
Showing 13 changed files with 306 additions and 38 deletions.
4 changes: 3 additions & 1 deletion .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,6 @@ PRODUCT_CATALOG=examples/sample_product_catalog.txt
PRODUCT_PRICE_MAPPING=examples/example_product_price_id_mapping.json
GPT_MODEL=gpt-3.5-turbo-0613
USE_TOOLS_IN_API=True

AWS_ACCESS_KEY_ID=xx
AWS_SECRET_ACCESS_KEY=xx
AWS_REGION_NAME=xx
8 changes: 8 additions & 0 deletions .github/workflows/poetry_unit_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,15 @@ jobs:
- name: Run Unit Tests
env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }}
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
AWS_REGION_NAME: ${{ secrets.AWS_REGION_NAME }}
AWS_DEFAULT_REGION: ${{ secrets.AWS_DEFAULT_REGION }}
run: |
export OPENAI_API_KEY=$OPENAI_API_KEY
export AWS_ACCESS_KEY_ID=$AWS_ACCESS_KEY_ID
export AWS_SECRET_ACCESS_KEY=$AWS_SECRET_ACCESS_KEY
export AWS_DEFAULT_REGION=$AWS_DEFAULT_REGION
export AWS_REGION_NAME=$AWS_REGION_NAME
make test # Executing tests with Poetry
71 changes: 70 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ pytest-cov = "^4.1.0"
pytest-asyncio = "^0.23.1"
langchain-openai = "0.0.2"
tokenizers = "^0.15.2"
boto3 = "^1.34.70"

[tool.poetry.group.dev.dependencies]
black = "^23.11.0"
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@ tiktoken>=0.5.2
pydantic>=2.5.2
litellm>=1.10.2
ipykernel>=6.27.1
langchain-openai==0.0.2
langchain-openai==0.0.2
boto3
4 changes: 3 additions & 1 deletion run_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,17 @@ class MessageList(BaseModel):

@app.get("/botname")
async def get_bot_name():
load_dotenv()
sales_api = SalesGPTAPI(
config_path=os.getenv("CONFIG_PATH", "examples/example_agent_setup.json"),
product_catalog=os.getenv(
"PRODUCT_CATALOG", "examples/sample_product_catalog.txt"
),
verbose=True,
model_name=os.getenv("GPT_MODEL", "gpt-3.5-turbo-0613"),
)
name = sales_api.sales_agent.salesperson_name
return {"name": name}
return {"name": name, "model": sales_api.sales_agent.model_name}


@app.post("/chat")
Expand Down
13 changes: 9 additions & 4 deletions salesgpt/agents.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
from copy import deepcopy
from typing import Any, Callable, Dict, List, Union

from langchain.agents import (AgentExecutor, LLMSingleActionAgent,
create_openai_tools_agent)
from langchain.agents import (
AgentExecutor,
LLMSingleActionAgent,
create_openai_tools_agent,
)
from langchain.chains import LLMChain, RetrievalQA
from langchain.chains.base import Chain
from langchain_community.chat_models import ChatLiteLLM
from langchain_core.agents import (_convert_agent_action_to_messages,
_convert_agent_observation_to_messages)
from langchain_core.agents import (
_convert_agent_action_to_messages,
_convert_agent_observation_to_messages,
)
from langchain_core.language_models.llms import create_base_retry_decorator
from litellm import acompletion
from pydantic import Field
Expand Down
6 changes: 4 additions & 2 deletions salesgpt/chains.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
from langchain_community.chat_models import ChatLiteLLM

from salesgpt.logger import time_logger
from salesgpt.prompts import (SALES_AGENT_INCEPTION_PROMPT,
STAGE_ANALYZER_INCEPTION_PROMPT)
from salesgpt.prompts import (
SALES_AGENT_INCEPTION_PROMPT,
STAGE_ANALYZER_INCEPTION_PROMPT,
)


class StageAnalyzerChain(LLMChain):
Expand Down
73 changes: 73 additions & 0 deletions salesgpt/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional

from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models import BaseChatModel, SimpleChatModel
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage, HumanMessage
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.runnables import run_in_executor
from langchain_openai import ChatOpenAI

from salesgpt.tools import completion_bedrock


class BedrockCustomModel(ChatOpenAI):
"""A custom chat model that echoes the first `n` characters of the input.
When contributing an implementation to LangChain, carefully document
the model including the initialization parameters, include
an example of how to initialize the model and include any relevant
links to the underlying models documentation or API.
Example:
.. code-block:: python
model = CustomChatModel(n=2)
result = model.invoke([HumanMessage(content="hello")])
result = model.batch([[HumanMessage(content="hello")],
[HumanMessage(content="world")]])
"""

model: str
system_prompt: str
"""The number of characters from the last message of the prompt to be echoed."""

def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
"""Override the _generate method to implement the chat model logic.
This can be a call to an API, a call to a local model, or any other
implementation that generates a response to the input prompt.
Args:
messages: the prompt composed of a list of messages.
stop: a list of strings on which the model should stop generating.
If generation stops due to a stop token, the stop token itself
SHOULD BE INCLUDED as part of the output. This is not enforced
across models right now, but it's a good practice to follow since
it makes it much easier to parse the output of the model
downstream and understand why generation stopped.
run_manager: A run manager with callbacks for the LLM.
"""
last_message = messages[-1]

print(messages)
response = completion_bedrock(
model_id=self.model,
system_prompt=self.system_prompt,
messages=[{"content": last_message.content, "role": "user"}],
max_tokens=1000,
)
print("output", response)
content = response["content"][0]["text"]
message = AIMessage(content=content)
generation = ChatGeneration(message=message)
return ChatResult(generations=[generation])
15 changes: 13 additions & 2 deletions salesgpt/salesgptapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
import json
import re

from langchain_community.chat_models import ChatLiteLLM
from langchain_community.chat_models import BedrockChat, ChatLiteLLM
from langchain_openai import ChatOpenAI

from salesgpt.agents import SalesGPT
from salesgpt.models import BedrockCustomModel


class SalesGPTAPI:
Expand All @@ -20,7 +22,15 @@ def __init__(
self.config_path = config_path
self.verbose = verbose
self.max_num_turns = max_num_turns
self.llm = ChatLiteLLM(temperature=0.2, model_name=model_name)
self.model_name = model_name
if "anthropic" in model_name:
self.llm = BedrockCustomModel(
type="bedrock-model",
model=model_name,
system_prompt="You are a helpful assistant.",
)
else:
self.llm = ChatLiteLLM(temperature=0.2, model=model_name)
self.product_catalog = product_catalog
self.conversation_history = []
self.use_tools = use_tools
Expand Down Expand Up @@ -131,6 +141,7 @@ def do(self, human_input=None):
"tool_input": tool_input,
"action_output": action_output,
"action_input": action_input,
"model_name": self.model_name,
}
return payload

Expand Down
Loading

0 comments on commit 4c2b621

Please sign in to comment.