Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Local LLMs #1306

Merged
merged 24 commits into from
Oct 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions gui/pages/Content/Models/AddModel.js
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import React, {useEffect, useState} from "react";
import ModelForm from "./ModelForm";

export default function AddModel({internalId, getModels, sendModelData}){
export default function AddModel({internalId, getModels, sendModelData, env}){

return(
<div id="add_model">
<div className="row">
<div className="col-3" />
<div className="col-6 col-6-scrollable">
<ModelForm internalId={internalId} getModels={getModels} sendModelData={sendModelData}/>
<ModelForm internalId={internalId} getModels={getModels} sendModelData={sendModelData} env={env}/>
</div>
<div className="col-3" />
</div>
Expand Down
73 changes: 58 additions & 15 deletions gui/pages/Content/Models/ModelForm.js
Original file line number Diff line number Diff line change
@@ -1,22 +1,25 @@
import React, {useEffect, useRef, useState} from "react";
import {removeTab, openNewTab, createInternalId, modelGetAuth, getUserClick} from "@/utils/utils";
import {removeTab, openNewTab, createInternalId, getUserClick} from "@/utils/utils";
import Image from "next/image";
import {fetchApiKey, storeModel, verifyEndPoint} from "@/pages/api/DashboardService";
import {fetchApiKey, storeModel, testModel, verifyEndPoint} from "@/pages/api/DashboardService";
import {BeatLoader, ClipLoader} from "react-spinners";
import {ToastContainer, toast} from 'react-toastify';

export default function ModelForm({internalId, getModels, sendModelData}){
const models = ['OpenAI', 'Replicate', 'Hugging Face', 'Google Palm'];
export default function ModelForm({internalId, getModels, sendModelData, env}){
const models = env === 'DEV' ? ['OpenAI', 'Replicate', 'Hugging Face', 'Google Palm', 'Local LLM'] : ['OpenAI', 'Replicate', 'Hugging Face', 'Google Palm'];
const [selectedModel, setSelectedModel] = useState('Select a Model');
const [modelName, setModelName] = useState('');
const [modelDescription, setModelDescription] = useState('');
const [modelTokenLimit, setModelTokenLimit] = useState(4096);
const [modelEndpoint, setModelEndpoint] = useState('');
const [modelDropdown, setModelDropdown] = useState(false);
const [modelVersion, setModelVersion] = useState('');
const [modelContextLength, setContextLength] = useState(4096);
const [tokenError, setTokenError] = useState(false);
const [lockAddition, setLockAddition] = useState(true);
const [isLoading, setIsLoading] = useState(false)
const [modelStatus, setModelStatus] = useState(null);
const [createClickable, setCreateClickable] = useState(true);
const modelRef = useRef(null);

useEffect(() => {
Expand Down Expand Up @@ -66,11 +69,9 @@ export default function ModelForm({internalId, getModels, sendModelData}){
{
const modelProviderId = response.data[0].id
verifyEndPoint(response.data[0].api_key, modelEndpoint, selectedModel).then((response) =>{
if(response.data.success) {
if(response.data.success)
storeModelDetails(modelProviderId)
getUserClick("Model Added Successfully",{'type': selectedModel})
}
else {
else{
toast.error("The Endpoint is not Valid",{autoClose: 1800});
setIsLoading(false);
}
Expand All @@ -81,13 +82,31 @@ export default function ModelForm({internalId, getModels, sendModelData}){
})
}

const handleModelStatus = async () => {
try {
setCreateClickable(false);
const response = await testModel();
if(response.status === 200) {
setModelStatus(true);
setCreateClickable(true);
} else {
setModelStatus(false);
setCreateClickable(true);
}
} catch(error) {
console.log("Error Message:: " + error);
setModelStatus(false);
setCreateClickable(true);
}
}

const handleModelSuccess = (model) => {
model.contentType = 'Model'
sendModelData(model)
}

const storeModelDetails = (modelProviderId) => {
storeModel(modelName,modelDescription, modelEndpoint, modelProviderId, modelTokenLimit, "Custom", modelVersion).then((response) =>{
storeModel(modelName,modelDescription, modelEndpoint, modelProviderId, modelTokenLimit, "Custom", modelVersion, modelContextLength).then((response) =>{
setIsLoading(false)
let data = response.data
if (data.error) {
Expand Down Expand Up @@ -155,18 +174,42 @@ export default function ModelForm({internalId, getModels, sendModelData}){
onChange={(event) => setModelVersion(event.target.value)}/>
</div>}

{(selectedModel === 'Local LLM') && <div className="mt_24">
<span>Model Context Length</span>
<input className="input_medium mt_8" type="number" placeholder="Enter Model Context Length" value={modelContextLength}
onChange={(event) => setContextLength(event.target.value)}/>
</div>}

<div className="mt_24">
<span>Token Limit</span>
<input className="input_medium mt_8" type="number" placeholder="Enter Model Token Limit" value={modelTokenLimit}
onChange={(event) => setModelTokenLimit(parseInt(event.target.value, 10))}/>
</div>

<div className="horizontal_container justify_end mt_24">
<button className="secondary_button mr_7"
onClick={() => removeTab(-5, "new model", "Add_Model", internalId)}>Cancel</button>
<button className='primary_button' onClick={handleAddModel} disabled={lockAddition || isLoading}>
{isLoading ? <><span>Adding Model &nbsp;</span><ClipLoader size={16} color={"#000000"} /></> : 'Add Model'}
</button>
{modelStatus===false && <div className="horizontal_container align_start error_box mt_24 gap_6">
<Image width={16} height={16} src="/images/icon_error.svg" alt="error-icon" />
<div className="vertical_containers">
<span className="text_12 color_white lh_16">Test model failed</span>
</div>
</div>}

{modelStatus===true && <div className="horizontal_container align_start success_box mt_24 gap_6">
<Image width={16} height={16} src="/images/icon_info.svg"/>
<div className="vertical_containers">
<span className="text_12 color_white lh_16">Test model successful</span>
</div>
</div>}

<div className="horizontal_container justify_space_between w_100 mt_24">
{selectedModel==='Local LLM' && <button className="secondary_button flex_none" disabled={!createClickable}
onClick={() => {handleModelStatus();}}>{createClickable ? 'Test Model' : 'Testing model...'}</button>}
<div className="horizontal_container justify_end">
<button className="secondary_button mr_7"
onClick={() => removeTab(-5, "new model", "Add_Model", internalId)}>Cancel</button>
<button className='primary_button' onClick={handleAddModel} disabled={lockAddition || isLoading || (selectedModel==='Local LLM' && !modelStatus)}>
{isLoading ? <><span>Adding Model &nbsp;</span><ClipLoader size={16} color={"#000000"} /></> : 'Add Model'}
</button>
</div>
</div>
<ToastContainer className="text_16"/>
</div>
Expand Down
2 changes: 1 addition & 1 deletion gui/pages/Dashboard/Content.js
Original file line number Diff line number Diff line change
Expand Up @@ -470,7 +470,7 @@ export default function Content({env, selectedView, selectedProjectId, organisat
organisationId={organisationId} sendKnowledgeData={addTab}
sendAgentData={addTab} selectedProjectId={selectedProjectId} editAgentId={tab.id}
fetchAgents={getAgentList} toolkits={toolkits} template={null} edit={true} agents={agents}/>}
{tab.contentType === 'Add_Model' && <AddModel internalId={tab.internalId} getModels={getModels} sendModelData={addTab}/>}
{tab.contentType === 'Add_Model' && <AddModel internalId={tab.internalId} getModels={getModels} sendModelData={addTab} env={env}/>}
{tab.contentType === 'Model' && <ModelDetails modelId={tab.id} modelName={tab.name} />}
</div>}
</div>
Expand Down
8 changes: 8 additions & 0 deletions gui/pages/_app.css
Original file line number Diff line number Diff line change
Expand Up @@ -1099,6 +1099,7 @@ p {
.mxh_78vh{max-height: 78vh}

.flex_dir_col{flex-direction: column}
.flex_none{flex: none}

.justify_center{justify-content: center}
.justify_end{justify-content: flex-end}
Expand Down Expand Up @@ -1835,6 +1836,13 @@ tr{
padding: 12px;
}

.success_box{
border-radius: 8px;
padding: 12px;
border-left: 4px solid rgba(255, 255, 255, 0.60);
background: rgba(255, 255, 255, 0.08);
}

.horizontal_line {
margin: 16px 0 16px -16px;
border: 1px solid #ffffff20;
Expand Down
8 changes: 6 additions & 2 deletions gui/pages/api/DashboardService.js
Original file line number Diff line number Diff line change
Expand Up @@ -358,8 +358,12 @@ export const verifyEndPoint = (model_api_key, end_point, model_provider) => {
});
}

export const storeModel = (model_name, description, end_point, model_provider_id, token_limit, type, version) => {
return api.post(`/models_controller/store_model`,{model_name, description, end_point, model_provider_id, token_limit, type, version});
export const storeModel = (model_name, description, end_point, model_provider_id, token_limit, type, version, context_length) => {
return api.post(`/models_controller/store_model`,{model_name, description, end_point, model_provider_id, token_limit, type, version, context_length});
}

export const testModel = () => {
return api.get(`/models_controller/test_local_llm`);
}

export const fetchModels = () => {
Expand Down
11 changes: 10 additions & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
from superagi.llms.replicate import Replicate
from superagi.llms.hugging_face import HuggingFace
from superagi.models.agent_template import AgentTemplate
from superagi.models.models_config import ModelsConfig
from superagi.models.organisation import Organisation
from superagi.models.types.login_request import LoginRequest
from superagi.models.types.validate_llm_api_key_request import ValidateAPIKeyRequest
Expand Down Expand Up @@ -215,6 +216,13 @@ def register_toolkit_for_master_organisation():
Organisation.id == marketplace_organisation_id).first()
if marketplace_organisation is not None:
register_marketplace_toolkits(session, marketplace_organisation)

def local_llm_model_config():
existing_models_config = session.query(ModelsConfig).filter(ModelsConfig.org_id == default_user.organisation_id, ModelsConfig.provider == 'Local LLM').first()
if existing_models_config is None:
models_config = ModelsConfig(org_id=default_user.organisation_id, provider='Local LLM', api_key="EMPTY")
session.add(models_config)
session.commit()

IterationWorkflowSeed.build_single_step_agent(session)
IterationWorkflowSeed.build_task_based_agents(session)
Expand All @@ -238,7 +246,8 @@ def register_toolkit_for_master_organisation():
# AgentWorkflowSeed.doc_search_and_code(session)
# AgentWorkflowSeed.build_research_email_workflow(session)
replace_old_iteration_workflows(session)

local_llm_model_config()

if env != "PROD":
register_toolkit_for_all_organisation()
else:
Expand Down
28 changes: 28 additions & 0 deletions migrations/versions/9270eb5a8475_local_llms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
"""local_llms

Revision ID: 9270eb5a8475
Revises: 3867bb00a495
Create Date: 2023-10-04 09:26:33.865424

"""
from alembic import op
import sqlalchemy as sa


# revision identifiers, used by Alembic.
revision = '9270eb5a8475'
down_revision = '3867bb00a495'
branch_labels = None
depends_on = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('models', sa.Column('context_length', sa.Integer(), nullable=True))
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column('models', 'context_length')
# ### end Alembic commands ###
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -158,3 +158,4 @@ google-generativeai==0.1.0
unstructured==0.8.1
ai21==1.2.6
typing-extensions==4.5.0
llama_cpp_python==0.2.7
37 changes: 35 additions & 2 deletions superagi/controllers/models_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@
from superagi.helper.auth import check_auth, get_user_organisation
from superagi.helper.models_helper import ModelsHelper
from superagi.apm.call_log_helper import CallLogHelper
from superagi.lib.logger import logger
from superagi.models.models import Models
from superagi.models.models_config import ModelsConfig
from superagi.config.config import get_config
from superagi.controllers.types.models_types import ModelsTypes
from fastapi_sqlalchemy import db
import logging
from pydantic import BaseModel
from superagi.helper.llm_loader import LLMLoader

router = APIRouter()

Expand All @@ -26,6 +28,7 @@ class StoreModelRequest(BaseModel):
token_limit: int
type: str
version: str
context_length: int

class ModelName (BaseModel):
model: str
Expand Down Expand Up @@ -69,7 +72,9 @@ async def verify_end_point(model_api_key: str = None, end_point: str = None, mod
@router.post("/store_model", status_code=200)
async def store_model(request: StoreModelRequest, organisation=Depends(get_user_organisation)):
try:
return Models.store_model_details(db.session, organisation.id, request.model_name, request.description, request.end_point, request.model_provider_id, request.token_limit, request.type, request.version)
#context_length = 4096
logger.info(request)
return Models.store_model_details(db.session, organisation.id, request.model_name, request.description, request.end_point, request.model_provider_id, request.token_limit, request.type, request.version, request.context_length)
except Exception as e:
logging.error(f"Error storing the Model Details: {str(e)}")
raise HTTPException(status_code=500, detail="Internal Server Error")
Expand Down Expand Up @@ -164,4 +169,32 @@ def get_models_details(page: int = 0):
marketplace_models = Models.fetch_marketplace_list(page)
marketplace_models_with_install = Models.get_model_install_details(db.session, marketplace_models, organisation_id,
ModelsTypes.MARKETPLACE.value)
return marketplace_models_with_install
return marketplace_models_with_install

@router.get("/test_local_llm", status_code=200)
def test_local_llm():
try:
llm_loader = LLMLoader(context_length=4096)
llm_model = llm_loader.model
llm_grammar = llm_loader.grammar
if llm_model is None:
logger.error("Model not found.")
raise HTTPException(status_code=404, detail="Error while loading the model. Please check your model path and try again.")
if llm_grammar is None:
logger.error("Grammar not found.")
raise HTTPException(status_code=404, detail="Grammar not found.")

messages = [
{"role":"system",
"content":"You are an AI assistant. Give response in a proper JSON format"},
{"role":"user",
"content":"Hi!"}
]
response = llm_model.create_chat_completion(messages=messages, grammar=llm_grammar)
content = response["choices"][0]["message"]["content"]
logger.info(content)
return "Model loaded successfully."

except Exception as e:
logger.info("Error: ",e)
raise HTTPException(status_code=404, detail="Error while loading the model. Please check your model path and try again.")
38 changes: 38 additions & 0 deletions superagi/helper/llm_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from llama_cpp import Llama
from llama_cpp import LlamaGrammar
from superagi.config.config import get_config
from superagi.lib.logger import logger


class LLMLoader:
_instance = None
_model = None
_grammar = None

def __new__(cls, *args, **kwargs):
if cls._instance is None:
cls._instance = super(LLMLoader, cls).__new__(cls)
return cls._instance

def __init__(self, context_length):
self.context_length = context_length

@property
def model(self):
if self._model is None:
try:
self._model = Llama(
model_path="/app/local_model_path", n_ctx=self.context_length)
except Exception as e:
logger.error(e)
return self._model

@property
def grammar(self):
if self._grammar is None:
try:
self._grammar = LlamaGrammar.from_file(
"superagi/llms/grammar/json.gbnf")
except Exception as e:
logger.error(e)
return self._grammar
3 changes: 3 additions & 0 deletions superagi/jobs/agent_executor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from datetime import datetime, timedelta

from sqlalchemy.orm import sessionmaker
from superagi.llms.local_llm import LocalLLM

import superagi.worker
from superagi.agent.agent_iteration_step_handler import AgentIterationStepHandler
Expand Down Expand Up @@ -135,6 +136,8 @@ def get_embedding(cls, model_source, model_api_key):
return HuggingFace(api_key=model_api_key)
if "Replicate" in model_source:
return Replicate(api_key=model_api_key)
if "Custom" in model_source:
return LocalLLM()
return None

def _check_for_max_iterations(self, session, organisation_id, agent_config, agent_execution_id):
Expand Down
Loading
Loading