From d0b4f4c44133e414f5368bcb9c8158c9cde63816 Mon Sep 17 00:00:00 2001 From: marginal23326 <58261815+marginal23326@users.noreply.github.com> Date: Thu, 30 Jan 2025 18:24:32 +0600 Subject: [PATCH] refactor: simplify LLM tests and remove duplication --- tests/test_llm_api.py | 214 ++++++++++++++++++------------------------ 1 file changed, 89 insertions(+), 125 deletions(-) diff --git a/tests/test_llm_api.py b/tests/test_llm_api.py index 6075896..45d5775 100644 --- a/tests/test_llm_api.py +++ b/tests/test_llm_api.py @@ -1,7 +1,10 @@ import os import pdb +from dataclasses import dataclass from dotenv import load_dotenv +from langchain_core.messages import HumanMessage, SystemMessage +from langchain_ollama import ChatOllama load_dotenv() @@ -9,154 +12,115 @@ sys.path.append(".") - -def test_openai_model(): - from langchain_core.messages import HumanMessage +@dataclass +class LLMConfig: + provider: str + model_name: str + temperature: float = 0.8 + base_url: str = None + api_key: str = None + +def create_message_content(text, image_path=None): + content = [{"type": "text", "text": text}] + + if image_path: + from src.utils import utils + image_data = utils.encode_image(image_path) + content.append({ + "type": "image_url", + "image_url": {"url": f"data:image/jpeg;base64,{image_data}"} + }) + + return content + +def get_env_value(key, provider): + env_mappings = { + "openai": {"api_key": "OPENAI_API_KEY", "base_url": "OPENAI_ENDPOINT"}, + "azure_openai": {"api_key": "AZURE_OPENAI_API_KEY", "base_url": "AZURE_OPENAI_ENDPOINT"}, + "gemini": {"api_key": "GOOGLE_API_KEY"}, + "deepseek": {"api_key": "DEEPSEEK_API_KEY", "base_url": "DEEPSEEK_ENDPOINT"} + } + + if provider in env_mappings and key in env_mappings[provider]: + return os.getenv(env_mappings[provider][key], "") + return "" + +def test_llm(config, query, image_path=None, system_message=None): from src.utils import utils + # Special handling for Ollama-based models + if config.provider == "ollama": + if "deepseek-r1" in config.model_name: + from src.utils.llm import DeepSeekR1ChatOllama + llm = DeepSeekR1ChatOllama(model=config.model_name) + else: + llm = ChatOllama(model=config.model_name) + + ai_msg = llm.invoke(query) + print(ai_msg.content) + if "deepseek-r1" in config.model_name: + pdb.set_trace() + return + + # For other providers, use the standard configuration llm = utils.get_llm_model( - provider="openai", - model_name="gpt-4o", - temperature=0.8, - base_url=os.getenv("OPENAI_ENDPOINT", ""), - api_key=os.getenv("OPENAI_API_KEY", "") - ) - image_path = "assets/examples/test.png" - image_data = utils.encode_image(image_path) - message = HumanMessage( - content=[ - {"type": "text", "text": "describe this image"}, - { - "type": "image_url", - "image_url": {"url": f"data:image/jpeg;base64,{image_data}"}, - }, - ] + provider=config.provider, + model_name=config.model_name, + temperature=config.temperature, + base_url=config.base_url or get_env_value("base_url", config.provider), + api_key=config.api_key or get_env_value("api_key", config.provider) ) - ai_msg = llm.invoke([message]) - print(ai_msg.content) - -def test_gemini_model(): - # you need to enable your api key first: https://ai.google.dev/palm_docs/oauth_quickstart - from langchain_core.messages import HumanMessage - from src.utils import utils - - llm = utils.get_llm_model( - provider="gemini", - model_name="gemini-2.0-flash-exp", - temperature=0.8, - api_key=os.getenv("GOOGLE_API_KEY", "") - ) + # Prepare messages for non-Ollama models + messages = [] + if system_message: + messages.append(SystemMessage(content=create_message_content(system_message))) + messages.append(HumanMessage(content=create_message_content(query, image_path))) + ai_msg = llm.invoke(messages) - image_path = "assets/examples/test.png" - image_data = utils.encode_image(image_path) - message = HumanMessage( - content=[ - {"type": "text", "text": "describe this image"}, - { - "type": "image_url", - "image_url": {"url": f"data:image/jpeg;base64,{image_data}"}, - }, - ] - ) - ai_msg = llm.invoke([message]) + # Handle different response types + if hasattr(ai_msg, "reasoning_content"): + print(ai_msg.reasoning_content) print(ai_msg.content) + if config.provider == "deepseek" and "deepseek-reasoner" in config.model_name: + print(llm.model_name) + pdb.set_trace() -def test_azure_openai_model(): - from langchain_core.messages import HumanMessage - from src.utils import utils +def test_openai_model(): + config = LLMConfig(provider="openai", model_name="gpt-4o") + test_llm(config, "Describe this image", "assets/examples/test.png") - llm = utils.get_llm_model( - provider="azure_openai", - model_name="gpt-4o", - temperature=0.8, - base_url=os.getenv("AZURE_OPENAI_ENDPOINT", ""), - api_key=os.getenv("AZURE_OPENAI_API_KEY", "") - ) - image_path = "assets/examples/test.png" - image_data = utils.encode_image(image_path) - message = HumanMessage( - content=[ - {"type": "text", "text": "describe this image"}, - { - "type": "image_url", - "image_url": {"url": f"data:image/jpeg;base64,{image_data}"}, - }, - ] - ) - ai_msg = llm.invoke([message]) - print(ai_msg.content) +def test_gemini_model(): + # Enable your API key first if you haven't: https://ai.google.dev/palm_docs/oauth_quickstart + config = LLMConfig(provider="gemini", model_name="gemini-2.0-flash-exp") + test_llm(config, "Describe this image", "assets/examples/test.png") +def test_azure_openai_model(): + config = LLMConfig(provider="azure_openai", model_name="gpt-4o") + test_llm(config, "Describe this image", "assets/examples/test.png") def test_deepseek_model(): - from langchain_core.messages import HumanMessage - from src.utils import utils - - llm = utils.get_llm_model( - provider="deepseek", - model_name="deepseek-chat", - temperature=0.8, - base_url=os.getenv("DEEPSEEK_ENDPOINT", ""), - api_key=os.getenv("DEEPSEEK_API_KEY", "") - ) - message = HumanMessage( - content=[ - {"type": "text", "text": "who are you?"} - ] - ) - ai_msg = llm.invoke([message]) - print(ai_msg.content) + config = LLMConfig(provider="deepseek", model_name="deepseek-chat") + test_llm(config, "Who are you?") def test_deepseek_r1_model(): - from langchain_core.messages import HumanMessage, SystemMessage, AIMessage - from src.utils import utils - - llm = utils.get_llm_model( - provider="deepseek", - model_name="deepseek-reasoner", - temperature=0.8, - base_url=os.getenv("DEEPSEEK_ENDPOINT", ""), - api_key=os.getenv("DEEPSEEK_API_KEY", "") - ) - messages = [] - sys_message = SystemMessage( - content=[{"type": "text", "text": "you are a helpful AI assistant"}] - ) - messages.append(sys_message) - user_message = HumanMessage( - content=[ - {"type": "text", "text": "9.11 and 9.8, which is greater?"} - ] - ) - messages.append(user_message) - ai_msg = llm.invoke(messages) - print(ai_msg.reasoning_content) - print(ai_msg.content) - print(llm.model_name) - pdb.set_trace() + config = LLMConfig(provider="deepseek", model_name="deepseek-reasoner") + test_llm(config, "Which is greater, 9.11 or 9.8?", system_message="You are a helpful AI assistant.") def test_ollama_model(): - from langchain_ollama import ChatOllama + config = LLMConfig(provider="ollama", model_name="qwen2.5:7b") + test_llm(config, "Sing a ballad of LangChain.") - llm = ChatOllama(model="qwen2.5:7b") - ai_msg = llm.invoke("Sing a ballad of LangChain.") - print(ai_msg.content) - def test_deepseek_r1_ollama_model(): - from src.utils.llm import DeepSeekR1ChatOllama - - llm = DeepSeekR1ChatOllama(model="deepseek-r1:14b") - ai_msg = llm.invoke("how many r in strawberry?") - print(ai_msg.content) - pdb.set_trace() - + config = LLMConfig(provider="ollama", model_name="deepseek-r1:14b") + test_llm(config, "How many 'r's are in the word 'strawberry'?") -if __name__ == '__main__': +if __name__ == "__main__": # test_openai_model() # test_gemini_model() # test_azure_openai_model() test_deepseek_model() # test_ollama_model() # test_deepseek_r1_model() - # test_deepseek_r1_ollama_model() \ No newline at end of file + # test_deepseek_r1_ollama_model()