diff --git a/gui/pages/Content/Agents/AgentCreate.js b/gui/pages/Content/Agents/AgentCreate.js
index c0c8afc1f..0de46038a 100644
--- a/gui/pages/Content/Agents/AgentCreate.js
+++ b/gui/pages/Content/Agents/AgentCreate.js
@@ -5,6 +5,7 @@ import 'react-toastify/dist/ReactToastify.css';
import styles from './Agents.module.css';
import {
createAgent,
+ editAgentTemplate,
fetchAgentTemplateConfigLocal,
getOrganisationConfig,
updateExecution,
@@ -21,17 +22,11 @@ import {EventBus} from "@/utils/eventBus";
import 'moment-timezone';
import AgentSchedule from "@/pages/Content/Agents/AgentSchedule";
-export default function AgentCreate({
- sendAgentData,
- selectedProjectId,
- fetchAgents,
- toolkits,
- organisationId,
- template,
- internalId
- }) {
+export default function AgentCreate({sendAgentData,selectedProjectId,fetchAgents,toolkits,organisationId,template,internalId}) {
+
const [advancedOptions, setAdvancedOptions] = useState(false);
const [agentName, setAgentName] = useState("");
+ const [agentTemplateId, setAgentTemplateId] = useState(null);
const [agentDescription, setAgentDescription] = useState("");
const [longTermMemory, setLongTermMemory] = useState(true);
const [addResources, setAddResources] = useState(true);
@@ -42,6 +37,7 @@ export default function AgentCreate({
const [maxIterations, setIterations] = useState(25);
const [toolkitList, setToolkitList] = useState(toolkits)
const [searchValue, setSearchValue] = useState('');
+ const [showButton, setShowButton] = useState(false);
const constraintsArray = [
"If you are unsure how you previously did something or want to recall past events, thinking about similar events will help you remember.",
@@ -137,6 +133,7 @@ export default function AgentCreate({
setLocalStorageValue("agent_name_" + String(internalId), template.name, setAgentName);
setLocalStorageValue("agent_description_" + String(internalId), template.description, setAgentDescription);
setLocalStorageValue("advanced_options_" + String(internalId), true, setAdvancedOptions);
+ setLocalStorageValue("agent_template_id_" + String(internalId), template.id, setAgentTemplateId);
fetchAgentTemplateConfigLocal(template.id)
.then((response) => {
@@ -151,6 +148,8 @@ export default function AgentCreate({
setLocalStorageValue("agent_database_" + String(internalId), data.LTM_DB, setDatabase);
setLocalStorageValue("agent_model_" + String(internalId), data.model, setModel);
setLocalStorageArray("tool_names_" + String(internalId), data.tools, setToolNames);
+ setLocalStorageValue("is_agent_template_" + String(internalId), true, setShowButton);
+ setShowButton(true);
})
.catch((error) => {
console.error('Error fetching template details:', error);
@@ -362,33 +361,35 @@ export default function AgentCreate({
}
}, [scheduleData]);
- const handleAddAgent = () => {
- if (!hasAPIkey) {
+ const validateAgentData = (isNewAgent) => {
+ if (isNewAgent && !hasAPIkey) {
toast.error("Your OpenAI/Palm API key is empty!", {autoClose: 1800});
openNewTab(-3, "Settings", "Settings", false);
- return
+ return false;
}
-
- if (agentName.replace(/\s/g, '') === '') {
+
+ if (agentName?.replace(/\s/g, '') === '') {
toast.error("Agent name can't be blank", {autoClose: 1800});
- return
+ return false;
}
-
- if (agentDescription.replace(/\s/g, '') === '') {
+ if (agentDescription?.replace(/\s/g, '') === '') {
toast.error("Agent description can't be blank", {autoClose: 1800});
- return
+ return false;
}
-
const isEmptyGoal = goals.some((goal) => goal.replace(/\s/g, '') === '');
if (isEmptyGoal) {
toast.error("Goal can't be empty", {autoClose: 1800});
- return;
+ return false;
}
-
if (selectedTools.length <= 0) {
toast.error("Add atleast one tool", {autoClose: 1800});
- return
+ return false;
}
+ return true;
+ }
+
+ const handleAddAgent = () => {
+ if (!validateAgentData(true)) return;
setCreateClickable(false);
@@ -531,6 +532,46 @@ export default function AgentCreate({
event.preventDefault();
};
+ function updateTemplate() {
+
+ if (!validateAgentData(false)) return;
+
+ let permission_type = permission;
+ if (permission.includes("RESTRICTED")) {
+ permission_type = "RESTRICTED";
+ }
+
+ const agentTemplateConfigData = {
+ "goal": goals,
+ "instruction": instructions,
+ "agent_type": agentType,
+ "constraints": constraints,
+ "tools": toolNames,
+ "exit": exitCriterion,
+ "iteration_interval": stepTime,
+ "model": model,
+ "max_iterations": maxIterations,
+ "permission_type": permission_type,
+ "LTM_DB": longTermMemory ? database : null,
+ }
+ const editTemplateData = {
+ "name": agentName,
+ "description": agentDescription,
+ "agent_configs": agentTemplateConfigData
+ }
+
+ editAgentTemplate(agentTemplateId, editTemplateData)
+ .then((response) => {
+ if (response.status === 200) {
+ toast.success('Agent template has been updated successfully!', {autoClose: 1800});
+ }
+ })
+ .catch((error) => {
+ toast.error("Error updating agent template")
+ console.error('Error updating agent template:', error);
+ });
+ };
+
function setFileData(files) {
if (files.length > 0) {
const fileData = {
@@ -579,11 +620,21 @@ export default function AgentCreate({
setAdvancedOptions(JSON.parse(advanced_options));
}
+ const is_agent_template = localStorage.getItem("is_agent_template_" + String(internalId));
+ if (is_agent_template) {
+ setShowButton(true);
+ }
+
const agent_name = localStorage.getItem("agent_name_" + String(internalId));
if (agent_name) {
setAgentName(agent_name);
}
+ const agent_template_id = localStorage.getItem("agent_template_id_"+ String(internalId));
+ if(agent_template_id){
+ setAgentTemplateId(agent_template_id)
+ }
+
const agent_description = localStorage.getItem("agent_description_" + String(internalId));
if (agent_description) {
setAgentDescription(agent_description);
@@ -986,6 +1037,12 @@ export default function AgentCreate({
+ {showButton && (
+
+ )}
{createDropdown && (
{
return api.put(`/agentexecutions/update/${executionId}`, executionData);
};
+export const editAgentTemplate = (agentTemplateId, agentTemplateData) => {
+ return api.put(`/agent_templates/update_agent_template/${agentTemplateId}`, agentTemplateData)
+}
+
export const addExecution = (executionData) => {
return api.post(`/agentexecutions/add`, executionData);
};
diff --git a/superagi/controllers/agent_template.py b/superagi/controllers/agent_template.py
index 140622aee..100bee8d0 100644
--- a/superagi/controllers/agent_template.py
+++ b/superagi/controllers/agent_template.py
@@ -13,6 +13,7 @@
from superagi.models.agent_template_config import AgentTemplateConfig
from superagi.models.agent_workflow import AgentWorkflow
from superagi.models.tool import Tool
+import json
# from superagi.types.db import AgentTemplateIn, AgentTemplateOut
router = APIRouter()
@@ -144,6 +145,59 @@ def update_agent_template(agent_template_id: int,
return db_agent_template
+@router.put("/update_agent_template/{agent_template_id}", status_code=200)
+def edit_agent_template(agent_template_id: int,
+ updated_agent_configs: dict,
+ organisation=Depends(get_user_organisation)):
+
+ """
+ Update the details of an agent template.
+
+ Args:
+ agent_template_id (int): The ID of the agent template to update.
+ edited_agent_configs (dict): The updated agent configurations.
+ organisation (Depends): Dependency to get the user organisation.
+
+ Returns:
+ HTTPException (status_code=200): If the agent gets successfully edited.
+
+ Raises:
+ HTTPException (status_code=404): If the agent template is not found.
+ """
+
+ db_agent_template = db.session.query(AgentTemplate).filter(AgentTemplate.organisation_id == organisation.id,
+ AgentTemplate.id == agent_template_id).first()
+ if db_agent_template is None:
+ raise HTTPException(status_code=404, detail="Agent Template not found")
+
+ db_agent_template.name = updated_agent_configs["name"]
+ db_agent_template.description = updated_agent_configs["description"]
+
+ db.session.commit()
+
+ agent_config_values = updated_agent_configs.get('agent_configs', {})
+
+ for key, value in agent_config_values.items():
+ if isinstance(value, (list, dict)):
+ value = json.dumps(value)
+ config = db.session.query(AgentTemplateConfig).filter(
+ AgentTemplateConfig.agent_template_id == agent_template_id,
+ AgentTemplateConfig.key == key
+ ).first()
+
+ if config is not None:
+ config.value = value
+ else:
+ new_config = AgentTemplateConfig(
+ agent_template_id=agent_template_id,
+ key=key,
+ value= value
+ )
+ db.session.add(new_config)
+
+ db.session.commit()
+ db.session.flush()
+
@router.post("/save_agent_as_template/{agent_id}")
def save_agent_as_template(agent_id: str,
diff --git a/tests/unit_tests/controllers/test_agent_template.py b/tests/unit_tests/controllers/test_agent_template.py
new file mode 100644
index 000000000..050e9e216
--- /dev/null
+++ b/tests/unit_tests/controllers/test_agent_template.py
@@ -0,0 +1,127 @@
+from unittest.mock import patch, MagicMock
+from superagi.models.agent_template import AgentTemplate
+from superagi.models.agent_template_config import AgentTemplateConfig
+from fastapi.testclient import TestClient
+from main import app
+
+client = TestClient(app)
+
+@patch('superagi.controllers.agent_template.db')
+@patch('superagi.helper.auth.db')
+@patch('superagi.helper.auth.get_user_organisation')
+def test_edit_agent_template_success(mock_get_user_org, mock_auth_db, mock_db):
+ # Create a mock agent template
+ mock_agent_template = AgentTemplate(id=1, name="Test Agent Template", description="Test Description")
+ # mock_agent_goals = AgentTemplateConfig()
+
+ # Create a mock edited agent configuration
+ mock_updated_agent_configs = {
+ "name": "Updated Agent Template",
+ "description": "Updated Description",
+ "agent_configs": {
+ "goal": ["Create a simple pacman game for me.", "Write all files properly."],
+ "instruction": ["write spec","write code","improve the code","write test"],
+ "agent_type": "Don't Maintain Task Queue",
+ "constraints": ["If you are unsure how you previously did something or want to recall past events, thinking about similar events will help you remember.","Ensure the tool and args are as per current plan and reasoning","Exclusively use the tools listed under \"TOOLS\"","REMEMBER to format your response as JSON, using double quotes (\"\") around keys and string values, and commas (,) to separate items in arrays and objects. IMPORTANTLY, to use a JSON object as a string in another JSON object, you need to escape the double quotes."],
+ "tools": ["Read Email", "Send Email", "Write File"],
+ "exit": "No exit criterion",
+ "iteration_interval": 500,
+ "model": "gpt-4",
+ "max_iterations": 25,
+ "permission_type": "God Mode",
+ "LTM_DB": "Pinecone"
+ }
+ }
+
+ # Mocking the user organisation
+ mock_get_user_org.return_value = MagicMock(id=1)
+
+ # Create a session mock
+ session_mock = MagicMock()
+ mock_db.session = session_mock
+ mock_db.session.query.return_value.filter.return_value.first.return_value = mock_agent_template
+ mock_db.session.commit.return_value = None
+ mock_db.session.add.return_value = None
+ mock_db.session.flush.return_value = None
+
+ mock_agent_template_config = AgentTemplateConfig(agent_template_id = 1, key="goal", value=["Create a simple pacman game for me.", "Write all files properly."])
+
+
+ # Call the endpoint
+ response = client.put("agent_templates/update_agent_template/1", json=mock_updated_agent_configs)
+
+ assert response.status_code == 200
+
+ # Verify changes in the mock agent template
+ assert mock_agent_template.name == "Updated Agent Template"
+ assert mock_agent_template.description == "Updated Description"
+ assert mock_agent_template_config.key == "goal"
+ assert mock_agent_template_config.value == ["Create a simple pacman game for me.", "Write all files properly."]
+
+
+ session_mock.commit.assert_called()
+ session_mock.flush.assert_called()
+
+
+@patch('superagi.controllers.agent_template.db')
+@patch('superagi.helper.auth.db')
+@patch('superagi.helper.auth.get_user_organisation')
+def test_edit_agent_template_failure(mock_get_user_org, mock_auth_db, mock_db):
+ # Setup: The user organisation exists, but the agent template does not exist.
+ mock_get_user_org.return_value = MagicMock(id=1)
+
+ # Create a session mock
+ session_mock = MagicMock()
+ mock_db.session = session_mock
+ mock_db.session.query.return_value.filter.return_value.first.return_value = None
+
+ # Call the endpoint
+ response = client.put("agent_templates/update_agent_template/1", json={})
+
+ # Verify: The response status code should be 404, indicating that the agent template was not found.
+ assert response.status_code == 404
+ assert response.json() == {"detail": "Agent Template not found"}
+
+ # Verify: The database commit method should not have been called because the agent template was not found.
+ session_mock.commit.assert_not_called()
+ session_mock.flush.assert_not_called()
+
+
+@patch('superagi.controllers.agent_template.db')
+@patch('superagi.helper.auth.db')
+@patch('superagi.helper.auth.get_user_organisation')
+def test_edit_agent_template_with_new_config_success(mock_get_user_org, mock_auth_db, mock_db):
+ # Create a mock agent template
+ mock_agent_template = AgentTemplate(id=1, name="Test Agent Template", description="Test Description")
+
+ # Create a mock edited agent configuration
+ mock_updated_agent_configs = {
+ "name": "Updated Agent Template",
+ "description": "Updated Description",
+ "agent_configs": {
+ "new_config_key": "New config value" # This is a new config
+ }
+ }
+
+ # Mocking the user organisation
+ mock_get_user_org.return_value = MagicMock(id=1)
+
+ # Create a session mock
+ session_mock = MagicMock()
+ mock_db.session = session_mock
+ mock_db.session.query.return_value.filter.return_value.first.return_value = mock_agent_template
+ mock_db.session.commit.return_value = None
+ mock_db.session.add.return_value = None
+ mock_db.session.flush.return_value = None
+
+ # Call the endpoint
+ response = client.put("agent_templates/update_agent_template/1", json=mock_updated_agent_configs)
+
+ assert response.status_code == 200
+
+ # Verify changes in the mock agent template
+ assert mock_agent_template.name == "Updated Agent Template"
+ assert mock_agent_template.description == "Updated Description"
+
+ session_mock.commit.assert_called()
+ session_mock.flush.assert_called()
\ No newline at end of file