Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: simplify LLM tests and remove duplication #206

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
214 changes: 89 additions & 125 deletions tests/test_llm_api.py
Original file line number Diff line number Diff line change
@@ -1,162 +1,126 @@
import os
import pdb
from dataclasses import dataclass

from dotenv import load_dotenv
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_ollama import ChatOllama

load_dotenv()

import sys

sys.path.append(".")


def test_openai_model():
from langchain_core.messages import HumanMessage
@dataclass
class LLMConfig:
provider: str
model_name: str
temperature: float = 0.8
base_url: str = None
api_key: str = None

def create_message_content(text, image_path=None):
content = [{"type": "text", "text": text}]

if image_path:
from src.utils import utils
image_data = utils.encode_image(image_path)
content.append({
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{image_data}"}
})

return content

def get_env_value(key, provider):
env_mappings = {
"openai": {"api_key": "OPENAI_API_KEY", "base_url": "OPENAI_ENDPOINT"},
"azure_openai": {"api_key": "AZURE_OPENAI_API_KEY", "base_url": "AZURE_OPENAI_ENDPOINT"},
"gemini": {"api_key": "GOOGLE_API_KEY"},
"deepseek": {"api_key": "DEEPSEEK_API_KEY", "base_url": "DEEPSEEK_ENDPOINT"}
}

if provider in env_mappings and key in env_mappings[provider]:
return os.getenv(env_mappings[provider][key], "")
return ""

def test_llm(config, query, image_path=None, system_message=None):
from src.utils import utils

# Special handling for Ollama-based models
if config.provider == "ollama":
if "deepseek-r1" in config.model_name:
from src.utils.llm import DeepSeekR1ChatOllama
llm = DeepSeekR1ChatOllama(model=config.model_name)
else:
llm = ChatOllama(model=config.model_name)

ai_msg = llm.invoke(query)
print(ai_msg.content)
if "deepseek-r1" in config.model_name:
pdb.set_trace()
return

# For other providers, use the standard configuration
llm = utils.get_llm_model(
provider="openai",
model_name="gpt-4o",
temperature=0.8,
base_url=os.getenv("OPENAI_ENDPOINT", ""),
api_key=os.getenv("OPENAI_API_KEY", "")
)
image_path = "assets/examples/test.png"
image_data = utils.encode_image(image_path)
message = HumanMessage(
content=[
{"type": "text", "text": "describe this image"},
{
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{image_data}"},
},
]
provider=config.provider,
model_name=config.model_name,
temperature=config.temperature,
base_url=config.base_url or get_env_value("base_url", config.provider),
api_key=config.api_key or get_env_value("api_key", config.provider)
)
ai_msg = llm.invoke([message])
print(ai_msg.content)


def test_gemini_model():
# you need to enable your api key first: https://ai.google.dev/palm_docs/oauth_quickstart
from langchain_core.messages import HumanMessage
from src.utils import utils

llm = utils.get_llm_model(
provider="gemini",
model_name="gemini-2.0-flash-exp",
temperature=0.8,
api_key=os.getenv("GOOGLE_API_KEY", "")
)
# Prepare messages for non-Ollama models
messages = []
if system_message:
messages.append(SystemMessage(content=create_message_content(system_message)))
messages.append(HumanMessage(content=create_message_content(query, image_path)))
ai_msg = llm.invoke(messages)

image_path = "assets/examples/test.png"
image_data = utils.encode_image(image_path)
message = HumanMessage(
content=[
{"type": "text", "text": "describe this image"},
{
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{image_data}"},
},
]
)
ai_msg = llm.invoke([message])
# Handle different response types
if hasattr(ai_msg, "reasoning_content"):
print(ai_msg.reasoning_content)
print(ai_msg.content)

if config.provider == "deepseek" and "deepseek-reasoner" in config.model_name:
print(llm.model_name)
pdb.set_trace()

