Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Edit agent templates #793

Merged
merged 9 commits into from
Jul 20, 2023
101 changes: 79 additions & 22 deletions gui/pages/Content/Agents/AgentCreate.js
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import 'react-toastify/dist/ReactToastify.css';
import styles from './Agents.module.css';
import {
createAgent,
editAgentTemplate,
fetchAgentTemplateConfigLocal,
getOrganisationConfig,
updateExecution,
Expand All @@ -21,17 +22,11 @@ import {EventBus} from "@/utils/eventBus";
import 'moment-timezone';
import AgentSchedule from "@/pages/Content/Agents/AgentSchedule";

export default function AgentCreate({
sendAgentData,
selectedProjectId,
fetchAgents,
toolkits,
organisationId,
template,
internalId
}) {
export default function AgentCreate({sendAgentData,selectedProjectId,fetchAgents,toolkits,organisationId,template,internalId}) {

const [advancedOptions, setAdvancedOptions] = useState(false);
const [agentName, setAgentName] = useState("");
const [agentTemplateId, setAgentTemplateId] = useState(null);
const [agentDescription, setAgentDescription] = useState("");
const [longTermMemory, setLongTermMemory] = useState(true);
const [addResources, setAddResources] = useState(true);
Expand All @@ -42,6 +37,7 @@ export default function AgentCreate({
const [maxIterations, setIterations] = useState(25);
const [toolkitList, setToolkitList] = useState(toolkits)
const [searchValue, setSearchValue] = useState('');
const [showButton, setShowButton] = useState(false);

const constraintsArray = [
"If you are unsure how you previously did something or want to recall past events, thinking about similar events will help you remember.",
Expand Down Expand Up @@ -137,6 +133,7 @@ export default function AgentCreate({
setLocalStorageValue("agent_name_" + String(internalId), template.name, setAgentName);
setLocalStorageValue("agent_description_" + String(internalId), template.description, setAgentDescription);
setLocalStorageValue("advanced_options_" + String(internalId), true, setAdvancedOptions);
setLocalStorageValue("agent_template_id_" + String(internalId), template.id, setAgentTemplateId);

fetchAgentTemplateConfigLocal(template.id)
.then((response) => {
Expand All @@ -151,6 +148,8 @@ export default function AgentCreate({
setLocalStorageValue("agent_database_" + String(internalId), data.LTM_DB, setDatabase);
setLocalStorageValue("agent_model_" + String(internalId), data.model, setModel);
setLocalStorageArray("tool_names_" + String(internalId), data.tools, setToolNames);
setLocalStorageValue("is_agent_template_" + String(internalId), true, setShowButton);
setShowButton(true);
})
.catch((error) => {
console.error('Error fetching template details:', error);
Expand Down Expand Up @@ -362,33 +361,35 @@ export default function AgentCreate({
}
}, [scheduleData]);

const handleAddAgent = () => {
if (!hasAPIkey) {
const validateAgentData = (isNewAgent) => {
if (isNewAgent && !hasAPIkey) {
toast.error("Your OpenAI/Palm API key is empty!", {autoClose: 1800});
openNewTab(-3, "Settings", "Settings", false);
return
return false;
}

if (agentName.replace(/\s/g, '') === '') {
if (agentName?.replace(/\s/g, '') === '') {
toast.error("Agent name can't be blank", {autoClose: 1800});
return
return false;
}

if (agentDescription.replace(/\s/g, '') === '') {
if (agentDescription?.replace(/\s/g, '') === '') {
toast.error("Agent description can't be blank", {autoClose: 1800});
return
return false;
}

const isEmptyGoal = goals.some((goal) => goal.replace(/\s/g, '') === '');
if (isEmptyGoal) {
toast.error("Goal can't be empty", {autoClose: 1800});
return;
return false;
}

if (selectedTools.length <= 0) {
toast.error("Add atleast one tool", {autoClose: 1800});
return
return false;
}
return true;
}

const handleAddAgent = () => {
if (!validateAgentData(true)) return;

setCreateClickable(false);

Expand Down Expand Up @@ -531,6 +532,46 @@ export default function AgentCreate({
event.preventDefault();
};

function updateTemplate() {

if (!validateAgentData(false)) return;

let permission_type = permission;
if (permission.includes("RESTRICTED")) {
permission_type = "RESTRICTED";
}

const agentTemplateConfigData = {
"goal": goals,
"instruction": instructions,
"agent_type": agentType,
"constraints": constraints,
"tools": toolNames,
"exit": exitCriterion,
"iteration_interval": stepTime,
"model": model,
"max_iterations": maxIterations,
"permission_type": permission_type,
"LTM_DB": longTermMemory ? database : null,
}
const editTemplateData = {
"name": agentName,
"description": agentDescription,
"agent_configs": agentTemplateConfigData
}

editAgentTemplate(agentTemplateId, editTemplateData)
.then((response) => {
if (response.status === 200) {
toast.success('Agent template has been updated successfully!', {autoClose: 1800});
}
})
.catch((error) => {
toast.error("Error updating agent template")
console.error('Error updating agent template:', error);
});
};

function setFileData(files) {
if (files.length > 0) {
const fileData = {
Expand Down Expand Up @@ -579,11 +620,21 @@ export default function AgentCreate({
setAdvancedOptions(JSON.parse(advanced_options));
}

const is_agent_template = localStorage.getItem("is_agent_template_" + String(internalId));
if (is_agent_template) {
setShowButton(true);
}

const agent_name = localStorage.getItem("agent_name_" + String(internalId));
if (agent_name) {
setAgentName(agent_name);
}

const agent_template_id = localStorage.getItem("agent_template_id_"+ String(internalId));
if(agent_template_id){
setAgentTemplateId(agent_template_id)
}

const agent_description = localStorage.getItem("agent_description_" + String(internalId));
if (agent_description) {
setAgentDescription(agent_description);
Expand Down Expand Up @@ -986,6 +1037,12 @@ export default function AgentCreate({
<button style={{marginRight: '7px'}} className="secondary_button"
onClick={() => removeTab(-1, "new agent", "Create_Agent", internalId)}>Cancel
</button>
{showButton && (
<button style={{ marginRight: '7px' }} className="secondary_button"
onClick={() => {updateTemplate()}}>
Update Template
</button>
)}
<div style={{display: 'flex', position: 'relative'}}>
{createDropdown && (<div className="custom_select_option" style={{
background: '#3B3B49',
Expand Down
4 changes: 4 additions & 0 deletions gui/pages/api/DashboardService.js
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ export const updateExecution = (executionId, executionData) => {
return api.put(`/agentexecutions/update/${executionId}`, executionData);
};

export const editAgentTemplate = (agentTemplateId, agentTemplateData) => {
return api.put(`/agent_templates/update_agent_template/${agentTemplateId}`, agentTemplateData)
}

export const addExecution = (executionData) => {
return api.post(`/agentexecutions/add`, executionData);
};
Expand Down
54 changes: 54 additions & 0 deletions superagi/controllers/agent_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from superagi.models.agent_template_config import AgentTemplateConfig
from superagi.models.agent_workflow import AgentWorkflow
from superagi.models.tool import Tool
import json
# from superagi.types.db import AgentTemplateIn, AgentTemplateOut

router = APIRouter()
Expand Down Expand Up @@ -144,6 +145,59 @@ def update_agent_template(agent_template_id: int,

return db_agent_template

@router.put("/update_agent_template/{agent_template_id}", status_code=200)
def edit_agent_template(agent_template_id: int,
updated_agent_configs: dict,
organisation=Depends(get_user_organisation)):

"""
Update the details of an agent template.

Args:
agent_template_id (int): The ID of the agent template to update.
edited_agent_configs (dict): The updated agent configurations.
organisation (Depends): Dependency to get the user organisation.

Returns:
HTTPException (status_code=200): If the agent gets successfully edited.

Raises:
HTTPException (status_code=404): If the agent template is not found.
"""

db_agent_template = db.session.query(AgentTemplate).filter(AgentTemplate.organisation_id == organisation.id,
AgentTemplate.id == agent_template_id).first()
if db_agent_template is None:
raise HTTPException(status_code=404, detail="Agent Template not found")

db_agent_template.name = updated_agent_configs["name"]
db_agent_template.description = updated_agent_configs["description"]

db.session.commit()

agent_config_values = updated_agent_configs.get('agent_configs', {})

for key, value in agent_config_values.items():
if isinstance(value, (list, dict)):
value = json.dumps(value)
config = db.session.query(AgentTemplateConfig).filter(
AgentTemplateConfig.agent_template_id == agent_template_id,
AgentTemplateConfig.key == key
).first()

if config is not None:
config.value = value
else:
new_config = AgentTemplateConfig(
agent_template_id=agent_template_id,
key=key,
value= value
)
db.session.add(new_config)

db.session.commit()
db.session.flush()


@router.post("/save_agent_as_template/{agent_id}")
def save_agent_as_template(agent_id: str,
Expand Down
127 changes: 127 additions & 0 deletions tests/unit_tests/controllers/test_agent_template.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
from unittest.mock import patch, MagicMock
from superagi.models.agent_template import AgentTemplate
from superagi.models.agent_template_config import AgentTemplateConfig
from fastapi.testclient import TestClient
from main import app

client = TestClient(app)

@patch('superagi.controllers.agent_template.db')
@patch('superagi.helper.auth.db')
@patch('superagi.helper.auth.get_user_organisation')
def test_edit_agent_template_success(mock_get_user_org, mock_auth_db, mock_db):
# Create a mock agent template
mock_agent_template = AgentTemplate(id=1, name="Test Agent Template", description="Test Description")
# mock_agent_goals = AgentTemplateConfig()

# Create a mock edited agent configuration
mock_updated_agent_configs = {
"name": "Updated Agent Template",
"description": "Updated Description",
"agent_configs": {
"goal": ["Create a simple pacman game for me.", "Write all files properly."],
"instruction": ["write spec","write code","improve the code","write test"],
"agent_type": "Don't Maintain Task Queue",
"constraints": ["If you are unsure how you previously did something or want to recall past events, thinking about similar events will help you remember.","Ensure the tool and args are as per current plan and reasoning","Exclusively use the tools listed under \"TOOLS\"","REMEMBER to format your response as JSON, using double quotes (\"\") around keys and string values, and commas (,) to separate items in arrays and objects. IMPORTANTLY, to use a JSON object as a string in another JSON object, you need to escape the double quotes."],
"tools": ["Read Email", "Send Email", "Write File"],
"exit": "No exit criterion",
"iteration_interval": 500,
"model": "gpt-4",
"max_iterations": 25,
"permission_type": "God Mode",
"LTM_DB": "Pinecone"
}
}

# Mocking the user organisation
mock_get_user_org.return_value = MagicMock(id=1)

# Create a session mock
session_mock = MagicMock()
mock_db.session = session_mock
mock_db.session.query.return_value.filter.return_value.first.return_value = mock_agent_template
mock_db.session.commit.return_value = None
mock_db.session.add.return_value = None
mock_db.session.flush.return_value = None

mock_agent_template_config = AgentTemplateConfig(agent_template_id = 1, key="goal", value=["Create a simple pacman game for me.", "Write all files properly."])


# Call the endpoint
response = client.put("agent_templates/update_agent_template/1", json=mock_updated_agent_configs)

assert response.status_code == 200

# Verify changes in the mock agent template
assert mock_agent_template.name == "Updated Agent Template"
assert mock_agent_template.description == "Updated Description"
assert mock_agent_template_config.key == "goal"
assert mock_agent_template_config.value == ["Create a simple pacman game for me.", "Write all files properly."]


session_mock.commit.assert_called()
session_mock.flush.assert_called()


@patch('superagi.controllers.agent_template.db')
@patch('superagi.helper.auth.db')
@patch('superagi.helper.auth.get_user_organisation')
def test_edit_agent_template_failure(mock_get_user_org, mock_auth_db, mock_db):
# Setup: The user organisation exists, but the agent template does not exist.
mock_get_user_org.return_value = MagicMock(id=1)

# Create a session mock
session_mock = MagicMock()
mock_db.session = session_mock
mock_db.session.query.return_value.filter.return_value.first.return_value = None

# Call the endpoint
response = client.put("agent_templates/update_agent_template/1", json={})

# Verify: The response status code should be 404, indicating that the agent template was not found.
assert response.status_code == 404
assert response.json() == {"detail": "Agent Template not found"}

# Verify: The database commit method should not have been called because the agent template was not found.
session_mock.commit.assert_not_called()
session_mock.flush.assert_not_called()


@patch('superagi.controllers.agent_template.db')
@patch('superagi.helper.auth.db')
@patch('superagi.helper.auth.get_user_organisation')
def test_edit_agent_template_with_new_config_success(mock_get_user_org, mock_auth_db, mock_db):
# Create a mock agent template
mock_agent_template = AgentTemplate(id=1, name="Test Agent Template", description="Test Description")

# Create a mock edited agent configuration
mock_updated_agent_configs = {
"name": "Updated Agent Template",
"description": "Updated Description",
"agent_configs": {
"new_config_key": "New config value" # This is a new config
}
}

# Mocking the user organisation
mock_get_user_org.return_value = MagicMock(id=1)

# Create a session mock
session_mock = MagicMock()
mock_db.session = session_mock
mock_db.session.query.return_value.filter.return_value.first.return_value = mock_agent_template
mock_db.session.commit.return_value = None
mock_db.session.add.return_value = None
mock_db.session.flush.return_value = None

# Call the endpoint
response = client.put("agent_templates/update_agent_template/1", json=mock_updated_agent_configs)

assert response.status_code == 200

# Verify changes in the mock agent template
assert mock_agent_template.name == "Updated Agent Template"
assert mock_agent_template.description == "Updated Description"

session_mock.commit.assert_called()
session_mock.flush.assert_called()