diff --git a/gui/pages/Dashboard/Settings/Settings.js b/gui/pages/Dashboard/Settings/Settings.js index 768baf54d..45a495edb 100644 --- a/gui/pages/Dashboard/Settings/Settings.js +++ b/gui/pages/Dashboard/Settings/Settings.js @@ -2,7 +2,7 @@ import React, {useState, useEffect, useRef} from 'react'; import {ToastContainer, toast} from 'react-toastify'; import 'react-toastify/dist/ReactToastify.css'; import agentStyles from "@/pages/Content/Agents/Agents.module.css"; -import {getOrganisationConfig, updateOrganisationConfig} from "@/pages/api/DashboardService"; +import {getOrganisationConfig, updateOrganisationConfig,validateLLMApiKey} from "@/pages/api/DashboardService"; import {EventBus} from "@/utils/eventBus"; import {removeTab, setLocalStorageValue} from "@/utils/utils"; import Image from "next/image"; @@ -83,8 +83,15 @@ export default function Settings({organisationId}) { return } - updateKey("model_api_key", modelApiKey); - updateKey("model_source", source); + validateLLMApiKey(source, modelApiKey) + .then((response) => { + if (response.data.status==="success") { + updateKey("model_api_key", modelApiKey); + updateKey("model_source", source); + } else { + toast.error("Invalid API key", {autoClose: 1800}); + } + }) }; const handleTemperatureChange = (event) => { diff --git a/gui/pages/api/DashboardService.js b/gui/pages/api/DashboardService.js index 2e1978f91..cf976ab64 100644 --- a/gui/pages/api/DashboardService.js +++ b/gui/pages/api/DashboardService.js @@ -76,6 +76,9 @@ export const validateAccessToken = () => { return api.get(`/validate-access-token`); } +export const validateLLMApiKey = (model_source, model_api_key) => { + return api.post(`/validate-llm-api-key`,{model_source, model_api_key}); +} export const checkEnvironment = () => { return api.get(`/configs/get/env`); } diff --git a/main.py b/main.py index ff5966c2d..7de018229 100644 --- a/main.py +++ b/main.py @@ -44,6 +44,7 @@ from superagi.controllers.analytics import router as analytics_router from superagi.helper.tool_helper import register_toolkits from superagi.lib.logger import logger +from superagi.llms.google_palm import GooglePalm from superagi.llms.openai import OpenAi from superagi.helper.auth import get_current_user from superagi.models.agent_workflow import AgentWorkflow @@ -53,6 +54,7 @@ from superagi.models.toolkit import Toolkit from superagi.models.oauth_tokens import OauthTokens from superagi.models.types.login_request import LoginRequest +from superagi.models.types.validate_llm_api_key_request import ValidateAPIKeyRequest from superagi.models.user import User app = FastAPI() @@ -426,6 +428,22 @@ async def root(Authorize: AuthJWT = Depends()): raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token") +@app.post("/validate-llm-api-key") +async def validate_llm_api_key(request: ValidateAPIKeyRequest, Authorize: AuthJWT = Depends()): + """API to validate LLM API Key""" + source = request.model_source + api_key = request.model_api_key + valid_api_key = False + if source == "OpenAi": + valid_api_key = OpenAi(api_key=api_key).verify_access_key() + elif source == "Google Palm": + valid_api_key = GooglePalm(api_key=api_key).verify_access_key() + if valid_api_key: + return {"message": "Valid API Key", "status": "success"} + else: + return {"message": "Invalid API Key", "status": "failed"} + + @app.get("/validate-open-ai-key/{open_ai_key}") async def root(open_ai_key: str, Authorize: AuthJWT = Depends()): """API to validate Open AI Key""" diff --git a/superagi/llms/base_llm.py b/superagi/llms/base_llm.py index 4408fcf9a..12b9eb452 100644 --- a/superagi/llms/base_llm.py +++ b/superagi/llms/base_llm.py @@ -16,4 +16,8 @@ def get_api_key(self): @abstractmethod def get_model(self): - pass \ No newline at end of file + pass + + @abstractmethod + def verify_access_key(self): + pass diff --git a/superagi/llms/google_palm.py b/superagi/llms/google_palm.py index 0f1de1e47..b19b595b3 100644 --- a/superagi/llms/google_palm.py +++ b/superagi/llms/google_palm.py @@ -76,3 +76,17 @@ def chat_completion(self, messages, max_tokens=get_config("MAX_MODEL_TOKEN_LIMIT except Exception as exception: logger.info("Google palm Exception:", exception) return {"error": exception} + + def verify_access_key(self): + """ + Verify the access key is valid. + + Returns: + bool: True if the access key is valid, False otherwise. + """ + try: + models = palm.list_models() + return True + except Exception as exception: + logger.info("Google palm Exception:", exception) + return False diff --git a/superagi/llms/openai.py b/superagi/llms/openai.py index 4f58b4200..1d44cc6dc 100644 --- a/superagi/llms/openai.py +++ b/superagi/llms/openai.py @@ -76,3 +76,17 @@ def chat_completion(self, messages, max_tokens=get_config("MAX_MODEL_TOKEN_LIMIT except Exception as exception: logger.info("OpenAi Exception:", exception) return {"error": exception} + + def verify_access_key(self): + """ + Verify the access key is valid. + + Returns: + bool: True if the access key is valid, False otherwise. + """ + try: + models = openai.Model.list() + return True + except Exception as exception: + logger.info("OpenAi Exception:", exception) + return False diff --git a/superagi/models/types/validate_llm_api_key_request.py b/superagi/models/types/validate_llm_api_key_request.py new file mode 100644 index 000000000..4a8153de8 --- /dev/null +++ b/superagi/models/types/validate_llm_api_key_request.py @@ -0,0 +1,6 @@ +from pydantic import BaseModel + + +class ValidateAPIKeyRequest(BaseModel): + model_source: str + model_api_key: str diff --git a/tests/unit_tests/llms/test_google_palm.py b/tests/unit_tests/llms/test_google_palm.py index 6c0a9d964..e9848ac88 100644 --- a/tests/unit_tests/llms/test_google_palm.py +++ b/tests/unit_tests/llms/test_google_palm.py @@ -28,3 +28,11 @@ def test_chat_completion(mock_palm): top_p=palm_instance.top_p, max_output_tokens=int(max_tokens) ) + + +def test_verify_access_key(): + model = 'models/text-bison-001' + api_key = 'test_key' + palm_instance = GooglePalm(api_key, model=model) + result = palm_instance.verify_access_key() + assert result is False diff --git a/tests/unit_tests/llms/test_open_ai.py b/tests/unit_tests/llms/test_open_ai.py index 955581799..9882092f4 100644 --- a/tests/unit_tests/llms/test_open_ai.py +++ b/tests/unit_tests/llms/test_open_ai.py @@ -2,6 +2,7 @@ from unittest.mock import MagicMock, patch from superagi.llms.openai import OpenAi + @patch('superagi.llms.openai.openai') def test_chat_completion(mock_openai): # Arrange @@ -30,3 +31,11 @@ def test_chat_completion(mock_openai): frequency_penalty=openai_instance.frequency_penalty, presence_penalty=openai_instance.presence_penalty ) + + +def test_verify_access_key(): + model = 'gpt-4' + api_key = 'test_key' + openai_instance = OpenAi(api_key, model=model) + result = openai_instance.verify_access_key() + assert result is False