Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove llm class, add stop_sequences to Chat #4

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion libs/community/langchain_community/llms/cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ async def _completion_with_retry(**kwargs: Any) -> Any:


@deprecated(
since="0.0.30", removal="0.2.0", alternative_import="langchain_cohere.BaseCohere"
since="0.0.30", removal="0.2.0"
)
class BaseCohere(Serializable):
"""Base class for Cohere models."""
Expand Down
1 change: 0 additions & 1 deletion libs/partners/cohere/langchain_cohere/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

__all__ = [
"ChatCohere",
"CohereVectorStore",
"CohereEmbeddings",
"CohereRagRetriever",
"CohereRerank",
Expand Down
105 changes: 69 additions & 36 deletions libs/partners/cohere/langchain_cohere/chat_models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional

import cohere
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
Expand All @@ -9,6 +10,7 @@
agenerate_from_stream,
generate_from_stream,
)
from langchain_core.load.serializable import Serializable
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
Expand All @@ -18,30 +20,8 @@
SystemMessage,
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult

from langchain_cohere.llms import BaseCohere


def get_role(message: BaseMessage) -> str:
"""Get the role of the message.

Args:
message: The message.

Returns:
The role of the message.

Raises:
ValueError: If the message is of an unknown type.
"""
if isinstance(message, ChatMessage) or isinstance(message, HumanMessage):
return "User"
elif isinstance(message, AIMessage):
return "Chatbot"
elif isinstance(message, SystemMessage):
return "System"
else:
raise ValueError(f"Got unknown type {message}")
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env


def get_cohere_chat_request(
Expand Down Expand Up @@ -83,7 +63,7 @@ def get_cohere_chat_request(
req = {
"message": messages[-1].content,
"chat_history": [
{"role": get_role(x), "message": x.content} for x in messages[:-1]
{"role": _get_role(x), "message": x.content} for x in messages[:-1]
],
"documents": documents,
"connectors": maybe_connectors,
Expand All @@ -94,7 +74,7 @@ def get_cohere_chat_request(
return {k: v for k, v in req.items() if v is not None}


class ChatCohere(BaseChatModel, BaseCohere):
class ChatCohere(BaseChatModel, Serializable):
"""`Cohere` chat large language models.

To use, you should have the ``cohere`` python package installed, and the
Expand All @@ -113,12 +93,49 @@ class ChatCohere(BaseChatModel, BaseCohere):
chat.invoke(messages)
"""

client: Any = None #: :meta private:
async_client: Any = None #: :meta private:

model: Optional[str] = Field(default=None)
"""Model name to use."""

temperature: Optional[float] = None
"""A non-negative float that tunes the degree of randomness in generation."""

cohere_api_key: Optional[SecretStr] = None
"""Cohere API key. If not provided, will be read from the environment variable."""

stop: Optional[List[str]] = None

streaming: bool = Field(default=False)
"""Whether to stream the results."""

user_agent: str = "langchain"
"""Identifier for the application making the request."""

class Config:
"""Configuration for this pydantic object."""

allow_population_by_field_name = True
arbitrary_types_allowed = True

@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validates that the Cohere API key exists in the environment and instantiates the API clients."""
values["cohere_api_key"] = convert_to_secret_str(
get_from_dict_or_env(values, "cohere_api_key", "COHERE_API_KEY")
)
client_name = values["user_agent"]
values["client"] = cohere.Client(
api_key=values["cohere_api_key"].get_secret_value(),
client_name=client_name,
)
values["async_client"] = cohere.AsyncClient(
api_key=values["cohere_api_key"].get_secret_value(),
client_name=client_name,
)
return values

@property
def _llm_type(self) -> str:
"""Return type of chat model."""
Expand All @@ -130,6 +147,7 @@ def _default_params(self) -> Dict[str, Any]:
base_params = {
"model": self.model,
"temperature": self.temperature,
"stop_sequences": self.stop,
}
return {k: v for k, v in base_params.items() if v is not None}

Expand All @@ -147,11 +165,7 @@ def _stream(
) -> Iterator[ChatGenerationChunk]:
request = get_cohere_chat_request(messages, **self._default_params, **kwargs)

if hasattr(self.client, "chat_stream"): # detect and support sdk v5
stream = self.client.chat_stream(**request)
else:
stream = self.client.chat(**request, stream=True)

stream = self.client.chat_stream(**request)
for data in stream:
if data.event_type == "text-generation":
delta = data.text
Expand All @@ -169,11 +183,7 @@ async def _astream(
) -> AsyncIterator[ChatGenerationChunk]:
request = get_cohere_chat_request(messages, **self._default_params, **kwargs)

if hasattr(self.async_client, "chat_stream"): # detect and support sdk v5
stream = self.async_client.chat_stream(**request)
else:
stream = self.async_client.chat(**request, stream=True)

stream = self.async_client.chat_stream(**request)
async for data in stream:
if data.event_type == "text-generation":
delta = data.text
Expand Down Expand Up @@ -247,3 +257,26 @@ async def _agenerate(
def get_num_tokens(self, text: str) -> int:
"""Calculate number of tokens."""
return len(self.client.tokenize(text).tokens)


def _get_role(message: BaseMessage) -> str:
"""
Get the Cohere API representation of a role.

Args:
message: The message.

Returns:
The role of the message.

Raises:
ValueError: If the message is of an unknown type.
"""
if isinstance(message, ChatMessage) or isinstance(message, HumanMessage):
return "USER"
elif isinstance(message, AIMessage):
return "CHATBOT"
elif isinstance(message, SystemMessage):
return "SYSTEM"
else:
raise ValueError(f"Got unknown type {message}")
Loading
Loading