Skip to content

Commit

Permalink
feat: support Vertex AI models via LangChain callback handler
Browse files Browse the repository at this point in the history
  • Loading branch information
gustavocidornelas committed Aug 13, 2024
1 parent 82cf45a commit ba83f0d
Showing 1 changed file with 47 additions and 24 deletions.
71 changes: 47 additions & 24 deletions src/openlayer/lib/integrations/langchain_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,12 @@

from ..tracing import tracer

LANGCHAIN_TO_OPENLAYER_PROVIDER_MAP = {"openai-chat": "OpenAI", "chat-ollama": "Ollama"}
PROVIDER_TO_STEP_NAME = {"OpenAI": "OpenAI Chat Completion", "Ollama": "Ollama Chat Completion"}
LANGCHAIN_TO_OPENLAYER_PROVIDER_MAP = {"openai-chat": "OpenAI", "chat-ollama": "Ollama", "vertexai": "Google"}
PROVIDER_TO_STEP_NAME = {
"OpenAI": "OpenAI Chat Completion",
"Ollama": "Ollama Chat Completion",
"Google": "Google Vertex AI Chat Completion",
}


class OpenlayerHandler(BaseCallbackHandler):
Expand All @@ -29,13 +33,27 @@ def __init__(self, **kwargs: Any) -> None:
self.prompt_tokens: int = None
self.completion_tokens: int = None
self.total_tokens: int = None
self.output: str = None
self.metatada: Dict[str, Any] = kwargs or {}
self.output: str = ""
self.metadata: Dict[str, Any] = kwargs or {}

# noqa arg002
def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any) -> Any:
"""Run when LLM starts running."""
pass
self._initialize_run(kwargs)
self.prompt = [{"role": "user", "content": text} for text in prompts]
self.start_time = time.time()

def _initialize_run(self, kwargs: Dict[str, Any]) -> None:
"""Initializes an LLM (or Chat) run, extracting the provider, model name,
and other metadata."""
self.model_parameters = kwargs.get("invocation_params", {})
metadata = kwargs.get("metadata", {})

provider = self.model_parameters.pop("_type", None)
if provider in LANGCHAIN_TO_OPENLAYER_PROVIDER_MAP:
self.provider = LANGCHAIN_TO_OPENLAYER_PROVIDER_MAP[provider]

self.model = self.model_parameters.get("model_name", None) or metadata.get("ls_model_name", None)

def on_chat_model_start(
self,
Expand All @@ -44,18 +62,7 @@ def on_chat_model_start(
**kwargs: Any,
) -> Any:
"""Run when Chat Model starts running."""
self.model_parameters = kwargs.get("invocation_params", {})
self.metadata = kwargs.get("metadata", {})

provider = self.model_parameters.get("_type", None)
if provider in LANGCHAIN_TO_OPENLAYER_PROVIDER_MAP:
self.provider = LANGCHAIN_TO_OPENLAYER_PROVIDER_MAP[provider]
self.model_parameters.pop("_type")
self.metadata.pop("ls_provider", None)
self.metadata.pop("ls_model_type", None)

self.model = self.model_parameters.get("model_name", None) or self.metadata.pop("ls_model_name", None)
self.output = ""
self._initialize_run(kwargs)
self.prompt = self._langchain_messages_to_prompt(messages)
self.start_time = time.time()

Expand Down Expand Up @@ -83,18 +90,20 @@ def on_llm_new_token(self, token: str, **kwargs: Any) -> Any:
def on_llm_end(self, response: langchain_schema.LLMResult, **kwargs: Any) -> Any: # noqa: ARG002, E501
"""Run when LLM ends running."""
self.end_time = time.time()
self.latency = (self.end_time - self.start_time) * 1000
self.latency = (self.end_time - self.start_time) * 1000 # in milliseconds

self._extract_token_information(response=response)
self._extract_output(response=response)
self._add_to_trace()

def _extract_token_information(self, response: langchain_schema.LLMResult) -> None:
"""Extract token information based on provider."""
if self.provider == "OpenAI":
self._openai_token_information(response)
elif self.provider == "Ollama":
self._ollama_token_information(response)

for generations in response.generations:
for generation in generations:
self.output += generation.text.replace("\n", " ")

self._add_to_trace()
elif self.provider == "Google":
self._google_token_information(response)

def _openai_token_information(self, response: langchain_schema.LLMResult) -> None:
"""Extracts OpenAI's token information."""
Expand All @@ -111,6 +120,20 @@ def _ollama_token_information(self, response: langchain_schema.LLMResult) -> Non
self.completion_tokens = generation_info.get("eval_count", 0)
self.total_tokens = self.prompt_tokens + self.completion_tokens

def _google_token_information(self, response: langchain_schema.LLMResult) -> None:
"""Extracts Google Vertex AI token information."""
usage_metadata = response.generations[0][0].generation_info["usage_metadata"]
if usage_metadata:
self.prompt_tokens = usage_metadata.get("prompt_token_count", 0)
self.completion_tokens = usage_metadata.get("candidates_token_count", 0)
self.total_tokens = usage_metadata.get("total_token_count", 0)

def _extract_output(self, response: langchain_schema.LLMResult) -> None:
"""Extracts the output from the response."""
for generations in response.generations:
for generation in generations:
self.output += generation.text.replace("\n", " ")

def _add_to_trace(self) -> None:
"""Adds to the trace."""
name = PROVIDER_TO_STEP_NAME.get(self.provider, "Chat Completion Model")
Expand Down

0 comments on commit ba83f0d

Please sign in to comment.