diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index 8b91797..04bb6d1 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -17,11 +17,6 @@ jobs: uses: actions/setup-python@v3 with: python-version: "3.10.10" - - id: 'auth' - name: 'Authenticate to Google Cloud' - uses: 'google-github-actions/auth@v1' - with: - credentials_json: '${{ secrets.GOOGLE_CREDENTIALS }}' - name: Install dependencies run: | make install-poetry diff --git a/tests/conftest.py b/tests/conftest.py index 8ff7b2f..0843bb7 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,9 +1,11 @@ import asyncio import typing from dataclasses import dataclass +from unittest.mock import patch import pytest from aioresponses import aioresponses +from langchain_community.llms.fake import FakeListLLM from allms.domain.configuration import ( AzureOpenAIConfiguration, AzureSelfDeployedConfiguration, VertexAIConfiguration, VertexAIModelGardenConfiguration) @@ -25,59 +27,70 @@ class GenerativeModels: vertex_palm: typing.Optional[VertexAIPalmModel] = None +class VertexAIMock(FakeListLLM): + def __init__(self, *args, **kwargs): + super().__init__(responses=["{}"]) + + @pytest.fixture(scope="function") def models(): event_loop = asyncio.new_event_loop() - return { - "azure_open_ai": AzureOpenAIModel( - config=AzureOpenAIConfiguration( - api_key="dummy_api_key", - base_url=AzureOpenAIEnv.OPENAI_API_BASE, - api_version=AzureOpenAIEnv.OPENAI_API_VERSION, - deployment=AzureOpenAIEnv.OPENAI_DEPLOYMENT_NAME, - model_name="gpt-4" - ), - event_loop=event_loop - ), - "vertex_palm": VertexAIPalmModel( - config=VertexAIConfiguration( - cloud_project="dummy-project-id", - cloud_location="us-central1" - ), - event_loop=event_loop - ), - "vertex_gemini": VertexAIGeminiModel( - config=VertexAIConfiguration( - cloud_project="dummy-project-id", - cloud_location="us-central1" - ), - event_loop=event_loop - ), - "vertex_gemma": VertexAIGemmaModel( - config=VertexAIModelGardenConfiguration( - cloud_project="dummy-project-id", - cloud_location="us-central1", - endpoint_id="dummy-endpoint-id" - ), - event_loop=event_loop - ), - "azure_llama2": AzureLlama2Model( - config=AzureSelfDeployedConfiguration( - api_key="dummy_api_key", - endpoint_url="https://dummy-endpoint.dummy-region.inference.ml.azure.com/score", - deployment="dummy_deployment_name" - ), - event_loop=event_loop - ), - "azure_mistral": AzureMistralModel( - config=AzureSelfDeployedConfiguration( - api_key="dummy_api_key", - endpoint_url="https://dummy-endpoint.dummy-region.inference.ml.azure.com/score", - deployment="dummy_deployment_name" - ), - event_loop=event_loop - ) - } + + with ( + patch("allms.models.vertexai_palm.CustomVertexAI", VertexAIMock), + patch("allms.models.vertexai_gemini.CustomVertexAI", VertexAIMock), + patch("allms.models.vertexai_gemma.VertexAIModelGardenWrapper", VertexAIMock) + ): + return { + "azure_open_ai": AzureOpenAIModel( + config=AzureOpenAIConfiguration( + api_key="dummy_api_key", + base_url=AzureOpenAIEnv.OPENAI_API_BASE, + api_version=AzureOpenAIEnv.OPENAI_API_VERSION, + deployment=AzureOpenAIEnv.OPENAI_DEPLOYMENT_NAME, + model_name="gpt-4" + ), + event_loop=event_loop + ), + "vertex_palm": VertexAIPalmModel( + config=VertexAIConfiguration( + cloud_project="dummy-project-id", + cloud_location="us-central1" + ), + event_loop=event_loop + ), + "vertex_gemini": VertexAIGeminiModel( + config=VertexAIConfiguration( + cloud_project="dummy-project-id", + cloud_location="us-central1" + ), + event_loop=event_loop + ), + "vertex_gemma": VertexAIGemmaModel( + config=VertexAIModelGardenConfiguration( + cloud_project="dummy-project-id", + cloud_location="us-central1", + endpoint_id="dummy-endpoint-id" + ), + event_loop=event_loop + ), + "azure_llama2": AzureLlama2Model( + config=AzureSelfDeployedConfiguration( + api_key="dummy_api_key", + endpoint_url="https://dummy-endpoint.dummy-region.inference.ml.azure.com/score", + deployment="dummy_deployment_name" + ), + event_loop=event_loop + ), + "azure_mistral": AzureMistralModel( + config=AzureSelfDeployedConfiguration( + api_key="dummy_api_key", + endpoint_url="https://dummy-endpoint.dummy-region.inference.ml.azure.com/score", + deployment="dummy_deployment_name" + ), + event_loop=event_loop + ) + } @pytest.fixture