diff --git a/tests/models/test_factory.py b/tests/models/test_factory.py index 647acee2..2ac8437e 100644 --- a/tests/models/test_factory.py +++ b/tests/models/test_factory.py @@ -1,59 +1,66 @@ +from unittest.mock import patch + import pytest from readmeai.config.settings import ConfigLoader from readmeai.core.errors import UnsupportedServiceError from readmeai.extractors.models import RepositoryContext -from readmeai.models.anthropic import AnthropicHandler from readmeai.models.enums import LLMProviders from readmeai.models.factory import ModelFactory -from readmeai.models.gemini import GeminiHandler -from readmeai.models.offline import OfflineHandler -from readmeai.models.openai import OpenAIHandler -def test_get_backend_anthropic( - mock_config_loader: ConfigLoader, mock_repository_context: RepositoryContext +def test_get_backend_openai( + mock_config_loader: ConfigLoader, + mock_repository_context: RepositoryContext, + monkeypatch: pytest.MonkeyPatch, ): - mock_config_loader.config.llm.api = LLMProviders.ANTHROPIC.value + """Test getting OpenAI backend with proper environment setup.""" + mock_config_loader.config.llm.api = LLMProviders.OPENAI.value + monkeypatch.setenv("OPENAI_API_KEY", "test_key") handler = ModelFactory.get_backend(mock_config_loader, mock_repository_context) - assert isinstance(handler, AnthropicHandler) + assert handler is not None + assert handler.__class__.__name__ == "OpenAIHandler" -def test_get_backend_gemini( - mock_config_loader: ConfigLoader, mock_repository_context: RepositoryContext +def test_get_backend_anthropic( + mock_config_loader: ConfigLoader, + mock_repository_context: RepositoryContext, + monkeypatch: pytest.MonkeyPatch, ): - mock_config_loader.config.llm.api = LLMProviders.GEMINI.value + """Test getting Anthropic backend.""" + mock_config_loader.config.llm.api = LLMProviders.ANTHROPIC.value + monkeypatch.setenv("ANTHROPIC_API_KEY", "test_key") handler = ModelFactory.get_backend(mock_config_loader, mock_repository_context) - assert isinstance(handler, GeminiHandler) + assert handler is not None + assert handler.__class__.__name__ == "AnthropicHandler" -def test_get_backend_openai( - mock_config_loader: ConfigLoader, mock_repository_context: RepositoryContext +def test_get_backend_gemini( + mock_config_loader: ConfigLoader, + mock_repository_context: RepositoryContext, ): - mock_config_loader.config.llm.api = LLMProviders.OPENAI.value - handler = ModelFactory.get_backend(mock_config_loader, mock_repository_context) - assert isinstance(handler, OpenAIHandler) + """Test getting Gemini backend.""" + mock_config_loader.config.llm.api = LLMProviders.GEMINI.value + with patch.dict("os.environ", {"GOOGLE_API_KEY": "test_key"}, clear=True): + handler = ModelFactory.get_backend(mock_config_loader, mock_repository_context) + assert handler is not None + assert handler.__class__.__name__ == "GeminiHandler" def test_get_backend_offline( - mock_config_loader: ConfigLoader, mock_repository_context: RepositoryContext + mock_config_loader: ConfigLoader, + mock_repository_context: RepositoryContext, ): + """Test getting Offline backend.""" mock_config_loader.config.llm.api = LLMProviders.OFFLINE.value handler = ModelFactory.get_backend(mock_config_loader, mock_repository_context) - assert isinstance(handler, OfflineHandler) + assert handler is not None + assert handler.__class__.__name__ == "OfflineHandler" def test_get_backend_unsupported_service( mock_config_loader: ConfigLoader, mock_repository_context: RepositoryContext ): - """ - Work around test for unsupported service error. Create mock object that has similar - structure to the original config but isn't constrained by the same validation rules. - - This allows us to: - - Not trigger Pydantic validation errors - - Still test the error handling for unsupported services - - Maintain the original intent of the test case - """ + """Test getting a backend with an unsupported service.""" class UnsupportedConfig(ConfigLoader): """ @@ -68,6 +75,8 @@ def __init__(self) -> None: )() unsupported_config = UnsupportedConfig() + with pytest.raises(UnsupportedServiceError) as e: ModelFactory.get_backend(unsupported_config, mock_repository_context) + assert isinstance(e.value, UnsupportedServiceError)