From ba83f0d3b105e0d757d0990f2546226af6e8bdc0 Mon Sep 17 00:00:00 2001 From: Gustavo Cid Ornelas Date: Tue, 13 Aug 2024 09:40:37 -0300 Subject: [PATCH] feat: support Vertex AI models via LangChain callback handler --- .../lib/integrations/langchain_callback.py | 71 ++++++++++++------- 1 file changed, 47 insertions(+), 24 deletions(-) diff --git a/src/openlayer/lib/integrations/langchain_callback.py b/src/openlayer/lib/integrations/langchain_callback.py index 89eb3e04..c996c125 100644 --- a/src/openlayer/lib/integrations/langchain_callback.py +++ b/src/openlayer/lib/integrations/langchain_callback.py @@ -9,8 +9,12 @@ from ..tracing import tracer -LANGCHAIN_TO_OPENLAYER_PROVIDER_MAP = {"openai-chat": "OpenAI", "chat-ollama": "Ollama"} -PROVIDER_TO_STEP_NAME = {"OpenAI": "OpenAI Chat Completion", "Ollama": "Ollama Chat Completion"} +LANGCHAIN_TO_OPENLAYER_PROVIDER_MAP = {"openai-chat": "OpenAI", "chat-ollama": "Ollama", "vertexai": "Google"} +PROVIDER_TO_STEP_NAME = { + "OpenAI": "OpenAI Chat Completion", + "Ollama": "Ollama Chat Completion", + "Google": "Google Vertex AI Chat Completion", +} class OpenlayerHandler(BaseCallbackHandler): @@ -29,13 +33,27 @@ def __init__(self, **kwargs: Any) -> None: self.prompt_tokens: int = None self.completion_tokens: int = None self.total_tokens: int = None - self.output: str = None - self.metatada: Dict[str, Any] = kwargs or {} + self.output: str = "" + self.metadata: Dict[str, Any] = kwargs or {} # noqa arg002 def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any) -> Any: """Run when LLM starts running.""" - pass + self._initialize_run(kwargs) + self.prompt = [{"role": "user", "content": text} for text in prompts] + self.start_time = time.time() + + def _initialize_run(self, kwargs: Dict[str, Any]) -> None: + """Initializes an LLM (or Chat) run, extracting the provider, model name, + and other metadata.""" + self.model_parameters = kwargs.get("invocation_params", {}) + metadata = kwargs.get("metadata", {}) + + provider = self.model_parameters.pop("_type", None) + if provider in LANGCHAIN_TO_OPENLAYER_PROVIDER_MAP: + self.provider = LANGCHAIN_TO_OPENLAYER_PROVIDER_MAP[provider] + + self.model = self.model_parameters.get("model_name", None) or metadata.get("ls_model_name", None) def on_chat_model_start( self, @@ -44,18 +62,7 @@ def on_chat_model_start( **kwargs: Any, ) -> Any: """Run when Chat Model starts running.""" - self.model_parameters = kwargs.get("invocation_params", {}) - self.metadata = kwargs.get("metadata", {}) - - provider = self.model_parameters.get("_type", None) - if provider in LANGCHAIN_TO_OPENLAYER_PROVIDER_MAP: - self.provider = LANGCHAIN_TO_OPENLAYER_PROVIDER_MAP[provider] - self.model_parameters.pop("_type") - self.metadata.pop("ls_provider", None) - self.metadata.pop("ls_model_type", None) - - self.model = self.model_parameters.get("model_name", None) or self.metadata.pop("ls_model_name", None) - self.output = "" + self._initialize_run(kwargs) self.prompt = self._langchain_messages_to_prompt(messages) self.start_time = time.time() @@ -83,18 +90,20 @@ def on_llm_new_token(self, token: str, **kwargs: Any) -> Any: def on_llm_end(self, response: langchain_schema.LLMResult, **kwargs: Any) -> Any: # noqa: ARG002, E501 """Run when LLM ends running.""" self.end_time = time.time() - self.latency = (self.end_time - self.start_time) * 1000 + self.latency = (self.end_time - self.start_time) * 1000 # in milliseconds + self._extract_token_information(response=response) + self._extract_output(response=response) + self._add_to_trace() + + def _extract_token_information(self, response: langchain_schema.LLMResult) -> None: + """Extract token information based on provider.""" if self.provider == "OpenAI": self._openai_token_information(response) elif self.provider == "Ollama": self._ollama_token_information(response) - - for generations in response.generations: - for generation in generations: - self.output += generation.text.replace("\n", " ") - - self._add_to_trace() + elif self.provider == "Google": + self._google_token_information(response) def _openai_token_information(self, response: langchain_schema.LLMResult) -> None: """Extracts OpenAI's token information.""" @@ -111,6 +120,20 @@ def _ollama_token_information(self, response: langchain_schema.LLMResult) -> Non self.completion_tokens = generation_info.get("eval_count", 0) self.total_tokens = self.prompt_tokens + self.completion_tokens + def _google_token_information(self, response: langchain_schema.LLMResult) -> None: + """Extracts Google Vertex AI token information.""" + usage_metadata = response.generations[0][0].generation_info["usage_metadata"] + if usage_metadata: + self.prompt_tokens = usage_metadata.get("prompt_token_count", 0) + self.completion_tokens = usage_metadata.get("candidates_token_count", 0) + self.total_tokens = usage_metadata.get("total_token_count", 0) + + def _extract_output(self, response: langchain_schema.LLMResult) -> None: + """Extracts the output from the response.""" + for generations in response.generations: + for generation in generations: + self.output += generation.text.replace("\n", " ") + def _add_to_trace(self) -> None: """Adds to the trace.""" name = PROVIDER_TO_STEP_NAME.get(self.provider, "Chat Completion Model")