Skip to content

Commit

Permalink
Fix/key authentication (#792)
Browse files Browse the repository at this point in the history
  • Loading branch information
nihiragarwal24 authored Jul 18, 2023
1 parent dc35026 commit d9c29fd
Show file tree
Hide file tree
Showing 9 changed files with 87 additions and 4 deletions.
13 changes: 10 additions & 3 deletions gui/pages/Dashboard/Settings/Settings.js
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import React, {useState, useEffect, useRef} from 'react';
import {ToastContainer, toast} from 'react-toastify';
import 'react-toastify/dist/ReactToastify.css';
import agentStyles from "@/pages/Content/Agents/Agents.module.css";
import {getOrganisationConfig, updateOrganisationConfig} from "@/pages/api/DashboardService";
import {getOrganisationConfig, updateOrganisationConfig,validateLLMApiKey} from "@/pages/api/DashboardService";
import {EventBus} from "@/utils/eventBus";
import {removeTab, setLocalStorageValue} from "@/utils/utils";
import Image from "next/image";
Expand Down Expand Up @@ -83,8 +83,15 @@ export default function Settings({organisationId}) {
return
}

updateKey("model_api_key", modelApiKey);
updateKey("model_source", source);
validateLLMApiKey(source, modelApiKey)
.then((response) => {
if (response.data.status==="success") {
updateKey("model_api_key", modelApiKey);
updateKey("model_source", source);
} else {
toast.error("Invalid API key", {autoClose: 1800});
}
})
};

const handleTemperatureChange = (event) => {
Expand Down
3 changes: 3 additions & 0 deletions gui/pages/api/DashboardService.js
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ export const validateAccessToken = () => {
return api.get(`/validate-access-token`);
}

export const validateLLMApiKey = (model_source, model_api_key) => {
return api.post(`/validate-llm-api-key`,{model_source, model_api_key});
}
export const checkEnvironment = () => {
return api.get(`/configs/get/env`);
}
Expand Down
18 changes: 18 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from superagi.controllers.analytics import router as analytics_router
from superagi.helper.tool_helper import register_toolkits
from superagi.lib.logger import logger
from superagi.llms.google_palm import GooglePalm
from superagi.llms.openai import OpenAi
from superagi.helper.auth import get_current_user
from superagi.models.agent_workflow import AgentWorkflow
Expand All @@ -53,6 +54,7 @@
from superagi.models.toolkit import Toolkit
from superagi.models.oauth_tokens import OauthTokens
from superagi.models.types.login_request import LoginRequest
from superagi.models.types.validate_llm_api_key_request import ValidateAPIKeyRequest
from superagi.models.user import User

app = FastAPI()
Expand Down Expand Up @@ -426,6 +428,22 @@ async def root(Authorize: AuthJWT = Depends()):
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token")


@app.post("/validate-llm-api-key")
async def validate_llm_api_key(request: ValidateAPIKeyRequest, Authorize: AuthJWT = Depends()):
"""API to validate LLM API Key"""
source = request.model_source
api_key = request.model_api_key
valid_api_key = False
if source == "OpenAi":
valid_api_key = OpenAi(api_key=api_key).verify_access_key()
elif source == "Google Palm":
valid_api_key = GooglePalm(api_key=api_key).verify_access_key()
if valid_api_key:
return {"message": "Valid API Key", "status": "success"}
else:
return {"message": "Invalid API Key", "status": "failed"}


@app.get("/validate-open-ai-key/{open_ai_key}")
async def root(open_ai_key: str, Authorize: AuthJWT = Depends()):
"""API to validate Open AI Key"""
Expand Down
6 changes: 5 additions & 1 deletion superagi/llms/base_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,8 @@ def get_api_key(self):

@abstractmethod
def get_model(self):
pass
pass

@abstractmethod
def verify_access_key(self):
pass
14 changes: 14 additions & 0 deletions superagi/llms/google_palm.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,17 @@ def chat_completion(self, messages, max_tokens=get_config("MAX_MODEL_TOKEN_LIMIT
except Exception as exception:
logger.info("Google palm Exception:", exception)
return {"error": exception}

def verify_access_key(self):
"""
Verify the access key is valid.
Returns:
bool: True if the access key is valid, False otherwise.
"""
try:
models = palm.list_models()
return True
except Exception as exception:
logger.info("Google palm Exception:", exception)
return False
14 changes: 14 additions & 0 deletions superagi/llms/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,17 @@ def chat_completion(self, messages, max_tokens=get_config("MAX_MODEL_TOKEN_LIMIT
except Exception as exception:
logger.info("OpenAi Exception:", exception)
return {"error": exception}

def verify_access_key(self):
"""
Verify the access key is valid.
Returns:
bool: True if the access key is valid, False otherwise.
"""
try:
models = openai.Model.list()
return True
except Exception as exception:
logger.info("OpenAi Exception:", exception)
return False
6 changes: 6 additions & 0 deletions superagi/models/types/validate_llm_api_key_request.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from pydantic import BaseModel


class ValidateAPIKeyRequest(BaseModel):
model_source: str
model_api_key: str
8 changes: 8 additions & 0 deletions tests/unit_tests/llms/test_google_palm.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,11 @@ def test_chat_completion(mock_palm):
top_p=palm_instance.top_p,
max_output_tokens=int(max_tokens)
)


def test_verify_access_key():
model = 'models/text-bison-001'
api_key = 'test_key'
palm_instance = GooglePalm(api_key, model=model)
result = palm_instance.verify_access_key()
assert result is False
9 changes: 9 additions & 0 deletions tests/unit_tests/llms/test_open_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from unittest.mock import MagicMock, patch
from superagi.llms.openai import OpenAi


@patch('superagi.llms.openai.openai')
def test_chat_completion(mock_openai):
# Arrange
Expand Down Expand Up @@ -30,3 +31,11 @@ def test_chat_completion(mock_openai):
frequency_penalty=openai_instance.frequency_penalty,
presence_penalty=openai_instance.presence_penalty
)


def test_verify_access_key():
model = 'gpt-4'
api_key = 'test_key'
openai_instance = OpenAi(api_key, model=model)
result = openai_instance.verify_access_key()
assert result is False

0 comments on commit d9c29fd

Please sign in to comment.