def test_azure_openai_model():
from langchain_core.messages import HumanMessage
from src.utils import utils
def test_openai_model():
config = LLMConfig(provider="openai", model_name="gpt-4o")
test_llm(config, "Describe this image", "assets/examples/test.png")

llm = utils.get_llm_model(
provider="azure_openai",
model_name="gpt-4o",
temperature=0.8,
base_url=os.getenv("AZURE_OPENAI_ENDPOINT", ""),
api_key=os.getenv("AZURE_OPENAI_API_KEY", "")
)
image_path = "assets/examples/test.png"
image_data = utils.encode_image(image_path)
message = HumanMessage(
content=[
{"type": "text", "text": "describe this image"},
{
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{image_data}"},
},
]
)
ai_msg = llm.invoke([message])
print(ai_msg.content)
def test_gemini_model():
# Enable your API key first if you haven't: https://ai.google.dev/palm_docs/oauth_quickstart
config = LLMConfig(provider="gemini", model_name="gemini-2.0-flash-exp")
test_llm(config, "Describe this image", "assets/examples/test.png")

def test_azure_openai_model():
config = LLMConfig(provider="azure_openai", model_name="gpt-4o")
test_llm(config, "Describe this image", "assets/examples/test.png")

def test_deepseek_model():
from langchain_core.messages import HumanMessage
from src.utils import utils

llm = utils.get_llm_model(
provider="deepseek",
model_name="deepseek-chat",
temperature=0.8,
base_url=os.getenv("DEEPSEEK_ENDPOINT", ""),
api_key=os.getenv("DEEPSEEK_API_KEY", "")
)
message = HumanMessage(
content=[
{"type": "text", "text": "who are you?"}
]
)
ai_msg = llm.invoke([message])
print(ai_msg.content)
config = LLMConfig(provider="deepseek", model_name="deepseek-chat")
test_llm(config, "Who are you?")

def test_deepseek_r1_model():
from langchain_core.messages import HumanMessage, SystemMessage, AIMessage
from src.utils import utils

llm = utils.get_llm_model(
provider="deepseek",
model_name="deepseek-reasoner",
temperature=0.8,
base_url=os.getenv("DEEPSEEK_ENDPOINT", ""),
api_key=os.getenv("DEEPSEEK_API_KEY", "")
)
messages = []
sys_message = SystemMessage(
content=[{"type": "text", "text": "you are a helpful AI assistant"}]
)
messages.append(sys_message)
user_message = HumanMessage(
content=[
{"type": "text", "text": "9.11 and 9.8, which is greater?"}
]
)
messages.append(user_message)
ai_msg = llm.invoke(messages)
print(ai_msg.reasoning_content)
print(ai_msg.content)
print(llm.model_name)
pdb.set_trace()
config = LLMConfig(provider="deepseek", model_name="deepseek-reasoner")
test_llm(config, "Which is greater, 9.11 or 9.8?", system_message="You are a helpful AI assistant.")

def test_ollama_model():
from langchain_ollama import ChatOllama
config = LLMConfig(provider="ollama", model_name="qwen2.5:7b")
test_llm(config, "Sing a ballad of LangChain.")

llm = ChatOllama(model="qwen2.5:7b")
ai_msg = llm.invoke("Sing a ballad of LangChain.")
print(ai_msg.content)

def test_deepseek_r1_ollama_model():
from src.utils.llm import DeepSeekR1ChatOllama

llm = DeepSeekR1ChatOllama(model="deepseek-r1:14b")
ai_msg = llm.invoke("how many r in strawberry?")
print(ai_msg.content)
pdb.set_trace()

config = LLMConfig(provider="ollama", model_name="deepseek-r1:14b")
test_llm(config, "How many 'r's are in the word 'strawberry'?")

if __name__ == '__main__':
if __name__ == "__main__":
# test_openai_model()
# test_gemini_model()
# test_azure_openai_model()
test_deepseek_model()
# test_ollama_model()
# test_deepseek_r1_model()
# test_deepseek_r1_ollama_model()
# test_deepseek_r1_ollama_model()