Skip to content

Commit

Permalink
add ollama llm and model listing
Browse files Browse the repository at this point in the history
  • Loading branch information
sarahwooders committed Oct 7, 2024
1 parent 05d9dea commit c389253
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 30 deletions.
72 changes: 51 additions & 21 deletions letta/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,34 +122,64 @@ def get_model_context_window(self, model_name: str):
response = requests.post(f"{self.base_url}/api/show", json={"name": model_name, "verbose": True})
response_json = response.json()

# thank you vLLM: https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L1675
possible_keys = [
# OPT
"max_position_embeddings",
# GPT-2
"n_positions",
# MPT
"max_seq_len",
# ChatGLM2
"seq_length",
# Command-R
"model_max_length",
# Others
"max_sequence_length",
"max_seq_length",
"seq_len",
]

## thank you vLLM: https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L1675
# possible_keys = [
# # OPT
# "max_position_embeddings",
# # GPT-2
# "n_positions",
# # MPT
# "max_seq_len",
# # ChatGLM2
# "seq_length",
# # Command-R
# "model_max_length",
# # Others
# "max_sequence_length",
# "max_seq_length",
# "seq_len",
# ]
# max_position_embeddings
# parse model cards: nous, dolphon, llama
for key, value in response_json["model_info"].items():
if "context_window" in key:
if "context_length" in key:
return value
return None

def get_model_embedding_dim(self, model_name: str):
import requests

response = requests.post(f"{self.base_url}/api/show", json={"name": model_name, "verbose": True})
response_json = response.json()
for key, value in response_json["model_info"].items():
if "embedding_length" in key:
return value
return None

def list_embedding_models(self) -> List[EmbeddingConfig]:
# TODO: filter embedding models
return []
# https://github.com/ollama/ollama/blob/main/docs/api.md#list-local-models
import requests

response = requests.get(f"{self.base_url}/api/tags")
if response.status_code != 200:
raise Exception(f"Failed to list Ollama models: {response.text}")
response_json = response.json()

configs = []
for model in response_json["models"]:
embedding_dim = self.get_model_embedding_dim(model["name"])
if not embedding_dim:
continue
configs.append(
EmbeddingConfig(
embedding_model=model["name"],
embedding_endpoint_type="ollama",
embedding_endpoint=self.base_url,
embedding_dim=embedding_dim,
embedding_chunk_size=300,
)
)
return configs


class GroqProvider(OpenAIProvider):
Expand Down
26 changes: 17 additions & 9 deletions tests/test_providers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
import os

from letta.providers import AnthropicProvider, GoogleAIProvider, OpenAIProvider
from letta.providers import (
AnthropicProvider,
GoogleAIProvider,
OllamaProvider,
OpenAIProvider,
)


def test_openai():
Expand All @@ -24,12 +29,15 @@ def test_anthropic():
# print(models)
#
#
# def test_ollama():
# provider = OllamaProvider()
# models = provider.list_llm_models()
# print(models)
#
#
def test_ollama():
provider = OllamaProvider(base_url=os.getenv("OLLAMA_BASE_URL"))
models = provider.list_llm_models()
print(models)

embedding_models = provider.list_embedding_models()
print(embedding_models)


def test_googleai():
provider = GoogleAIProvider(api_key=os.getenv("GEMINI_API_KEY"))
models = provider.list_llm_models()
Expand All @@ -40,8 +48,8 @@ def test_googleai():

#
#
test_googleai()
# test_ollama()
# test_googleai()
test_ollama()
# test_groq()
# test_openai()
# test_anthropic()

0 comments on commit c389253

Please sign in to comment.