-
Notifications
You must be signed in to change notification settings - Fork 468
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #129 from filip-michalsky/add_bedrock
add bedrock
- Loading branch information
Showing
13 changed files
with
306 additions
and
38 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional | ||
|
||
from langchain_core.callbacks import ( | ||
AsyncCallbackManagerForLLMRun, | ||
CallbackManagerForLLMRun, | ||
) | ||
from langchain_core.language_models import BaseChatModel, SimpleChatModel | ||
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage, HumanMessage | ||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult | ||
from langchain_core.runnables import run_in_executor | ||
from langchain_openai import ChatOpenAI | ||
|
||
from salesgpt.tools import completion_bedrock | ||
|
||
|
||
class BedrockCustomModel(ChatOpenAI): | ||
"""A custom chat model that echoes the first `n` characters of the input. | ||
When contributing an implementation to LangChain, carefully document | ||
the model including the initialization parameters, include | ||
an example of how to initialize the model and include any relevant | ||
links to the underlying models documentation or API. | ||
Example: | ||
.. code-block:: python | ||
model = CustomChatModel(n=2) | ||
result = model.invoke([HumanMessage(content="hello")]) | ||
result = model.batch([[HumanMessage(content="hello")], | ||
[HumanMessage(content="world")]]) | ||
""" | ||
|
||
model: str | ||
system_prompt: str | ||
"""The number of characters from the last message of the prompt to be echoed.""" | ||
|
||
def _generate( | ||
self, | ||
messages: List[BaseMessage], | ||
stop: Optional[List[str]] = None, | ||
run_manager: Optional[CallbackManagerForLLMRun] = None, | ||
**kwargs: Any, | ||
) -> ChatResult: | ||
"""Override the _generate method to implement the chat model logic. | ||
This can be a call to an API, a call to a local model, or any other | ||
implementation that generates a response to the input prompt. | ||
Args: | ||
messages: the prompt composed of a list of messages. | ||
stop: a list of strings on which the model should stop generating. | ||
If generation stops due to a stop token, the stop token itself | ||
SHOULD BE INCLUDED as part of the output. This is not enforced | ||
across models right now, but it's a good practice to follow since | ||
it makes it much easier to parse the output of the model | ||
downstream and understand why generation stopped. | ||
run_manager: A run manager with callbacks for the LLM. | ||
""" | ||
last_message = messages[-1] | ||
|
||
print(messages) | ||
response = completion_bedrock( | ||
model_id=self.model, | ||
system_prompt=self.system_prompt, | ||
messages=[{"content": last_message.content, "role": "user"}], | ||
max_tokens=1000, | ||
) | ||
print("output", response) | ||
content = response["content"][0]["text"] | ||
message = AIMessage(content=content) | ||
generation = ChatGeneration(message=message) | ||
return ChatResult(generations=[generation]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.