Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/key auth #790

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions gui/pages/Dashboard/Settings/Settings.js
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import React, {useState, useEffect, useRef} from 'react';
import {ToastContainer, toast} from 'react-toastify';
import 'react-toastify/dist/ReactToastify.css';
import agentStyles from "@/pages/Content/Agents/Agents.module.css";
import {getOrganisationConfig, updateOrganisationConfig} from "@/pages/api/DashboardService";
import {getOrganisationConfig, updateOrganisationConfig,validateLLMApiKey} from "@/pages/api/DashboardService";
import {EventBus} from "@/utils/eventBus";
import {removeTab, setLocalStorageValue} from "@/utils/utils";
import Image from "next/image";
Expand Down Expand Up @@ -79,8 +79,16 @@ export default function Settings({organisationId}) {
return
}

updateKey("model_api_key", modelApiKey);
updateKey("model_source", source);
validateLLMApiKey(source, modelApiKey)
.then((response) => {
console.log("CHANGES",response.data)
if (response.data.status==="success") {
updateKey("model_api_key", modelApiKey);
updateKey("model_source", source);
} else {
toast.error("Invalid API key", {autoClose: 1800});
}
})
};

const handleTemperatureChange = (event) => {
Expand Down
3 changes: 3 additions & 0 deletions gui/pages/api/DashboardService.js
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ export const validateAccessToken = () => {
return api.get(`/validate-access-token`);
}

export const validateLLMApiKey = (model_source, model_api_key) => {
return api.post(`/validate-llm-api-key`,{model_source, model_api_key});
}
export const checkEnvironment = () => {
return api.get(`/configs/get/env`);
}
Expand Down
18 changes: 18 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from superagi.controllers.analytics import router as analytics_router
from superagi.helper.tool_helper import register_toolkits
from superagi.lib.logger import logger
from superagi.llms.google_palm import GooglePalm
from superagi.llms.openai import OpenAi
from superagi.helper.auth import get_current_user
from superagi.models.agent_workflow import AgentWorkflow
Expand All @@ -53,6 +54,7 @@
from superagi.models.toolkit import Toolkit
from superagi.models.oauth_tokens import OauthTokens
from superagi.models.types.login_request import LoginRequest
from superagi.models.types.validate_llm_api_key_request import ValidateAPIKeyRequest
from superagi.models.user import User

app = FastAPI()
Expand Down Expand Up @@ -426,6 +428,22 @@ async def root(Authorize: AuthJWT = Depends()):
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token")


@app.post("/validate-llm-api-key")
async def validate_llm_api_key(request: ValidateAPIKeyRequest, Authorize: AuthJWT = Depends()):
"""API to validate LLM API Key"""
source = request.model_source
api_key = request.model_api_key
valid_api_key = False
if source == "OpenAi":
valid_api_key = OpenAi(api_key=api_key).verify_access_key()
elif source == "Google Palm":
valid_api_key = GooglePalm(api_key=api_key).verify_access_key()
if valid_api_key:
return {"message": "Valid API Key", "status": "success"}
else:
return {"message": "Invalid API Key", "status": "failed"}


@app.get("/validate-open-ai-key/{open_ai_key}")
async def root(open_ai_key: str, Authorize: AuthJWT = Depends()):
"""API to validate Open AI Key"""
Expand Down
6 changes: 5 additions & 1 deletion superagi/llms/base_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,8 @@

@abstractmethod
def get_model(self):
pass
pass

Check warning on line 19 in superagi/llms/base_llm.py

View check run for this annotation

Codecov / codecov/patch

superagi/llms/base_llm.py#L19

Added line #L19 was not covered by tests

@abstractmethod
def verify_access_key(self):
pass

Check warning on line 23 in superagi/llms/base_llm.py

View check run for this annotation

Codecov / codecov/patch

superagi/llms/base_llm.py#L23

Added line #L23 was not covered by tests
16 changes: 15 additions & 1 deletion superagi/llms/google_palm.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,4 +75,18 @@
return {"response": completion, "content": completion.result}
except Exception as exception:
logger.info("Google palm Exception:", exception)
return {"error": "ERROR_GOOGLE_PALM", "message": "Google palm exception"}
return {"error": exception}

Check warning on line 78 in superagi/llms/google_palm.py

View check run for this annotation

Codecov / codecov/patch

superagi/llms/google_palm.py#L78

Added line #L78 was not covered by tests

def verify_access_key(self):
"""
Verify the access key is valid.

Returns:
bool: True if the access key is valid, False otherwise.
"""
try:
models = palm.list_models()
return True

Check warning on line 89 in superagi/llms/google_palm.py

View check run for this annotation

Codecov / codecov/patch

superagi/llms/google_palm.py#L89

Added line #L89 was not covered by tests
except Exception as exception:
logger.info("Google palm Exception:", exception)
return False
17 changes: 16 additions & 1 deletion superagi/llms/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,4 +86,19 @@
return {"error": "ERROR_INVALID_REQUEST", "message": "Openai invalid request error.."}
except Exception as exception:
logger.info("OpenAi Exception:", exception)
return {"error": "ERROR_OPENAI", "message": "Open ai exception"}
return {"error": exception}

Check warning on line 89 in superagi/llms/openai.py

View check run for this annotation

Codecov / codecov/patch

superagi/llms/openai.py#L89

Added line #L89 was not covered by tests

def verify_access_key(self):
"""
Verify the access key is valid.

Returns:
bool: True if the access key is valid, False otherwise.
"""
try:
models = openai.Model.list()
return True

Check warning on line 100 in superagi/llms/openai.py

View check run for this annotation

Codecov / codecov/patch

superagi/llms/openai.py#L100

Added line #L100 was not covered by tests
except Exception as exception:
logger.info("OpenAi Exception:", exception)
return False

6 changes: 6 additions & 0 deletions superagi/models/types/validate_llm_api_key_request.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from pydantic import BaseModel


class ValidateAPIKeyRequest(BaseModel):
model_source: str
model_api_key: str
8 changes: 8 additions & 0 deletions tests/unit_tests/llms/test_google_palm.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,11 @@ def test_chat_completion(mock_palm):
top_p=palm_instance.top_p,
max_output_tokens=int(max_tokens)
)


def test_verify_access_key():
model = 'models/text-bison-001'
api_key = 'test_key'
palm_instance = GooglePalm(api_key, model=model)
result = palm_instance.verify_access_key()
assert result is False
9 changes: 9 additions & 0 deletions tests/unit_tests/llms/test_open_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from unittest.mock import MagicMock, patch
from superagi.llms.openai import OpenAi


@patch('superagi.llms.openai.openai')
def test_chat_completion(mock_openai):
# Arrange
Expand Down Expand Up @@ -30,3 +31,11 @@ def test_chat_completion(mock_openai):
frequency_penalty=openai_instance.frequency_penalty,
presence_penalty=openai_instance.presence_penalty
)


def test_verify_access_key():
model = 'gpt-4'
api_key = 'test_key'
openai_instance = OpenAi(api_key, model=model)
result = openai_instance.verify_access_key()
assert result is False