From f772a380e80a34fad67ce0e67a8880444dac448f Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Tue, 3 Oct 2023 11:52:49 +0000 Subject: [PATCH 01/22] local_llms --- gui/pages/Content/Models/ModelForm.js | 2 +- main.py | 11 ++- requirements.txt | 1 + superagi/controllers/models_controller.py | 26 ++++++- superagi/helper/llm_loader.py | 35 +++++++++ superagi/jobs/agent_executor.py | 3 + superagi/llms/grammar/json.gbnf | 25 +++++++ superagi/llms/llm_model_factory.py | 6 ++ superagi/llms/local_llm.py | 91 +++++++++++++++++++++++ superagi/models/models.py | 14 +++- superagi/models/models_config.py | 9 +++ superagi/types/model_source_types.py | 1 + 12 files changed, 218 insertions(+), 6 deletions(-) create mode 100644 superagi/helper/llm_loader.py create mode 100644 superagi/llms/grammar/json.gbnf create mode 100644 superagi/llms/local_llm.py diff --git a/gui/pages/Content/Models/ModelForm.js b/gui/pages/Content/Models/ModelForm.js index 9671cfc10..9431e6f67 100644 --- a/gui/pages/Content/Models/ModelForm.js +++ b/gui/pages/Content/Models/ModelForm.js @@ -6,7 +6,7 @@ import {BeatLoader, ClipLoader} from "react-spinners"; import {ToastContainer, toast} from 'react-toastify'; export default function ModelForm({internalId, getModels, sendModelData}){ - const models = ['OpenAI', 'Replicate', 'Hugging Face', 'Google Palm']; + const models = ['OpenAI', 'Replicate', 'Hugging Face', 'Google Palm', 'Custom LLM']; const [selectedModel, setSelectedModel] = useState('Select a Model'); const [modelName, setModelName] = useState(''); const [modelDescription, setModelDescription] = useState(''); diff --git a/main.py b/main.py index 2e4b11852..55ae7040b 100644 --- a/main.py +++ b/main.py @@ -50,6 +50,7 @@ from superagi.llms.replicate import Replicate from superagi.llms.hugging_face import HuggingFace from superagi.models.agent_template import AgentTemplate +from superagi.models.models_config import ModelsConfig from superagi.models.organisation import Organisation from superagi.models.types.login_request import LoginRequest from superagi.models.types.validate_llm_api_key_request import ValidateAPIKeyRequest @@ -215,6 +216,13 @@ def register_toolkit_for_master_organisation(): Organisation.id == marketplace_organisation_id).first() if marketplace_organisation is not None: register_marketplace_toolkits(session, marketplace_organisation) + + def local_llm_model_config(): + existing_models_config = session.query(ModelsConfig).filter(ModelsConfig.org_id == default_user.organisation_id, ModelsConfig.provider == 'Custom LLM').first() + if existing_models_config is None: + models_config = ModelsConfig(org_id=default_user.organisation_id, provider='Custom LLM', api_key="EMPTY") + session.add(models_config) + session.commit() IterationWorkflowSeed.build_single_step_agent(session) IterationWorkflowSeed.build_task_based_agents(session) @@ -238,7 +246,8 @@ def register_toolkit_for_master_organisation(): # AgentWorkflowSeed.doc_search_and_code(session) # AgentWorkflowSeed.build_research_email_workflow(session) replace_old_iteration_workflows(session) - + local_llm_model_config() + if env != "PROD": register_toolkit_for_all_organisation() else: diff --git a/requirements.txt b/requirements.txt index ab45bb1c7..9ebdd1a49 100644 --- a/requirements.txt +++ b/requirements.txt @@ -158,3 +158,4 @@ google-generativeai==0.1.0 unstructured==0.8.1 ai21==1.2.6 typing-extensions==4.5.0 +llama_cpp_python==0.2.7 \ No newline at end of file diff --git a/superagi/controllers/models_controller.py b/superagi/controllers/models_controller.py index 2521f9abd..b58b1baf0 100644 --- a/superagi/controllers/models_controller.py +++ b/superagi/controllers/models_controller.py @@ -2,6 +2,7 @@ from superagi.helper.auth import check_auth, get_user_organisation from superagi.helper.models_helper import ModelsHelper from superagi.apm.call_log_helper import CallLogHelper +from superagi.lib.logger import logger from superagi.models.models import Models from superagi.models.models_config import ModelsConfig from superagi.config.config import get_config @@ -9,6 +10,7 @@ from fastapi_sqlalchemy import db import logging from pydantic import BaseModel +from superagi.helper.llm_loader import LLMLoader router = APIRouter() @@ -26,6 +28,7 @@ class StoreModelRequest(BaseModel): token_limit: int type: str version: str + context_length: int class ModelName (BaseModel): model: str @@ -69,7 +72,7 @@ async def verify_end_point(model_api_key: str = None, end_point: str = None, mod @router.post("/store_model", status_code=200) async def store_model(request: StoreModelRequest, organisation=Depends(get_user_organisation)): try: - return Models.store_model_details(db.session, organisation.id, request.model_name, request.description, request.end_point, request.model_provider_id, request.token_limit, request.type, request.version) + return Models.store_model_details(db.session, organisation.id, request.model_name, request.description, request.end_point, request.model_provider_id, request.token_limit, request.type, request.version, request.context_length) except Exception as e: logging.error(f"Error storing the Model Details: {str(e)}") raise HTTPException(status_code=500, detail="Internal Server Error") @@ -164,4 +167,23 @@ def get_models_details(page: int = 0): marketplace_models = Models.fetch_marketplace_list(page) marketplace_models_with_install = Models.get_model_install_details(db.session, marketplace_models, organisation_id, ModelsTypes.MARKETPLACE.value) - return marketplace_models_with_install \ No newline at end of file + return marketplace_models_with_install + +@router.get("/test_local_llm", status_code=200) +def test_local_llm(): + try: + llm_loader = LLMLoader() + llm_model = llm_loader.model + llm_grammar = llm_loader.grammar + if llm_model is None: + logger.error("Model not found.") + raise HTTPException(status_code=404, detail="Error while loading the model. Please check your model path and try again.") + if llm_grammar is None: + logger.error("Grammar not found.") + raise HTTPException(status_code=404, detail="") + + return "Model loaded successfully." + + except Exception as e: + logger.info("Error: ",e) + raise HTTPException(status_code=404, detail="Error while loading the model. Please check your model path and try again.") \ No newline at end of file diff --git a/superagi/helper/llm_loader.py b/superagi/helper/llm_loader.py new file mode 100644 index 000000000..c9ef7a973 --- /dev/null +++ b/superagi/helper/llm_loader.py @@ -0,0 +1,35 @@ +from llama_cpp import Llama +from llama_cpp import LlamaGrammar +from superagi.config.config import get_config +from superagi.lib.logger import logger + + +class LLMLoader: + _instance = None + _model = None + _grammar = None + + def __new__(cls): + if cls._instance is None: + cls._instance = super(LLMLoader, cls).__new__(cls) + return cls._instance + + @property + def model(self): + if self._model is None: + try: + self._model = Llama( + model_path="/app/local_model_path", n_ctx=int(get_config("MAX_CONTEXT_LENGTH"))) + except Exception as e: + logger.info(e) + return self._model + + @property + def grammar(self): + if self._grammar is None: + try: + self._grammar = LlamaGrammar.from_file( + "superagi/llms/grammar/json.gbnf") + except Exception as e: + logger.info(e) + return self._grammar diff --git a/superagi/jobs/agent_executor.py b/superagi/jobs/agent_executor.py index b86be941c..25e0d0d12 100644 --- a/superagi/jobs/agent_executor.py +++ b/superagi/jobs/agent_executor.py @@ -1,6 +1,7 @@ from datetime import datetime, timedelta from sqlalchemy.orm import sessionmaker +from superagi.llms.local_llm import LocalLLM import superagi.worker from superagi.agent.agent_iteration_step_handler import AgentIterationStepHandler @@ -135,6 +136,8 @@ def get_embedding(cls, model_source, model_api_key): return HuggingFace(api_key=model_api_key) if "Replicate" in model_source: return Replicate(api_key=model_api_key) + if "Custom" in model_source: + return LocalLLM() return None def _check_for_max_iterations(self, session, organisation_id, agent_config, agent_execution_id): diff --git a/superagi/llms/grammar/json.gbnf b/superagi/llms/grammar/json.gbnf new file mode 100644 index 000000000..a9537cdf9 --- /dev/null +++ b/superagi/llms/grammar/json.gbnf @@ -0,0 +1,25 @@ +root ::= object +value ::= object | array | string | number | ("true" | "false" | "null") ws + +object ::= + "{" ws ( + string ":" ws value + ("," ws string ":" ws value)* + )? "}" ws + +array ::= + "[" ws ( + value + ("," ws value)* + )? "]" ws + +string ::= + "\"" ( + [^"\\] | + "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes + )* "\"" ws + +number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws + +# Optional space: by convention, applied in this grammar after literal chars when allowed +ws ::= ([ \t\n] ws)? diff --git a/superagi/llms/llm_model_factory.py b/superagi/llms/llm_model_factory.py index 251c64a71..bd6360fdc 100644 --- a/superagi/llms/llm_model_factory.py +++ b/superagi/llms/llm_model_factory.py @@ -1,4 +1,5 @@ from superagi.llms.google_palm import GooglePalm +from superagi.llms.local_llm import LocalLLM from superagi.llms.openai import OpenAi from superagi.llms.replicate import Replicate from superagi.llms.hugging_face import HuggingFace @@ -33,6 +34,9 @@ def get_model(organisation_id, api_key, model="gpt-3.5-turbo", **kwargs): elif provider_name == 'Hugging Face': print("Provider is Hugging Face") return HuggingFace(model=model_instance.model_name, end_point=model_instance.end_point, api_key=api_key, **kwargs) + elif provider_name == 'Custom LLM': + print("Provider is Custom LLM") + return LocalLLM(model=model_instance.model_name) else: print('Unknown provider.') @@ -45,5 +49,7 @@ def build_model_with_api_key(provider_name, api_key): return GooglePalm(api_key=api_key) elif provider_name.lower() == 'hugging face': return HuggingFace(api_key=api_key) + elif provider_name.lower() == 'custom llm': + return LocalLLM(api_key=api_key) else: print('Unknown provider.') \ No newline at end of file diff --git a/superagi/llms/local_llm.py b/superagi/llms/local_llm.py new file mode 100644 index 000000000..874d4fbca --- /dev/null +++ b/superagi/llms/local_llm.py @@ -0,0 +1,91 @@ +from superagi.config.config import get_config +from superagi.lib.logger import logger +from superagi.llms.base_llm import BaseLlm +from superagi.helper.llm_loader import LLMLoader + + +class LocalLLM(BaseLlm): + def __init__(self, temperature=0.6, max_tokens=get_config("MAX_MODEL_TOKEN_LIMIT"), top_p=1, + frequency_penalty=0, + presence_penalty=0, number_of_results=1, model=None, api_key='EMPTY'): + """ + Args: + model (str): The model. + temperature (float): The temperature. + max_tokens (int): The maximum number of tokens. + top_p (float): The top p. + frequency_penalty (float): The frequency penalty. + presence_penalty (float): The presence penalty. + number_of_results (int): The number of results. + """ + self.model = model + self.api_key = api_key + self.temperature = temperature + self.max_tokens = max_tokens + self.top_p = top_p + self.frequency_penalty = frequency_penalty + self.presence_penalty = presence_penalty + self.number_of_results = number_of_results + + llm_loader = LLMLoader() + self.llm_model = llm_loader.model + self.llm_grammar = llm_loader.grammar + + def chat_completion(self, messages, max_tokens=get_config("MAX_MODEL_TOKEN_LIMIT")): + """ + Call the chat completion. + + Args: + messages (list): The messages. + max_tokens (int): The maximum number of tokens. + + Returns: + dict: The response. + """ + try: + if self.llm_model is None or self.llm_grammar is None: + logger.error("Model not found.") + return {"error": "Model loading error", "message": "Model not found. Please check your model path and try again."} + else: + response = self.llm_model.create_chat_completion(messages=messages, functions=None, function_call=None, temperature=self.temperature, top_p=self.top_p, + max_tokens=int(max_tokens), presence_penalty=self.presence_penalty, frequency_penalty=self.frequency_penalty, grammar=self.llm_grammar) + content = response["choices"][0]["message"]["content"] + logger.info(content) + return {"response": response, "content": content} + + except Exception as exception: + logger.info("Exception:", exception) + return {"error": "ERROR", "message": "Error: "+str(exception)} + + def get_source(self): + """ + Get the source. + + Returns: + str: The source. + """ + return "Local LLM" + + def get_api_key(self): + """ + Returns: + str: The API key. + """ + return self.api_key + + def get_model(self): + """ + Returns: + str: The model. + """ + return self.model + + def get_models(self): + """ + Returns: + list: The models. + """ + return self.model + + def verify_access_key(self, api_key): + return True diff --git a/superagi/models/models.py b/superagi/models/models.py index 5a58b74d6..fcdba4769 100644 --- a/superagi/models/models.py +++ b/superagi/models/models.py @@ -1,3 +1,4 @@ +import yaml from sqlalchemy import Column, Integer, String, and_ from sqlalchemy.sql import func from typing import List, Dict, Union @@ -103,7 +104,7 @@ def fetch_model_tokens(cls, session, organisation_id) -> Dict[str, int]: return {"error": "Unexpected Error Occured"} @classmethod - def store_model_details(cls, session, organisation_id, model_name, description, end_point, model_provider_id, token_limit, type, version): + def store_model_details(cls, session, organisation_id, model_name, description, end_point, model_provider_id, token_limit, type, version, context_length): from superagi.models.models_config import ModelsConfig if not model_name: return {"error": "Model Name is empty or undefined"} @@ -129,9 +130,17 @@ def store_model_details(cls, session, organisation_id, model_name, description, return model # Return error message if model not found # Check the 'provider' from ModelsConfig table - if not end_point and model["provider"] not in ['OpenAI', 'Google Palm', 'Replicate']: + if not end_point and model["provider"] not in ['OpenAI', 'Google Palm', 'Replicate','Custom LLM']: return {"error": "End Point is empty or undefined"} + if context_length is not None: + with open('config.yaml', 'r') as file: + config_data = yaml.safe_load(file) + if 'MAX_CONTEXT_LENGTH' in config_data: + config_data['MAX_CONTEXT_LENGTH'] = context_length + with open('config.yaml', 'w') as file: + yaml.safe_dump(config_data, file) + try: model = Models( model_name=model_name, @@ -229,3 +238,4 @@ def fetch_model_details(cls, session, organisation_id, model_id: int) -> Dict[st except Exception as e: logging.error(f"Unexpected Error Occured: {e}") return {"error": "Unexpected Error Occured"} + diff --git a/superagi/models/models_config.py b/superagi/models/models_config.py index 0c8c13b95..ba6edfb82 100644 --- a/superagi/models/models_config.py +++ b/superagi/models/models_config.py @@ -1,4 +1,5 @@ from sqlalchemy import Column, Integer, String, and_, distinct +from superagi.lib.logger import logger from superagi.models.base_model import DBBaseModel from superagi.models.organisation import Organisation from superagi.models.project import Project @@ -69,6 +70,9 @@ def fetch_value_by_agent_id(cls, session, agent_id: int, model: str): if not config: return None + if config.provider == 'Custom LLM': + return {"provider": config.provider, "api_key": config.api_key} if config else None + return {"provider": config.provider, "api_key": decrypt_data(config.api_key)} if config else None @classmethod @@ -123,8 +127,13 @@ def fetch_api_key(cls, session, organisation_id, model_provider): api_key_data = session.query(ModelsConfig.id, ModelsConfig.provider, ModelsConfig.api_key).filter( and_(ModelsConfig.org_id == organisation_id, ModelsConfig.provider == model_provider)).first() + logger.info(api_key_data) if api_key_data is None: return [] + elif api_key_data.provider == 'Custom LLM': + api_key = [{'id': api_key_data.id, 'provider': api_key_data.provider, + 'api_key': api_key_data.api_key}] + return api_key else: api_key = [{'id': api_key_data.id, 'provider': api_key_data.provider, 'api_key': decrypt_data(api_key_data.api_key)}] diff --git a/superagi/types/model_source_types.py b/superagi/types/model_source_types.py index f811a60c6..648c54c37 100644 --- a/superagi/types/model_source_types.py +++ b/superagi/types/model_source_types.py @@ -6,6 +6,7 @@ class ModelSourceType(Enum): OpenAI = 'OpenAi' Replicate = 'Replicate' HuggingFace = 'Hugging Face' + CustomLLM = 'Custom LLM' @classmethod def get_model_source_type(cls, name): From 427a04f78e65b8fd8dfc690f3e9c2d4e779bfea2 Mon Sep 17 00:00:00 2001 From: Rounak Bhatia Date: Tue, 3 Oct 2023 18:00:10 +0530 Subject: [PATCH 02/22] local_llms --- tests/unit_tests/controllers/test_models_controller.py | 3 ++- tests/unit_tests/models/test_models.py | 3 +++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/unit_tests/controllers/test_models_controller.py b/tests/unit_tests/controllers/test_models_controller.py index 489cff636..dd83a57bf 100644 --- a/tests/unit_tests/controllers/test_models_controller.py +++ b/tests/unit_tests/controllers/test_models_controller.py @@ -50,7 +50,8 @@ def test_store_model_success(mock_get_db): "model_provider_id": 1, "token_limit": 10, "type": "mock_type", - "version": "mock_version" + "version": "mock_version", + "context_length":4096 } with patch('superagi.helper.auth.get_user_organisation') as mock_get_user_org, \ patch('superagi.helper.auth.db') as mock_auth_db: diff --git a/tests/unit_tests/models/test_models.py b/tests/unit_tests/models/test_models.py index 3bdc43075..acd13eefe 100644 --- a/tests/unit_tests/models/test_models.py +++ b/tests/unit_tests/models/test_models.py @@ -133,6 +133,7 @@ def test_store_model_details_when_model_exists(mock_session): token_limit=500, type="type", version="v1.0", + context_length=4096 ) # Assert @@ -161,6 +162,7 @@ def test_store_model_details_when_model_not_exists(mock_session, monkeypatch): token_limit=500, type="type", version="v1.0", + context_length=4096 ) # Assert @@ -187,6 +189,7 @@ def test_store_model_details_when_unexpected_error_occurs(mock_session, monkeypa token_limit=500, type="type", version="v1.0", + context_length=4096 ) # Assert From 874635cc67dce908721f5c562311e1b4d674875c Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Wed, 4 Oct 2023 05:49:35 +0000 Subject: [PATCH 03/22] local_llms --- superagi/controllers/models_controller.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/superagi/controllers/models_controller.py b/superagi/controllers/models_controller.py index b58b1baf0..f11d22a56 100644 --- a/superagi/controllers/models_controller.py +++ b/superagi/controllers/models_controller.py @@ -180,7 +180,7 @@ def test_local_llm(): raise HTTPException(status_code=404, detail="Error while loading the model. Please check your model path and try again.") if llm_grammar is None: logger.error("Grammar not found.") - raise HTTPException(status_code=404, detail="") + raise HTTPException(status_code=404, detail="Grammar not found.") return "Model loaded successfully." From d931ac1b3062b32d3bc13380e069027ab9bed8d4 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Wed, 4 Oct 2023 06:16:52 +0000 Subject: [PATCH 04/22] local_llms --- superagi/controllers/models_controller.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/superagi/controllers/models_controller.py b/superagi/controllers/models_controller.py index f11d22a56..1971e44d5 100644 --- a/superagi/controllers/models_controller.py +++ b/superagi/controllers/models_controller.py @@ -181,7 +181,16 @@ def test_local_llm(): if llm_grammar is None: logger.error("Grammar not found.") raise HTTPException(status_code=404, detail="Grammar not found.") - + + messages = [ + {"role":"system", + "content":"You are an AI assistant. Give response in a proper JSON format"}, + {"role":"user", + "content":"Hi!"} + ] + response = llm_model.create_chat_completion(messages=messages, grammar=llm_grammar) + content = response["choices"][0]["message"]["content"] + logger.info(content) return "Model loaded successfully." except Exception as e: From e746c1e23f6d0dbb273566ee6aa3f1ec4e17b35f Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Wed, 4 Oct 2023 10:09:02 +0000 Subject: [PATCH 05/22] local_llms --- docker-compose.yaml | 2 + gui/pages/Content/Models/ModelForm.js | 11 ++- gui/pages/api/DashboardService.js | 4 +- main.py | 4 +- .../versions/9270eb5a8475_local_llms.py | 84 +++++++++++++++++++ superagi/controllers/models_controller.py | 4 +- superagi/helper/llm_loader.py | 11 ++- superagi/llms/llm_model_factory.py | 8 +- superagi/llms/local_llm.py | 5 +- superagi/models/models.py | 16 ++-- superagi/models/models_config.py | 4 +- superagi/types/model_source_types.py | 2 +- 12 files changed, 126 insertions(+), 29 deletions(-) create mode 100644 migrations/versions/9270eb5a8475_local_llms.py diff --git a/docker-compose.yaml b/docker-compose.yaml index 94044916b..7f37296fd 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -3,6 +3,7 @@ services: backend: volumes: - "./:/app" + - "/home/ubuntu/local_models/vicuna-13B-v1.5-GGUF/vicuna-13b-v1.5.Q8_0.gguf:/app/local_model_path" build: . depends_on: - super__redis @@ -14,6 +15,7 @@ services: volumes: - "./:/app" - "${EXTERNAL_RESOURCE_DIR:-./workspace}:/app/ext" + - "/home/ubuntu/local_models/vicuna-13B-v1.5-GGUF/vicuna-13b-v1.5.Q8_0.gguf:/app/local_model_path" build: . depends_on: - super__redis diff --git a/gui/pages/Content/Models/ModelForm.js b/gui/pages/Content/Models/ModelForm.js index 9431e6f67..0e790e38f 100644 --- a/gui/pages/Content/Models/ModelForm.js +++ b/gui/pages/Content/Models/ModelForm.js @@ -6,7 +6,7 @@ import {BeatLoader, ClipLoader} from "react-spinners"; import {ToastContainer, toast} from 'react-toastify'; export default function ModelForm({internalId, getModels, sendModelData}){ - const models = ['OpenAI', 'Replicate', 'Hugging Face', 'Google Palm', 'Custom LLM']; + const models = ['OpenAI', 'Replicate', 'Hugging Face', 'Google Palm', 'Local LLM']; const [selectedModel, setSelectedModel] = useState('Select a Model'); const [modelName, setModelName] = useState(''); const [modelDescription, setModelDescription] = useState(''); @@ -14,6 +14,7 @@ export default function ModelForm({internalId, getModels, sendModelData}){ const [modelEndpoint, setModelEndpoint] = useState(''); const [modelDropdown, setModelDropdown] = useState(false); const [modelVersion, setModelVersion] = useState(''); + const [modelContextLength, setContextLength] = useState(4096); const [tokenError, setTokenError] = useState(false); const [lockAddition, setLockAddition] = useState(true); const [isLoading, setIsLoading] = useState(false) @@ -87,7 +88,7 @@ export default function ModelForm({internalId, getModels, sendModelData}){ } const storeModelDetails = (modelProviderId) => { - storeModel(modelName,modelDescription, modelEndpoint, modelProviderId, modelTokenLimit, "Custom", modelVersion).then((response) =>{ + storeModel(modelName,modelDescription, modelEndpoint, modelProviderId, modelTokenLimit, "Custom", modelVersion, modelContextLength).then((response) =>{ setIsLoading(false) let data = response.data if (data.error) { @@ -155,6 +156,12 @@ export default function ModelForm({internalId, getModels, sendModelData}){ onChange={(event) => setModelVersion(event.target.value)}/> } + {(selectedModel === 'Local LLM') &&
+ Model Context Length + setContextLength(event.target.value)}/> +
} +
Token Limit { }); } -export const storeModel = (model_name, description, end_point, model_provider_id, token_limit, type, version) => { - return api.post(`/models_controller/store_model`,{model_name, description, end_point, model_provider_id, token_limit, type, version}); +export const storeModel = (model_name, description, end_point, model_provider_id, token_limit, type, version, context_length) => { + return api.post(`/models_controller/store_model`,{model_name, description, end_point, model_provider_id, token_limit, type, version, context_length}); } export const fetchModels = () => { diff --git a/main.py b/main.py index 55ae7040b..cf4807483 100644 --- a/main.py +++ b/main.py @@ -218,9 +218,9 @@ def register_toolkit_for_master_organisation(): register_marketplace_toolkits(session, marketplace_organisation) def local_llm_model_config(): - existing_models_config = session.query(ModelsConfig).filter(ModelsConfig.org_id == default_user.organisation_id, ModelsConfig.provider == 'Custom LLM').first() + existing_models_config = session.query(ModelsConfig).filter(ModelsConfig.org_id == default_user.organisation_id, ModelsConfig.provider == 'Local LLM').first() if existing_models_config is None: - models_config = ModelsConfig(org_id=default_user.organisation_id, provider='Custom LLM', api_key="EMPTY") + models_config = ModelsConfig(org_id=default_user.organisation_id, provider='Local LLM', api_key="EMPTY") session.add(models_config) session.commit() diff --git a/migrations/versions/9270eb5a8475_local_llms.py b/migrations/versions/9270eb5a8475_local_llms.py new file mode 100644 index 000000000..865561f0b --- /dev/null +++ b/migrations/versions/9270eb5a8475_local_llms.py @@ -0,0 +1,84 @@ +"""local_llms + +Revision ID: 9270eb5a8475 +Revises: 3867bb00a495 +Create Date: 2023-10-04 09:26:33.865424 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '9270eb5a8475' +down_revision = '3867bb00a495' +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index('ix_agent_schedule_agent_id', table_name='agent_schedule') + op.drop_index('ix_agent_schedule_expiry_date', table_name='agent_schedule') + op.drop_index('ix_agent_schedule_status', table_name='agent_schedule') + op.alter_column('agent_workflow_steps', 'unique_id', + existing_type=sa.VARCHAR(), + nullable=True) + op.alter_column('agent_workflow_steps', 'step_type', + existing_type=sa.VARCHAR(), + nullable=True) + op.drop_column('agent_workflows', 'organisation_id') + op.drop_index('ix_events_agent_id', table_name='events') + op.drop_index('ix_events_event_property', table_name='events') + op.drop_index('ix_events_org_id', table_name='events') + op.alter_column('knowledge_configs', 'knowledge_id', + existing_type=sa.INTEGER(), + nullable=True) + op.alter_column('knowledges', 'name', + existing_type=sa.VARCHAR(), + nullable=True) + op.add_column('models', sa.Column('context_length', sa.Integer(), nullable=True)) + op.alter_column('vector_db_configs', 'vector_db_id', + existing_type=sa.INTEGER(), + nullable=True) + op.alter_column('vector_db_indices', 'name', + existing_type=sa.VARCHAR(), + nullable=True) + op.alter_column('vector_dbs', 'name', + existing_type=sa.VARCHAR(), + nullable=True) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.alter_column('vector_dbs', 'name', + existing_type=sa.VARCHAR(), + nullable=False) + op.alter_column('vector_db_indices', 'name', + existing_type=sa.VARCHAR(), + nullable=False) + op.alter_column('vector_db_configs', 'vector_db_id', + existing_type=sa.INTEGER(), + nullable=False) + op.drop_column('models', 'context_length') + op.alter_column('knowledges', 'name', + existing_type=sa.VARCHAR(), + nullable=False) + op.alter_column('knowledge_configs', 'knowledge_id', + existing_type=sa.INTEGER(), + nullable=False) + op.create_index('ix_events_org_id', 'events', ['org_id'], unique=False) + op.create_index('ix_events_event_property', 'events', ['event_property'], unique=False) + op.create_index('ix_events_agent_id', 'events', ['agent_id'], unique=False) + op.add_column('agent_workflows', sa.Column('organisation_id', sa.INTEGER(), autoincrement=False, nullable=True)) + op.alter_column('agent_workflow_steps', 'step_type', + existing_type=sa.VARCHAR(), + nullable=False) + op.alter_column('agent_workflow_steps', 'unique_id', + existing_type=sa.VARCHAR(), + nullable=False) + op.create_index('ix_agent_schedule_status', 'agent_schedule', ['status'], unique=False) + op.create_index('ix_agent_schedule_expiry_date', 'agent_schedule', ['expiry_date'], unique=False) + op.create_index('ix_agent_schedule_agent_id', 'agent_schedule', ['agent_id'], unique=False) + # ### end Alembic commands ### diff --git a/superagi/controllers/models_controller.py b/superagi/controllers/models_controller.py index 1971e44d5..40f202461 100644 --- a/superagi/controllers/models_controller.py +++ b/superagi/controllers/models_controller.py @@ -72,6 +72,8 @@ async def verify_end_point(model_api_key: str = None, end_point: str = None, mod @router.post("/store_model", status_code=200) async def store_model(request: StoreModelRequest, organisation=Depends(get_user_organisation)): try: + #context_length = 4096 + logger.info(request) return Models.store_model_details(db.session, organisation.id, request.model_name, request.description, request.end_point, request.model_provider_id, request.token_limit, request.type, request.version, request.context_length) except Exception as e: logging.error(f"Error storing the Model Details: {str(e)}") @@ -172,7 +174,7 @@ def get_models_details(page: int = 0): @router.get("/test_local_llm", status_code=200) def test_local_llm(): try: - llm_loader = LLMLoader() + llm_loader = LLMLoader(context_length=4096) llm_model = llm_loader.model llm_grammar = llm_loader.grammar if llm_model is None: diff --git a/superagi/helper/llm_loader.py b/superagi/helper/llm_loader.py index c9ef7a973..8c2b19e45 100644 --- a/superagi/helper/llm_loader.py +++ b/superagi/helper/llm_loader.py @@ -9,19 +9,22 @@ class LLMLoader: _model = None _grammar = None - def __new__(cls): + def __new__(cls, *args, **kwargs): if cls._instance is None: cls._instance = super(LLMLoader, cls).__new__(cls) return cls._instance + def __init__(self, context_length): + self.context_length = context_length + @property def model(self): if self._model is None: try: self._model = Llama( - model_path="/app/local_model_path", n_ctx=int(get_config("MAX_CONTEXT_LENGTH"))) + model_path="/app/local_model_path", n_ctx=self.context_length) except Exception as e: - logger.info(e) + logger.error(e) return self._model @property @@ -31,5 +34,5 @@ def grammar(self): self._grammar = LlamaGrammar.from_file( "superagi/llms/grammar/json.gbnf") except Exception as e: - logger.info(e) + logger.error(e) return self._grammar diff --git a/superagi/llms/llm_model_factory.py b/superagi/llms/llm_model_factory.py index bd6360fdc..345c4f8c7 100644 --- a/superagi/llms/llm_model_factory.py +++ b/superagi/llms/llm_model_factory.py @@ -34,9 +34,9 @@ def get_model(organisation_id, api_key, model="gpt-3.5-turbo", **kwargs): elif provider_name == 'Hugging Face': print("Provider is Hugging Face") return HuggingFace(model=model_instance.model_name, end_point=model_instance.end_point, api_key=api_key, **kwargs) - elif provider_name == 'Custom LLM': - print("Provider is Custom LLM") - return LocalLLM(model=model_instance.model_name) + elif provider_name == 'Local LLM': + print("Provider is Local LLM") + return LocalLLM(model=model_instance.model_name, context_length=model_instance.context_length) else: print('Unknown provider.') @@ -49,7 +49,7 @@ def build_model_with_api_key(provider_name, api_key): return GooglePalm(api_key=api_key) elif provider_name.lower() == 'hugging face': return HuggingFace(api_key=api_key) - elif provider_name.lower() == 'custom llm': + elif provider_name.lower() == 'local llm': return LocalLLM(api_key=api_key) else: print('Unknown provider.') \ No newline at end of file diff --git a/superagi/llms/local_llm.py b/superagi/llms/local_llm.py index 874d4fbca..608afa289 100644 --- a/superagi/llms/local_llm.py +++ b/superagi/llms/local_llm.py @@ -7,7 +7,7 @@ class LocalLLM(BaseLlm): def __init__(self, temperature=0.6, max_tokens=get_config("MAX_MODEL_TOKEN_LIMIT"), top_p=1, frequency_penalty=0, - presence_penalty=0, number_of_results=1, model=None, api_key='EMPTY'): + presence_penalty=0, number_of_results=1, model=None, api_key='EMPTY', context_length=4096): """ Args: model (str): The model. @@ -26,8 +26,9 @@ def __init__(self, temperature=0.6, max_tokens=get_config("MAX_MODEL_TOKEN_LIMIT self.frequency_penalty = frequency_penalty self.presence_penalty = presence_penalty self.number_of_results = number_of_results + self.context_length = context_length - llm_loader = LLMLoader() + llm_loader = LLMLoader(self.context_length) self.llm_model = llm_loader.model self.llm_grammar = llm_loader.grammar diff --git a/superagi/models/models.py b/superagi/models/models.py index fcdba4769..559d6d677 100644 --- a/superagi/models/models.py +++ b/superagi/models/models.py @@ -6,6 +6,7 @@ from superagi.controllers.types.models_types import ModelsTypes from superagi.helper.encyption_helper import decrypt_data import requests, logging +from superagi.lib.logger import logger marketplace_url = "https://app.superagi.com/api" # marketplace_url = "http://localhost:8001" @@ -40,6 +41,7 @@ class Models(DBBaseModel): version = Column(String, nullable=False) org_id = Column(Integer, nullable=False) model_features = Column(String, nullable=False) + context_length = Column(Integer, nullable=True) def __repr__(self): """ @@ -130,16 +132,11 @@ def store_model_details(cls, session, organisation_id, model_name, description, return model # Return error message if model not found # Check the 'provider' from ModelsConfig table - if not end_point and model["provider"] not in ['OpenAI', 'Google Palm', 'Replicate','Custom LLM']: + if not end_point and model["provider"] not in ['OpenAI', 'Google Palm', 'Replicate','Local LLM']: return {"error": "End Point is empty or undefined"} - if context_length is not None: - with open('config.yaml', 'r') as file: - config_data = yaml.safe_load(file) - if 'MAX_CONTEXT_LENGTH' in config_data: - config_data['MAX_CONTEXT_LENGTH'] = context_length - with open('config.yaml', 'w') as file: - yaml.safe_dump(config_data, file) + if context_length is None: + context_length = 0 try: model = Models( @@ -151,7 +148,8 @@ def store_model_details(cls, session, organisation_id, model_name, description, type=type, version=version, org_id=organisation_id, - model_features='' + model_features='', + context_length=context_length ) session.add(model) session.commit() diff --git a/superagi/models/models_config.py b/superagi/models/models_config.py index ba6edfb82..1c2091489 100644 --- a/superagi/models/models_config.py +++ b/superagi/models/models_config.py @@ -70,7 +70,7 @@ def fetch_value_by_agent_id(cls, session, agent_id: int, model: str): if not config: return None - if config.provider == 'Custom LLM': + if config.provider == 'Local LLM': return {"provider": config.provider, "api_key": config.api_key} if config else None return {"provider": config.provider, "api_key": decrypt_data(config.api_key)} if config else None @@ -130,7 +130,7 @@ def fetch_api_key(cls, session, organisation_id, model_provider): logger.info(api_key_data) if api_key_data is None: return [] - elif api_key_data.provider == 'Custom LLM': + elif api_key_data.provider == 'Local LLM': api_key = [{'id': api_key_data.id, 'provider': api_key_data.provider, 'api_key': api_key_data.api_key}] return api_key diff --git a/superagi/types/model_source_types.py b/superagi/types/model_source_types.py index 648c54c37..6e9de18ad 100644 --- a/superagi/types/model_source_types.py +++ b/superagi/types/model_source_types.py @@ -6,7 +6,7 @@ class ModelSourceType(Enum): OpenAI = 'OpenAi' Replicate = 'Replicate' HuggingFace = 'Hugging Face' - CustomLLM = 'Custom LLM' + LocalLLM = 'Local LLM' @classmethod def get_model_source_type(cls, name): From 2cdd551057e8a9faaa95860b63ddc57ff97bdbe2 Mon Sep 17 00:00:00 2001 From: Rounak Bhatia Date: Wed, 4 Oct 2023 15:44:01 +0530 Subject: [PATCH 06/22] fixes --- docker-compose.yaml | 2 -- 1 file changed, 2 deletions(-) diff --git a/docker-compose.yaml b/docker-compose.yaml index 7f37296fd..94044916b 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -3,7 +3,6 @@ services: backend: volumes: - "./:/app" - - "/home/ubuntu/local_models/vicuna-13B-v1.5-GGUF/vicuna-13b-v1.5.Q8_0.gguf:/app/local_model_path" build: . depends_on: - super__redis @@ -15,7 +14,6 @@ services: volumes: - "./:/app" - "${EXTERNAL_RESOURCE_DIR:-./workspace}:/app/ext" - - "/home/ubuntu/local_models/vicuna-13B-v1.5-GGUF/vicuna-13b-v1.5.Q8_0.gguf:/app/local_model_path" build: . depends_on: - super__redis From f8d60842b604a3e0431d57c13371d38553df12cf Mon Sep 17 00:00:00 2001 From: namansleeps <122260931+namansleeps@users.noreply.github.com> Date: Wed, 4 Oct 2023 15:46:30 +0530 Subject: [PATCH 07/22] models error fixed (#1308) --- gui/pages/Content/Models/ModelForm.js | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/gui/pages/Content/Models/ModelForm.js b/gui/pages/Content/Models/ModelForm.js index 9671cfc10..d8b248c56 100644 --- a/gui/pages/Content/Models/ModelForm.js +++ b/gui/pages/Content/Models/ModelForm.js @@ -1,5 +1,5 @@ import React, {useEffect, useRef, useState} from "react"; -import {removeTab, openNewTab, createInternalId, modelGetAuth, getUserClick} from "@/utils/utils"; +import {removeTab, openNewTab, createInternalId, getUserClick} from "@/utils/utils"; import Image from "next/image"; import {fetchApiKey, storeModel, verifyEndPoint} from "@/pages/api/DashboardService"; import {BeatLoader, ClipLoader} from "react-spinners"; @@ -66,11 +66,9 @@ export default function ModelForm({internalId, getModels, sendModelData}){ { const modelProviderId = response.data[0].id verifyEndPoint(response.data[0].api_key, modelEndpoint, selectedModel).then((response) =>{ - if(response.data.success) { + if(response.data.success) storeModelDetails(modelProviderId) - getUserClick("Model Added Successfully",{'type': selectedModel}) - } - else { + else{ toast.error("The Endpoint is not Valid",{autoClose: 1800}); setIsLoading(false); } From 9e7e6861e8f274644ad9b5cc324ac787620ea605 Mon Sep 17 00:00:00 2001 From: Rounak Bhatia Date: Wed, 4 Oct 2023 10:18:32 +0000 Subject: [PATCH 08/22] local_llms --- docker-compose.yaml | 2 + gui/pages/Content/Models/ModelForm.js | 11 ++- gui/pages/api/DashboardService.js | 4 +- main.py | 4 +- .../versions/9270eb5a8475_local_llms.py | 84 +++++++++++++++++++ superagi/controllers/models_controller.py | 4 +- superagi/helper/llm_loader.py | 11 ++- superagi/llms/llm_model_factory.py | 8 +- superagi/llms/local_llm.py | 5 +- superagi/models/models.py | 16 ++-- superagi/models/models_config.py | 4 +- superagi/types/model_source_types.py | 2 +- 12 files changed, 126 insertions(+), 29 deletions(-) create mode 100644 migrations/versions/9270eb5a8475_local_llms.py diff --git a/docker-compose.yaml b/docker-compose.yaml index 94044916b..7f37296fd 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -3,6 +3,7 @@ services: backend: volumes: - "./:/app" + - "/home/ubuntu/local_models/vicuna-13B-v1.5-GGUF/vicuna-13b-v1.5.Q8_0.gguf:/app/local_model_path" build: . depends_on: - super__redis @@ -14,6 +15,7 @@ services: volumes: - "./:/app" - "${EXTERNAL_RESOURCE_DIR:-./workspace}:/app/ext" + - "/home/ubuntu/local_models/vicuna-13B-v1.5-GGUF/vicuna-13b-v1.5.Q8_0.gguf:/app/local_model_path" build: . depends_on: - super__redis diff --git a/gui/pages/Content/Models/ModelForm.js b/gui/pages/Content/Models/ModelForm.js index 9431e6f67..0e790e38f 100644 --- a/gui/pages/Content/Models/ModelForm.js +++ b/gui/pages/Content/Models/ModelForm.js @@ -6,7 +6,7 @@ import {BeatLoader, ClipLoader} from "react-spinners"; import {ToastContainer, toast} from 'react-toastify'; export default function ModelForm({internalId, getModels, sendModelData}){ - const models = ['OpenAI', 'Replicate', 'Hugging Face', 'Google Palm', 'Custom LLM']; + const models = ['OpenAI', 'Replicate', 'Hugging Face', 'Google Palm', 'Local LLM']; const [selectedModel, setSelectedModel] = useState('Select a Model'); const [modelName, setModelName] = useState(''); const [modelDescription, setModelDescription] = useState(''); @@ -14,6 +14,7 @@ export default function ModelForm({internalId, getModels, sendModelData}){ const [modelEndpoint, setModelEndpoint] = useState(''); const [modelDropdown, setModelDropdown] = useState(false); const [modelVersion, setModelVersion] = useState(''); + const [modelContextLength, setContextLength] = useState(4096); const [tokenError, setTokenError] = useState(false); const [lockAddition, setLockAddition] = useState(true); const [isLoading, setIsLoading] = useState(false) @@ -87,7 +88,7 @@ export default function ModelForm({internalId, getModels, sendModelData}){ } const storeModelDetails = (modelProviderId) => { - storeModel(modelName,modelDescription, modelEndpoint, modelProviderId, modelTokenLimit, "Custom", modelVersion).then((response) =>{ + storeModel(modelName,modelDescription, modelEndpoint, modelProviderId, modelTokenLimit, "Custom", modelVersion, modelContextLength).then((response) =>{ setIsLoading(false) let data = response.data if (data.error) { @@ -155,6 +156,12 @@ export default function ModelForm({internalId, getModels, sendModelData}){ onChange={(event) => setModelVersion(event.target.value)}/>
} + {(selectedModel === 'Local LLM') &&
+ Model Context Length + setContextLength(event.target.value)}/> +
} +
Token Limit { }); } -export const storeModel = (model_name, description, end_point, model_provider_id, token_limit, type, version) => { - return api.post(`/models_controller/store_model`,{model_name, description, end_point, model_provider_id, token_limit, type, version}); +export const storeModel = (model_name, description, end_point, model_provider_id, token_limit, type, version, context_length) => { + return api.post(`/models_controller/store_model`,{model_name, description, end_point, model_provider_id, token_limit, type, version, context_length}); } export const fetchModels = () => { diff --git a/main.py b/main.py index 55ae7040b..cf4807483 100644 --- a/main.py +++ b/main.py @@ -218,9 +218,9 @@ def register_toolkit_for_master_organisation(): register_marketplace_toolkits(session, marketplace_organisation) def local_llm_model_config(): - existing_models_config = session.query(ModelsConfig).filter(ModelsConfig.org_id == default_user.organisation_id, ModelsConfig.provider == 'Custom LLM').first() + existing_models_config = session.query(ModelsConfig).filter(ModelsConfig.org_id == default_user.organisation_id, ModelsConfig.provider == 'Local LLM').first() if existing_models_config is None: - models_config = ModelsConfig(org_id=default_user.organisation_id, provider='Custom LLM', api_key="EMPTY") + models_config = ModelsConfig(org_id=default_user.organisation_id, provider='Local LLM', api_key="EMPTY") session.add(models_config) session.commit() diff --git a/migrations/versions/9270eb5a8475_local_llms.py b/migrations/versions/9270eb5a8475_local_llms.py new file mode 100644 index 000000000..865561f0b --- /dev/null +++ b/migrations/versions/9270eb5a8475_local_llms.py @@ -0,0 +1,84 @@ +"""local_llms + +Revision ID: 9270eb5a8475 +Revises: 3867bb00a495 +Create Date: 2023-10-04 09:26:33.865424 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '9270eb5a8475' +down_revision = '3867bb00a495' +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index('ix_agent_schedule_agent_id', table_name='agent_schedule') + op.drop_index('ix_agent_schedule_expiry_date', table_name='agent_schedule') + op.drop_index('ix_agent_schedule_status', table_name='agent_schedule') + op.alter_column('agent_workflow_steps', 'unique_id', + existing_type=sa.VARCHAR(), + nullable=True) + op.alter_column('agent_workflow_steps', 'step_type', + existing_type=sa.VARCHAR(), + nullable=True) + op.drop_column('agent_workflows', 'organisation_id') + op.drop_index('ix_events_agent_id', table_name='events') + op.drop_index('ix_events_event_property', table_name='events') + op.drop_index('ix_events_org_id', table_name='events') + op.alter_column('knowledge_configs', 'knowledge_id', + existing_type=sa.INTEGER(), + nullable=True) + op.alter_column('knowledges', 'name', + existing_type=sa.VARCHAR(), + nullable=True) + op.add_column('models', sa.Column('context_length', sa.Integer(), nullable=True)) + op.alter_column('vector_db_configs', 'vector_db_id', + existing_type=sa.INTEGER(), + nullable=True) + op.alter_column('vector_db_indices', 'name', + existing_type=sa.VARCHAR(), + nullable=True) + op.alter_column('vector_dbs', 'name', + existing_type=sa.VARCHAR(), + nullable=True) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.alter_column('vector_dbs', 'name', + existing_type=sa.VARCHAR(), + nullable=False) + op.alter_column('vector_db_indices', 'name', + existing_type=sa.VARCHAR(), + nullable=False) + op.alter_column('vector_db_configs', 'vector_db_id', + existing_type=sa.INTEGER(), + nullable=False) + op.drop_column('models', 'context_length') + op.alter_column('knowledges', 'name', + existing_type=sa.VARCHAR(), + nullable=False) + op.alter_column('knowledge_configs', 'knowledge_id', + existing_type=sa.INTEGER(), + nullable=False) + op.create_index('ix_events_org_id', 'events', ['org_id'], unique=False) + op.create_index('ix_events_event_property', 'events', ['event_property'], unique=False) + op.create_index('ix_events_agent_id', 'events', ['agent_id'], unique=False) + op.add_column('agent_workflows', sa.Column('organisation_id', sa.INTEGER(), autoincrement=False, nullable=True)) + op.alter_column('agent_workflow_steps', 'step_type', + existing_type=sa.VARCHAR(), + nullable=False) + op.alter_column('agent_workflow_steps', 'unique_id', + existing_type=sa.VARCHAR(), + nullable=False) + op.create_index('ix_agent_schedule_status', 'agent_schedule', ['status'], unique=False) + op.create_index('ix_agent_schedule_expiry_date', 'agent_schedule', ['expiry_date'], unique=False) + op.create_index('ix_agent_schedule_agent_id', 'agent_schedule', ['agent_id'], unique=False) + # ### end Alembic commands ### diff --git a/superagi/controllers/models_controller.py b/superagi/controllers/models_controller.py index 1971e44d5..40f202461 100644 --- a/superagi/controllers/models_controller.py +++ b/superagi/controllers/models_controller.py @@ -72,6 +72,8 @@ async def verify_end_point(model_api_key: str = None, end_point: str = None, mod @router.post("/store_model", status_code=200) async def store_model(request: StoreModelRequest, organisation=Depends(get_user_organisation)): try: + #context_length = 4096 + logger.info(request) return Models.store_model_details(db.session, organisation.id, request.model_name, request.description, request.end_point, request.model_provider_id, request.token_limit, request.type, request.version, request.context_length) except Exception as e: logging.error(f"Error storing the Model Details: {str(e)}") @@ -172,7 +174,7 @@ def get_models_details(page: int = 0): @router.get("/test_local_llm", status_code=200) def test_local_llm(): try: - llm_loader = LLMLoader() + llm_loader = LLMLoader(context_length=4096) llm_model = llm_loader.model llm_grammar = llm_loader.grammar if llm_model is None: diff --git a/superagi/helper/llm_loader.py b/superagi/helper/llm_loader.py index c9ef7a973..8c2b19e45 100644 --- a/superagi/helper/llm_loader.py +++ b/superagi/helper/llm_loader.py @@ -9,19 +9,22 @@ class LLMLoader: _model = None _grammar = None - def __new__(cls): + def __new__(cls, *args, **kwargs): if cls._instance is None: cls._instance = super(LLMLoader, cls).__new__(cls) return cls._instance + def __init__(self, context_length): + self.context_length = context_length + @property def model(self): if self._model is None: try: self._model = Llama( - model_path="/app/local_model_path", n_ctx=int(get_config("MAX_CONTEXT_LENGTH"))) + model_path="/app/local_model_path", n_ctx=self.context_length) except Exception as e: - logger.info(e) + logger.error(e) return self._model @property @@ -31,5 +34,5 @@ def grammar(self): self._grammar = LlamaGrammar.from_file( "superagi/llms/grammar/json.gbnf") except Exception as e: - logger.info(e) + logger.error(e) return self._grammar diff --git a/superagi/llms/llm_model_factory.py b/superagi/llms/llm_model_factory.py index bd6360fdc..345c4f8c7 100644 --- a/superagi/llms/llm_model_factory.py +++ b/superagi/llms/llm_model_factory.py @@ -34,9 +34,9 @@ def get_model(organisation_id, api_key, model="gpt-3.5-turbo", **kwargs): elif provider_name == 'Hugging Face': print("Provider is Hugging Face") return HuggingFace(model=model_instance.model_name, end_point=model_instance.end_point, api_key=api_key, **kwargs) - elif provider_name == 'Custom LLM': - print("Provider is Custom LLM") - return LocalLLM(model=model_instance.model_name) + elif provider_name == 'Local LLM': + print("Provider is Local LLM") + return LocalLLM(model=model_instance.model_name, context_length=model_instance.context_length) else: print('Unknown provider.') @@ -49,7 +49,7 @@ def build_model_with_api_key(provider_name, api_key): return GooglePalm(api_key=api_key) elif provider_name.lower() == 'hugging face': return HuggingFace(api_key=api_key) - elif provider_name.lower() == 'custom llm': + elif provider_name.lower() == 'local llm': return LocalLLM(api_key=api_key) else: print('Unknown provider.') \ No newline at end of file diff --git a/superagi/llms/local_llm.py b/superagi/llms/local_llm.py index 874d4fbca..608afa289 100644 --- a/superagi/llms/local_llm.py +++ b/superagi/llms/local_llm.py @@ -7,7 +7,7 @@ class LocalLLM(BaseLlm): def __init__(self, temperature=0.6, max_tokens=get_config("MAX_MODEL_TOKEN_LIMIT"), top_p=1, frequency_penalty=0, - presence_penalty=0, number_of_results=1, model=None, api_key='EMPTY'): + presence_penalty=0, number_of_results=1, model=None, api_key='EMPTY', context_length=4096): """ Args: model (str): The model. @@ -26,8 +26,9 @@ def __init__(self, temperature=0.6, max_tokens=get_config("MAX_MODEL_TOKEN_LIMIT self.frequency_penalty = frequency_penalty self.presence_penalty = presence_penalty self.number_of_results = number_of_results + self.context_length = context_length - llm_loader = LLMLoader() + llm_loader = LLMLoader(self.context_length) self.llm_model = llm_loader.model self.llm_grammar = llm_loader.grammar diff --git a/superagi/models/models.py b/superagi/models/models.py index fcdba4769..559d6d677 100644 --- a/superagi/models/models.py +++ b/superagi/models/models.py @@ -6,6 +6,7 @@ from superagi.controllers.types.models_types import ModelsTypes from superagi.helper.encyption_helper import decrypt_data import requests, logging +from superagi.lib.logger import logger marketplace_url = "https://app.superagi.com/api" # marketplace_url = "http://localhost:8001" @@ -40,6 +41,7 @@ class Models(DBBaseModel): version = Column(String, nullable=False) org_id = Column(Integer, nullable=False) model_features = Column(String, nullable=False) + context_length = Column(Integer, nullable=True) def __repr__(self): """ @@ -130,16 +132,11 @@ def store_model_details(cls, session, organisation_id, model_name, description, return model # Return error message if model not found # Check the 'provider' from ModelsConfig table - if not end_point and model["provider"] not in ['OpenAI', 'Google Palm', 'Replicate','Custom LLM']: + if not end_point and model["provider"] not in ['OpenAI', 'Google Palm', 'Replicate','Local LLM']: return {"error": "End Point is empty or undefined"} - if context_length is not None: - with open('config.yaml', 'r') as file: - config_data = yaml.safe_load(file) - if 'MAX_CONTEXT_LENGTH' in config_data: - config_data['MAX_CONTEXT_LENGTH'] = context_length - with open('config.yaml', 'w') as file: - yaml.safe_dump(config_data, file) + if context_length is None: + context_length = 0 try: model = Models( @@ -151,7 +148,8 @@ def store_model_details(cls, session, organisation_id, model_name, description, type=type, version=version, org_id=organisation_id, - model_features='' + model_features='', + context_length=context_length ) session.add(model) session.commit() diff --git a/superagi/models/models_config.py b/superagi/models/models_config.py index ba6edfb82..1c2091489 100644 --- a/superagi/models/models_config.py +++ b/superagi/models/models_config.py @@ -70,7 +70,7 @@ def fetch_value_by_agent_id(cls, session, agent_id: int, model: str): if not config: return None - if config.provider == 'Custom LLM': + if config.provider == 'Local LLM': return {"provider": config.provider, "api_key": config.api_key} if config else None return {"provider": config.provider, "api_key": decrypt_data(config.api_key)} if config else None @@ -130,7 +130,7 @@ def fetch_api_key(cls, session, organisation_id, model_provider): logger.info(api_key_data) if api_key_data is None: return [] - elif api_key_data.provider == 'Custom LLM': + elif api_key_data.provider == 'Local LLM': api_key = [{'id': api_key_data.id, 'provider': api_key_data.provider, 'api_key': api_key_data.api_key}] return api_key diff --git a/superagi/types/model_source_types.py b/superagi/types/model_source_types.py index 648c54c37..6e9de18ad 100644 --- a/superagi/types/model_source_types.py +++ b/superagi/types/model_source_types.py @@ -6,7 +6,7 @@ class ModelSourceType(Enum): OpenAI = 'OpenAi' Replicate = 'Replicate' HuggingFace = 'Hugging Face' - CustomLLM = 'Custom LLM' + LocalLLM = 'Local LLM' @classmethod def get_model_source_type(cls, name): From b1ddffb4e0a0bc850738e7dce930a857968996dd Mon Sep 17 00:00:00 2001 From: Rounak Bhatia Date: Wed, 4 Oct 2023 10:30:39 +0000 Subject: [PATCH 09/22] local_llms --- .../controllers/test_models_controller.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/tests/unit_tests/controllers/test_models_controller.py b/tests/unit_tests/controllers/test_models_controller.py index dd83a57bf..0f4b23b98 100644 --- a/tests/unit_tests/controllers/test_models_controller.py +++ b/tests/unit_tests/controllers/test_models_controller.py @@ -101,3 +101,16 @@ def test_get_marketplace_models_list_success(mock_get_db): patch('superagi.helper.auth.db') as mock_auth_db: response = client.get("/models_controller/marketplace/list/0") assert response.status_code == 200 + +@patch('superagi.helper.llm_loader.LLMLoader') +def test_get_llm(mocked_loader): + mocked_model = MagicMock() + mocked_grammar = MagicMock() + + instance = mocked_loader.return_value + instance.model = mocked_model + instance.grammar = mocked_grammar + + response = client.get("models_controller/test_local_llm") + + assert response.status_code == 200, "Unexpected response code" From 43732f70577e81d21f07241efc2095508c5ac04b Mon Sep 17 00:00:00 2001 From: Rounak Bhatia Date: Wed, 4 Oct 2023 16:45:54 +0530 Subject: [PATCH 10/22] local_llms --- .../controllers/test_models_controller.py | 26 ++++++++++--------- 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/tests/unit_tests/controllers/test_models_controller.py b/tests/unit_tests/controllers/test_models_controller.py index 0f4b23b98..82123ba89 100644 --- a/tests/unit_tests/controllers/test_models_controller.py +++ b/tests/unit_tests/controllers/test_models_controller.py @@ -2,6 +2,11 @@ import pytest from fastapi.testclient import TestClient from main import app +from llama_cpp import Llama +from llama_cpp import LlamaGrammar +import llama_cpp + +from superagi.helper.llm_loader import LLMLoader client = TestClient(app) @@ -102,15 +107,12 @@ def test_get_marketplace_models_list_success(mock_get_db): response = client.get("/models_controller/marketplace/list/0") assert response.status_code == 200 -@patch('superagi.helper.llm_loader.LLMLoader') -def test_get_llm(mocked_loader): - mocked_model = MagicMock() - mocked_grammar = MagicMock() - - instance = mocked_loader.return_value - instance.model = mocked_model - instance.grammar = mocked_grammar - - response = client.get("models_controller/test_local_llm") - - assert response.status_code == 200, "Unexpected response code" +def test_get_llm(): + with(patch.object(LLMLoader, 'model', new_callable=MagicMock)) as mock_model: + with(patch.object(LLMLoader, 'grammar', new_callable=MagicMock)) as mock_grammar: + + mock_model.create_chat_completion.return_value = {"choices": [{"message": {"content": "Hello!"}}]} + + response = client.get("/models_controller/test_local_llm") + + assert response.status_code == 200 \ No newline at end of file From 882c1972a1061057027df11b1235e42da629d6a9 Mon Sep 17 00:00:00 2001 From: Rounak Bhatia Date: Wed, 4 Oct 2023 16:47:15 +0530 Subject: [PATCH 11/22] local_llms --- docker-compose.yaml | 2 -- tests/unit_tests/controllers/test_models_controller.py | 2 +- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/docker-compose.yaml b/docker-compose.yaml index 7f37296fd..94044916b 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -3,7 +3,6 @@ services: backend: volumes: - "./:/app" - - "/home/ubuntu/local_models/vicuna-13B-v1.5-GGUF/vicuna-13b-v1.5.Q8_0.gguf:/app/local_model_path" build: . depends_on: - super__redis @@ -15,7 +14,6 @@ services: volumes: - "./:/app" - "${EXTERNAL_RESOURCE_DIR:-./workspace}:/app/ext" - - "/home/ubuntu/local_models/vicuna-13B-v1.5-GGUF/vicuna-13b-v1.5.Q8_0.gguf:/app/local_model_path" build: . depends_on: - super__redis diff --git a/tests/unit_tests/controllers/test_models_controller.py b/tests/unit_tests/controllers/test_models_controller.py index 82123ba89..790229789 100644 --- a/tests/unit_tests/controllers/test_models_controller.py +++ b/tests/unit_tests/controllers/test_models_controller.py @@ -107,7 +107,7 @@ def test_get_marketplace_models_list_success(mock_get_db): response = client.get("/models_controller/marketplace/list/0") assert response.status_code == 200 -def test_get_llm(): +def test_get_local_llm(): with(patch.object(LLMLoader, 'model', new_callable=MagicMock)) as mock_model: with(patch.object(LLMLoader, 'grammar', new_callable=MagicMock)) as mock_grammar: From ab1d96c4896921cdb16f67332215db2b9d55b768 Mon Sep 17 00:00:00 2001 From: Rounak Bhatia Date: Thu, 5 Oct 2023 11:32:55 +0530 Subject: [PATCH 12/22] local_llms --- .../versions/9270eb5a8475_local_llms.py | 56 ------------------- 1 file changed, 56 deletions(-) diff --git a/migrations/versions/9270eb5a8475_local_llms.py b/migrations/versions/9270eb5a8475_local_llms.py index 865561f0b..7e6371e8a 100644 --- a/migrations/versions/9270eb5a8475_local_llms.py +++ b/migrations/versions/9270eb5a8475_local_llms.py @@ -18,67 +18,11 @@ def upgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### - op.drop_index('ix_agent_schedule_agent_id', table_name='agent_schedule') - op.drop_index('ix_agent_schedule_expiry_date', table_name='agent_schedule') - op.drop_index('ix_agent_schedule_status', table_name='agent_schedule') - op.alter_column('agent_workflow_steps', 'unique_id', - existing_type=sa.VARCHAR(), - nullable=True) - op.alter_column('agent_workflow_steps', 'step_type', - existing_type=sa.VARCHAR(), - nullable=True) - op.drop_column('agent_workflows', 'organisation_id') - op.drop_index('ix_events_agent_id', table_name='events') - op.drop_index('ix_events_event_property', table_name='events') - op.drop_index('ix_events_org_id', table_name='events') - op.alter_column('knowledge_configs', 'knowledge_id', - existing_type=sa.INTEGER(), - nullable=True) - op.alter_column('knowledges', 'name', - existing_type=sa.VARCHAR(), - nullable=True) op.add_column('models', sa.Column('context_length', sa.Integer(), nullable=True)) - op.alter_column('vector_db_configs', 'vector_db_id', - existing_type=sa.INTEGER(), - nullable=True) - op.alter_column('vector_db_indices', 'name', - existing_type=sa.VARCHAR(), - nullable=True) - op.alter_column('vector_dbs', 'name', - existing_type=sa.VARCHAR(), - nullable=True) # ### end Alembic commands ### def downgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### - op.alter_column('vector_dbs', 'name', - existing_type=sa.VARCHAR(), - nullable=False) - op.alter_column('vector_db_indices', 'name', - existing_type=sa.VARCHAR(), - nullable=False) - op.alter_column('vector_db_configs', 'vector_db_id', - existing_type=sa.INTEGER(), - nullable=False) op.drop_column('models', 'context_length') - op.alter_column('knowledges', 'name', - existing_type=sa.VARCHAR(), - nullable=False) - op.alter_column('knowledge_configs', 'knowledge_id', - existing_type=sa.INTEGER(), - nullable=False) - op.create_index('ix_events_org_id', 'events', ['org_id'], unique=False) - op.create_index('ix_events_event_property', 'events', ['event_property'], unique=False) - op.create_index('ix_events_agent_id', 'events', ['agent_id'], unique=False) - op.add_column('agent_workflows', sa.Column('organisation_id', sa.INTEGER(), autoincrement=False, nullable=True)) - op.alter_column('agent_workflow_steps', 'step_type', - existing_type=sa.VARCHAR(), - nullable=False) - op.alter_column('agent_workflow_steps', 'unique_id', - existing_type=sa.VARCHAR(), - nullable=False) - op.create_index('ix_agent_schedule_status', 'agent_schedule', ['status'], unique=False) - op.create_index('ix_agent_schedule_expiry_date', 'agent_schedule', ['expiry_date'], unique=False) - op.create_index('ix_agent_schedule_agent_id', 'agent_schedule', ['agent_id'], unique=False) # ### end Alembic commands ### From a363b0a8a2b0a2c4d9004258cfc40e9401d7623f Mon Sep 17 00:00:00 2001 From: Rounak Bhatia Date: Mon, 9 Oct 2023 19:13:38 +0530 Subject: [PATCH 13/22] frontend_changes --- docker-compose.yaml | 8 ++++---- gui/package-lock.json | 3 +-- gui/package.json | 2 +- gui/pages/Content/Models/ModelForm.js | 16 ++++++++++------ gui/pages/_app.css | 1 + gui/pages/api/apiConfig.js | 2 +- 6 files changed, 18 insertions(+), 14 deletions(-) diff --git a/docker-compose.yaml b/docker-compose.yaml index 94044916b..d3caf6029 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -28,10 +28,10 @@ services: NEXT_PUBLIC_API_BASE_URL: "/api" networks: - super_network -# volumes: -# - ./gui:/app -# - /app/node_modules/ -# - /app/.next/ + volumes: + - ./gui:/app + - /app/node_modules/ + - /app/.next/ super__redis: image: "redis/redis-stack-server:latest" networks: diff --git a/gui/package-lock.json b/gui/package-lock.json index 244a7d0b2..a10e70f1e 100644 --- a/gui/package-lock.json +++ b/gui/package-lock.json @@ -22,7 +22,7 @@ "mixpanel-browser": "^2.47.0", "moment": "^2.29.4", "moment-timezone": "^0.5.43", - "next": "13.4.2", + "next": "^13.4.2", "react": "18.2.0", "react-datetime": "^3.2.0", "react-dom": "18.2.0", @@ -3577,7 +3577,6 @@ "version": "13.4.2", "resolved": "https://registry.npmjs.org/next/-/next-13.4.2.tgz", "integrity": "sha512-aNFqLs3a3nTGvLWlO9SUhCuMUHVPSFQC0+tDNGAsDXqx+WJDFSbvc233gOJ5H19SBc7nw36A9LwQepOJ2u/8Kg==", - "license": "MIT", "dependencies": { "@next/env": "13.4.2", "@swc/helpers": "0.5.1", diff --git a/gui/package.json b/gui/package.json index 1dcb77583..00b860550 100644 --- a/gui/package.json +++ b/gui/package.json @@ -24,7 +24,7 @@ "mixpanel-browser": "^2.47.0", "moment": "^2.29.4", "moment-timezone": "^0.5.43", - "next": "13.4.2", + "next": "^13.4.2", "react": "18.2.0", "react-datetime": "^3.2.0", "react-dom": "18.2.0", diff --git a/gui/pages/Content/Models/ModelForm.js b/gui/pages/Content/Models/ModelForm.js index d8b248c56..a927afa6b 100644 --- a/gui/pages/Content/Models/ModelForm.js +++ b/gui/pages/Content/Models/ModelForm.js @@ -17,6 +17,7 @@ export default function ModelForm({internalId, getModels, sendModelData}){ const [tokenError, setTokenError] = useState(false); const [lockAddition, setLockAddition] = useState(true); const [isLoading, setIsLoading] = useState(false) + const [modelStatus, setModelStatus] = useState(false); const modelRef = useRef(null); useEffect(() => { @@ -159,12 +160,15 @@ export default function ModelForm({internalId, getModels, sendModelData}){ onChange={(event) => setModelTokenLimit(parseInt(event.target.value, 10))}/>
-
- - +
+ +
+ + +
diff --git a/gui/pages/_app.css b/gui/pages/_app.css index c2318065e..a5c49ee56 100644 --- a/gui/pages/_app.css +++ b/gui/pages/_app.css @@ -1099,6 +1099,7 @@ p { .mxh_78vh{max-height: 78vh} .flex_dir_col{flex-direction: column} +.flex_none{flex: none} .justify_center{justify-content: center} .justify_end{justify-content: flex-end} diff --git a/gui/pages/api/apiConfig.js b/gui/pages/api/apiConfig.js index 2f9e736b7..3fbcf600d 100644 --- a/gui/pages/api/apiConfig.js +++ b/gui/pages/api/apiConfig.js @@ -1,7 +1,7 @@ import axios from 'axios'; const GITHUB_CLIENT_ID = process.env.GITHUB_CLIENT_ID; -const API_BASE_URL = process.env.NEXT_PUBLIC_API_BASE_URL || 'http://localhost:8001'; +const API_BASE_URL = process.env.NEXT_PUBLIC_API_BASE_URL || 'http://localhost:3000/api'; const GOOGLE_ANALYTICS_MEASUREMENT_ID = process.env.GOOGLE_ANALYTICS_MEASUREMENT_ID; const GOOGLE_ANALYTICS_API_SECRET = process.env.GOOGLE_ANALYTICS_API_SECRET; const MIXPANEL_AUTH_ID = process.env.MIXPANEL_AUTH_ID From 6271d8e75e51e4d408437546b617d8043c39b08f Mon Sep 17 00:00:00 2001 From: Rounak Bhatia Date: Mon, 9 Oct 2023 20:32:26 +0530 Subject: [PATCH 14/22] local_llms --- docker-compose.yaml | 2 +- gui/pages/Content/Models/ModelForm.js | 31 ++++++++++++++++++++++++--- gui/pages/api/DashboardService.js | 4 ++++ 3 files changed, 33 insertions(+), 4 deletions(-) diff --git a/docker-compose.yaml b/docker-compose.yaml index d3caf6029..22c4f20c2 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -43,7 +43,7 @@ services: - redis_data:/data super__postgres: - image: "docker.io/library/postgres:latest" + image: "docker.io/library/postgres:15" environment: - POSTGRES_USER=superagi - POSTGRES_PASSWORD=password diff --git a/gui/pages/Content/Models/ModelForm.js b/gui/pages/Content/Models/ModelForm.js index 792d44355..cd9ce8e54 100644 --- a/gui/pages/Content/Models/ModelForm.js +++ b/gui/pages/Content/Models/ModelForm.js @@ -1,7 +1,7 @@ import React, {useEffect, useRef, useState} from "react"; import {removeTab, openNewTab, createInternalId, getUserClick} from "@/utils/utils"; import Image from "next/image"; -import {fetchApiKey, storeModel, verifyEndPoint} from "@/pages/api/DashboardService"; +import {fetchApiKey, storeModel, testModel, verifyEndPoint} from "@/pages/api/DashboardService"; import {BeatLoader, ClipLoader} from "react-spinners"; import {ToastContainer, toast} from 'react-toastify'; @@ -18,7 +18,7 @@ export default function ModelForm({internalId, getModels, sendModelData}){ const [tokenError, setTokenError] = useState(false); const [lockAddition, setLockAddition] = useState(true); const [isLoading, setIsLoading] = useState(false) - const [modelStatus, setModelStatus] = useState(false); + const [modelStatus, setModelStatus] = useState(null); const modelRef = useRef(null); useEffect(() => { @@ -81,6 +81,18 @@ export default function ModelForm({internalId, getModels, sendModelData}){ }) } + const handleModelStatus = () => { + testModel().then((response) =>{ + if(response.data.success) + setModelStatus(true) + else + setModelStatus(false) + }).catch((error) => { + console.log("Error Message:: " + error) + setModelStatus(false) + }) + } + const handleModelSuccess = (model) => { model.contentType = 'Model' sendModelData(model) @@ -167,8 +179,21 @@ export default function ModelForm({internalId, getModels, sendModelData}){ onChange={(event) => setModelTokenLimit(parseInt(event.target.value, 10))}/> + {modelStatus===false &&
+ error-icon +
+ Test model failed +
+
} + + {modelStatus===true &&
+
+ Test model successful +
+
} +
- +
diff --git a/gui/pages/api/DashboardService.js b/gui/pages/api/DashboardService.js index 129f0e838..2e5f93869 100644 --- a/gui/pages/api/DashboardService.js +++ b/gui/pages/api/DashboardService.js @@ -362,6 +362,10 @@ export const storeModel = (model_name, description, end_point, model_provider_id return api.post(`/models_controller/store_model`,{model_name, description, end_point, model_provider_id, token_limit, type, version, context_length}); } +export const testModel = () => { + return api.get(`/models_controller/test_local_llm`); +} + export const fetchModels = () => { return api.get(`/models_controller/fetch_models`); } From fea8fc5ea1b86465a1a6a41ea83326a66e764d54 Mon Sep 17 00:00:00 2001 From: Rounak Bhatia Date: Tue, 10 Oct 2023 12:47:41 +0530 Subject: [PATCH 15/22] local_llms --- gui/pages/Content/Models/ModelForm.js | 13 +++++++++---- gui/pages/_app.css | 7 +++++++ 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/gui/pages/Content/Models/ModelForm.js b/gui/pages/Content/Models/ModelForm.js index cd9ce8e54..27662e7c8 100644 --- a/gui/pages/Content/Models/ModelForm.js +++ b/gui/pages/Content/Models/ModelForm.js @@ -83,10 +83,14 @@ export default function ModelForm({internalId, getModels, sendModelData}){ const handleModelStatus = () => { testModel().then((response) =>{ - if(response.data.success) + if(response.status === 200) + { setModelStatus(true) + } else + { setModelStatus(false) + } }).catch((error) => { console.log("Error Message:: " + error) setModelStatus(false) @@ -186,18 +190,19 @@ export default function ModelForm({internalId, getModels, sendModelData}){
} - {modelStatus===true &&
+ {modelStatus===true &&
+
Test model successful
}
- + {selectedModel==='Local LLM' && }
-
diff --git a/gui/pages/_app.css b/gui/pages/_app.css index a5c49ee56..03f0917f2 100644 --- a/gui/pages/_app.css +++ b/gui/pages/_app.css @@ -1836,6 +1836,13 @@ tr{ padding: 12px; } +.success_box{ + border-radius: 8px; + padding: 12px; + border-left: 4px solid rgba(255, 255, 255, 0.60); + background: rgba(255, 255, 255, 0.08); +} + .horizontal_line { margin: 16px 0 16px -16px; border: 1px solid #ffffff20; From 648f530e4af690796b13be24970ff9b4ec1b2f72 Mon Sep 17 00:00:00 2001 From: Rounak Bhatia Date: Tue, 10 Oct 2023 08:02:16 +0000 Subject: [PATCH 16/22] local_llms --- gui/pages/Content/Models/ModelForm.js | 32 +++++++++++++++------------ 1 file changed, 18 insertions(+), 14 deletions(-) diff --git a/gui/pages/Content/Models/ModelForm.js b/gui/pages/Content/Models/ModelForm.js index 27662e7c8..00b730bf8 100644 --- a/gui/pages/Content/Models/ModelForm.js +++ b/gui/pages/Content/Models/ModelForm.js @@ -19,6 +19,7 @@ export default function ModelForm({internalId, getModels, sendModelData}){ const [lockAddition, setLockAddition] = useState(true); const [isLoading, setIsLoading] = useState(false) const [modelStatus, setModelStatus] = useState(null); + const [createClickable, setCreateClickable] = useState(true); const modelRef = useRef(null); useEffect(() => { @@ -81,20 +82,22 @@ export default function ModelForm({internalId, getModels, sendModelData}){ }) } - const handleModelStatus = () => { - testModel().then((response) =>{ - if(response.status === 200) - { - setModelStatus(true) - } - else - { - setModelStatus(false) + const handleModelStatus = async () => { + try { + setCreateClickable(false); + const response = await testModel(); + if(response.status === 200) { + setModelStatus(true); + setCreateClickable(true); + } else { + setModelStatus(false); + setCreateClickable(true); } - }).catch((error) => { - console.log("Error Message:: " + error) - setModelStatus(false) - }) + } catch(error) { + console.log("Error Message:: " + error); + setModelStatus(false); + setCreateClickable(true); + } } const handleModelSuccess = (model) => { @@ -198,7 +201,8 @@ export default function ModelForm({internalId, getModels, sendModelData}){
}
- {selectedModel==='Local LLM' && } + {selectedModel==='Local LLM' && }
From 5734b8c7391d9c7c127d3336d0b237dccc7f4d26 Mon Sep 17 00:00:00 2001 From: Rounak Bhatia Date: Tue, 10 Oct 2023 08:04:33 +0000 Subject: [PATCH 17/22] local_llms --- docker-compose.yaml | 8 ++++---- gui/pages/api/apiConfig.js | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/docker-compose.yaml b/docker-compose.yaml index 22c4f20c2..b4a789717 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -28,10 +28,10 @@ services: NEXT_PUBLIC_API_BASE_URL: "/api" networks: - super_network - volumes: - - ./gui:/app - - /app/node_modules/ - - /app/.next/ + # volumes: + # - ./gui:/app + # - /app/node_modules/ + # - /app/.next/ super__redis: image: "redis/redis-stack-server:latest" networks: diff --git a/gui/pages/api/apiConfig.js b/gui/pages/api/apiConfig.js index 3fbcf600d..2f9e736b7 100644 --- a/gui/pages/api/apiConfig.js +++ b/gui/pages/api/apiConfig.js @@ -1,7 +1,7 @@ import axios from 'axios'; const GITHUB_CLIENT_ID = process.env.GITHUB_CLIENT_ID; -const API_BASE_URL = process.env.NEXT_PUBLIC_API_BASE_URL || 'http://localhost:3000/api'; +const API_BASE_URL = process.env.NEXT_PUBLIC_API_BASE_URL || 'http://localhost:8001'; const GOOGLE_ANALYTICS_MEASUREMENT_ID = process.env.GOOGLE_ANALYTICS_MEASUREMENT_ID; const GOOGLE_ANALYTICS_API_SECRET = process.env.GOOGLE_ANALYTICS_API_SECRET; const MIXPANEL_AUTH_ID = process.env.MIXPANEL_AUTH_ID From dd2b04a704e31a8b3126c06a538156e2765331de Mon Sep 17 00:00:00 2001 From: Rounak Bhatia Date: Tue, 10 Oct 2023 09:30:23 +0000 Subject: [PATCH 18/22] local_llms_frontend --- gui/pages/Content/Models/AddModel.js | 4 ++-- gui/pages/Content/Models/ModelForm.js | 4 ++-- gui/pages/Dashboard/Content.js | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/gui/pages/Content/Models/AddModel.js b/gui/pages/Content/Models/AddModel.js index e596cb80c..0ef3d5497 100644 --- a/gui/pages/Content/Models/AddModel.js +++ b/gui/pages/Content/Models/AddModel.js @@ -1,14 +1,14 @@ import React, {useEffect, useState} from "react"; import ModelForm from "./ModelForm"; -export default function AddModel({internalId, getModels, sendModelData}){ +export default function AddModel({internalId, getModels, sendModelData, env}){ return(
- +
diff --git a/gui/pages/Content/Models/ModelForm.js b/gui/pages/Content/Models/ModelForm.js index 00b730bf8..45794bf18 100644 --- a/gui/pages/Content/Models/ModelForm.js +++ b/gui/pages/Content/Models/ModelForm.js @@ -5,8 +5,8 @@ import {fetchApiKey, storeModel, testModel, verifyEndPoint} from "@/pages/api/Da import {BeatLoader, ClipLoader} from "react-spinners"; import {ToastContainer, toast} from 'react-toastify'; -export default function ModelForm({internalId, getModels, sendModelData}){ - const models = ['OpenAI', 'Replicate', 'Hugging Face', 'Google Palm', 'Local LLM']; +export default function ModelForm({internalId, getModels, sendModelData, env}){ + const models = env === 'DEV' ? ['OpenAI', 'Replicate', 'Hugging Face', 'Google Palm', 'Local LLM'] : ['OpenAI', 'Replicate', 'Hugging Face', 'Google Palm']; const [selectedModel, setSelectedModel] = useState('Select a Model'); const [modelName, setModelName] = useState(''); const [modelDescription, setModelDescription] = useState(''); diff --git a/gui/pages/Dashboard/Content.js b/gui/pages/Dashboard/Content.js index 0611a7be0..5ad7740b6 100644 --- a/gui/pages/Dashboard/Content.js +++ b/gui/pages/Dashboard/Content.js @@ -470,7 +470,7 @@ export default function Content({env, selectedView, selectedProjectId, organisat organisationId={organisationId} sendKnowledgeData={addTab} sendAgentData={addTab} selectedProjectId={selectedProjectId} editAgentId={tab.id} fetchAgents={getAgentList} toolkits={toolkits} template={null} edit={true} agents={agents}/>} - {tab.contentType === 'Add_Model' && } + {tab.contentType === 'Add_Model' && } {tab.contentType === 'Model' && }
}
From 6ee4359c6defc289a8c2893dbc0c432787f7b79e Mon Sep 17 00:00:00 2001 From: Rounak Bhatia Date: Wed, 18 Oct 2023 11:23:19 +0530 Subject: [PATCH 19/22] fixes --- docker-compose.yaml | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/docker-compose.yaml b/docker-compose.yaml index b4a789717..94044916b 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -28,10 +28,10 @@ services: NEXT_PUBLIC_API_BASE_URL: "/api" networks: - super_network - # volumes: - # - ./gui:/app - # - /app/node_modules/ - # - /app/.next/ +# volumes: +# - ./gui:/app +# - /app/node_modules/ +# - /app/.next/ super__redis: image: "redis/redis-stack-server:latest" networks: @@ -43,7 +43,7 @@ services: - redis_data:/data super__postgres: - image: "docker.io/library/postgres:15" + image: "docker.io/library/postgres:latest" environment: - POSTGRES_USER=superagi - POSTGRES_PASSWORD=password From b72447ef6bb298a2dae76cd4a8fce0e34a20f0c7 Mon Sep 17 00:00:00 2001 From: Rounak Bhatia Date: Wed, 18 Oct 2023 11:25:35 +0530 Subject: [PATCH 20/22] fixes --- gui/package-lock.json | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/gui/package-lock.json b/gui/package-lock.json index a10e70f1e..ad32a9d85 100644 --- a/gui/package-lock.json +++ b/gui/package-lock.json @@ -22,7 +22,7 @@ "mixpanel-browser": "^2.47.0", "moment": "^2.29.4", "moment-timezone": "^0.5.43", - "next": "^13.4.2", + "next": "13.4.2", "react": "18.2.0", "react-datetime": "^3.2.0", "react-dom": "18.2.0", @@ -3577,6 +3577,7 @@ "version": "13.4.2", "resolved": "https://registry.npmjs.org/next/-/next-13.4.2.tgz", "integrity": "sha512-aNFqLs3a3nTGvLWlO9SUhCuMUHVPSFQC0+tDNGAsDXqx+WJDFSbvc233gOJ5H19SBc7nw36A9LwQepOJ2u/8Kg==", + "license": "MIT", "dependencies": { "@next/env": "13.4.2", "@swc/helpers": "0.5.1", @@ -5221,4 +5222,4 @@ "integrity": "sha512-N82ooyxVNm6h1riLCoyS9e3fuJ3AMG2zIZs2Gd1ATcSFjSA23Q0fzjjZeh0jbJvWVDZ0cJT8yaNNaaXHzueNjg==" } } -} +} \ No newline at end of file From 9a1b7ad8fabfa650ed72b24411135bf8b0b86e68 Mon Sep 17 00:00:00 2001 From: Rounak Bhatia Date: Wed, 18 Oct 2023 11:29:51 +0530 Subject: [PATCH 21/22] fixes --- gui/package.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gui/package.json b/gui/package.json index 00b860550..1dcb77583 100644 --- a/gui/package.json +++ b/gui/package.json @@ -24,7 +24,7 @@ "mixpanel-browser": "^2.47.0", "moment": "^2.29.4", "moment-timezone": "^0.5.43", - "next": "^13.4.2", + "next": "13.4.2", "react": "18.2.0", "react-datetime": "^3.2.0", "react-dom": "18.2.0", From 84c8dc6e42d573d6e10f43b17fbcc3b1f0792711 Mon Sep 17 00:00:00 2001 From: Rounak Bhatia Date: Wed, 18 Oct 2023 12:00:12 +0530 Subject: [PATCH 22/22] fixes --- gui/package-lock.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gui/package-lock.json b/gui/package-lock.json index ad32a9d85..244a7d0b2 100644 --- a/gui/package-lock.json +++ b/gui/package-lock.json @@ -5222,4 +5222,4 @@ "integrity": "sha512-N82ooyxVNm6h1riLCoyS9e3fuJ3AMG2zIZs2Gd1ATcSFjSA23Q0fzjjZeh0jbJvWVDZ0cJT8yaNNaaXHzueNjg==" } } -} \ No newline at end of file +}