Skip to content

Commit

Permalink
made changes according to comments of code review
Browse files Browse the repository at this point in the history
  • Loading branch information
Aryan-Singh-14 committed Aug 8, 2023
1 parent 60ce35f commit 11a2c37
Show file tree
Hide file tree
Showing 10 changed files with 57 additions and 66 deletions.
12 changes: 6 additions & 6 deletions migrations/versions/446884dcae58_add_api_key_and_web_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,21 +25,21 @@ def upgrade() -> None:
sa.Column('key', sa.String(), nullable=True),
sa.Column('created_at', sa.DateTime(), nullable=True),
sa.Column('updated_at', sa.DateTime(), nullable=True),
sa.Column('revoked',sa.Boolean(),nullable=True,default=False),
sa.Column('is_expired',sa.Boolean(),nullable=True,default=False),
sa.PrimaryKeyConstraint('id')
)
op.create_table('web_hooks',
op.create_table('webhooks',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('name', sa.String(), nullable=True),
sa.Column('org_id', sa.Integer(), nullable=True),
sa.Column('url', sa.String(), nullable=True),
sa.Column('headers', sa.JSON(), nullable=True),
sa.Column('created_at', sa.DateTime(), nullable=True),
sa.Column('updated_at', sa.DateTime(), nullable=True),
sa.Column('isDeleted',sa.Boolean(),nullable=True),
sa.Column('is_deleted',sa.Boolean(),nullable=True),
sa.PrimaryKeyConstraint('id')
)
op.create_table('web_hook_events',
op.create_table('webhook_events',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('agent_id', sa.Integer(), nullable=True),
sa.Column('run_id', sa.Integer(), nullable=True),
Expand All @@ -58,8 +58,8 @@ def upgrade() -> None:
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###

op.drop_table('web_hooks')
op.drop_table('webhooks')
op.drop_table('api_key')
op.drop_table('web_hook_events')
op.drop_table('webhook_events')

# ### end Alembic commands ###
40 changes: 24 additions & 16 deletions superagi/controllers/api/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,27 +77,23 @@ class Config:
@router.post("",status_code=200)
def create_agent_with_config(agent_with_config: AgentConfigExtInput,
api_key: str = Security(validate_api_key),organisation:Organisation = Depends(get_organisation_from_api_key)):

project=Project.get_project_from_org_id(db.session,organisation.id)
tools_arr=Toolkit.get_tool_and_toolkit_arr(db.session,agent_with_config.tools)
invalid_tools = Tool.get_invalid_tools(tools_arr, db.session)
if len(invalid_tools) > 0: # If the returned value is not True (then it is an invalid tool_id)
raise HTTPException(status_code=404,
detail=f"Tool with IDs {str(invalid_tools)} does not exist. 404 Not Found.")

try:
tools_arr=Toolkit.get_tool_and_toolkit_arr(db.session,agent_with_config.tools)
except Exception as e:
raise HTTPException(status_code=404,detail=str(e))

agent_with_config.tools=tools_arr
agent_with_config.project_id=project.id
agent_with_config.exit="No exit criterion"
agent_with_config.permission_type="God Mode"
agent_with_config.LTM_DB=None

db_agent = Agent.create_agent_with_config(db, agent_with_config)

if agent_with_config.schedule is not None:
agent_schedule = AgentSchedule.get_schedule_from_config(db.session,db_agent,agent_with_config.schedule)
if agent_schedule is None:
raise HTTPException(status_code=500, detail="Failed to schedule agent")

EventHandler(session=db.session).create_event('agent_created', {'agent_name': agent_with_config.name,
'model': agent_with_config.model}, db_agent.id,
organisation.id if organisation else 0)
Expand Down Expand Up @@ -196,12 +192,11 @@ def update_agent(agent_id: int, agent_with_config: AgentConfigUpdateExtInput,api
db_schedule=AgentSchedule.get_schedule_from_agent_id(db.session,agent_id)
if db_schedule is not None:
raise HTTPException(status_code=409, detail="Agent is already scheduled,cannot update")

tools_arr=Toolkit.get_tool_and_toolkit_arr(db.session,agent_with_config.tools)
invalid_tools = Tool.get_invalid_tools(tools_arr, db.session)
if len(invalid_tools) > 0: # If the returned value is not True (then it is an invalid tool_id)
raise HTTPException(status_code=404,
detail=f"Tool with IDs {str(invalid_tools)} does not exist.")

try:
tools_arr=Toolkit.get_tool_and_toolkit_arr(db.session,agent_with_config.tools)
except Exception as e:
raise HTTPException(status_code=404,detail=str(e))

agent_with_config.tools=tools_arr
agent_with_config.project_id=project.id
Expand Down Expand Up @@ -270,6 +265,13 @@ def pause_agent_runs(agent_id:int,execution_state_change_input:ExecutionStateCha
if project.organisation_id!=organisation.id:
raise HTTPException(status_code=404, detail="Agent not found")

#Checking if the run_ids whose output files are requested belong to the organisation
if execution_state_change_input.run_ids is not None:
try:
AgentExecution.validate_run_ids(db.session,execution_state_change_input.run_ids,organisation.id)
except Exception as e:
raise HTTPException(status_code=404, detail=str(e))

db_execution_arr=AgentExecution.get_all_executions_with_status_and_agent_id(db.session,agent.id,execution_state_change_input,"RUNNING")
for ind_execution in db_execution_arr:
ind_execution.status="PAUSED"
Expand All @@ -289,6 +291,12 @@ def resume_agent_runs(agent_id:int,execution_state_change_input:ExecutionStateCh
if project.organisation_id!=organisation.id:
raise HTTPException(status_code=404, detail="Agent not found")

if execution_state_change_input.run_ids is not None:
try:
AgentExecution.validate_run_ids(db.session,execution_state_change_input.run_ids,organisation.id)
except Exception as e:
raise HTTPException(status_code=404, detail=str(e))

db_execution_arr=AgentExecution.get_all_executions_with_status_and_agent_id(db.session,agent.id,execution_state_change_input,"PAUSED")
for ind_execution in db_execution_arr:
ind_execution.status="RUNNING"
Expand All @@ -312,6 +320,6 @@ def get_run_resources(run_id_config:RunIDConfig,api_key: str = Security(validate
raise HTTPException(status_code=404, detail=str(e))

db_resources_arr=Resource.get_all_resources_from_run_ids(db.session,run_ids_arr)
response_obj=S3Helper.get_download_url_of_resources(db_resources_arr)
response_obj=S3Helper().get_download_url_of_resources(db_resources_arr)
return response_obj

42 changes: 13 additions & 29 deletions superagi/controllers/webhook.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,25 @@
import json

from datetime import datetime

from fastapi import APIRouter
from fastapi import HTTPException, Depends ,Security
from fastapi import Depends
from fastapi_jwt_auth import AuthJWT
from fastapi_sqlalchemy import db
from pydantic import BaseModel,Json

from jsonmerge import merge
from pytz import timezone
from sqlalchemy import func, or_
from superagi.models.agent_execution_permission import AgentExecutionPermission
from superagi.worker import execute_agent
from superagi.helper.auth import check_auth,validate_api_key
from superagi.models.agent import Agent
from superagi.models.agent_execution_config import AgentExecutionConfiguration
from superagi.models.agent_config import AgentConfiguration
from superagi.models.agent_schedule import AgentSchedule
from superagi.models.agent_template import AgentTemplate
from superagi.models.project import Project
from superagi.models.agent_execution import AgentExecution
from superagi.models.tool import Tool
from pydantic import BaseModel




from superagi.helper.auth import check_auth

from superagi.models.web_hooks import WebHooks
from superagi.controllers.types.agent_schedule import AgentScheduleInput
from superagi.controllers.types.agent_with_config import AgentConfigInput
from superagi.controllers.types.agent_with_config_schedule import AgentConfigSchedule
from jsonmerge import merge

from datetime import datetime
import json

from superagi.models.toolkit import Toolkit
from superagi.models.knowledges import Knowledges

from sqlalchemy import func
# from superagi.types.db import AgentOut, AgentIn
from superagi.helper.auth import check_auth, get_user_organisation
from superagi.apm.event_handler import EventHandler


router = APIRouter()

Expand All @@ -54,7 +38,7 @@ class WebHookOut(BaseModel):
name: str
url: str
headers: dict
isDeleted: bool
is_deleted: bool
created_at: datetime
updated_at: datetime

Expand All @@ -76,7 +60,7 @@ def create_webhook(webhook: WebHookIn,Authorize: AuthJWT = Depends(check_auth),o
Raises:
HTTPException (Status Code=404): If the associated project is not found.
"""
db_webhook=WebHooks(name=webhook.name,url=webhook.url,headers=webhook.headers,org_id=organisation.id,isDeleted=False)
db_webhook=WebHooks(name=webhook.name,url=webhook.url,headers=webhook.headers,org_id=organisation.id,is_deleted=False)
db.session.add(db_webhook)
db.session.commit()
db.session.flush()
Expand Down
4 changes: 2 additions & 2 deletions superagi/helper/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def get_current_user(Authorize: AuthJWT = Depends(check_auth)):

def validate_api_key(api_key: str = Security(api_key_header)) -> str:
query_result = db.session.query(ApiKey).filter(ApiKey.key == api_key,
or_(ApiKey.revoked == False, ApiKey.revoked == None)).first()
or_(ApiKey.is_expired == False, ApiKey.is_expired == None)).first()
if query_result is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
Expand All @@ -75,7 +75,7 @@ def validate_api_key(api_key: str = Security(api_key_header)) -> str:

def get_organisation_from_api_key(api_key: str = Security(api_key_header)) -> Organisation:
query_result = db.session.query(ApiKey).filter(ApiKey.key == api_key,
or_(ApiKey.revoked == False, ApiKey.revoked == None)).first()
or_(ApiKey.is_expired == False, ApiKey.is_expired == None)).first()
if query_result is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
Expand Down
5 changes: 2 additions & 3 deletions superagi/helper/s3_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,16 +113,15 @@ def upload_file_content(self, content, file_path):
except:
raise HTTPException(status_code=500, detail="AWS credentials not found. Check your configuration.")

@classmethod
def get_download_url_of_resources(cls,db_resources_arr):
def get_download_url_of_resources(self,db_resources_arr):
s3 = boto3.client(
's3',
aws_access_key_id=get_config("AWS_ACCESS_KEY_ID"),
aws_secret_access_key=get_config("AWS_SECRET_ACCESS_KEY"),
)
response_obj={}
for db_resource in db_resources_arr:
response = s3.get_object(Bucket=get_config("BUCKET_NAME"), Key=db_resource.path)
response = self.s3.get_object(Bucket=get_config("BUCKET_NAME"), Key=db_resource.path)
content = response["Body"].read()
bucket_name = get_config("INSTAGRAM_TOOL_BUCKET_NAME")
file_name=db_resource.path.split('/')[-1]
Expand Down
8 changes: 4 additions & 4 deletions superagi/helper/webhook_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,15 @@ def agentStatusChangeCallback(self,agent_execution_id, val,old_val):
for webhook_obj in org_webhooks:
webhook_obj_body={"agent_id":agent_id,"org_id":org_id,"event":f"{old_val} to {val}"}
error=None
r=None
request=None
status='sent'
try:
r = requests.post(webhook_obj.url.strip(), data=json.dumps(webhook_obj_body), headers=webhook_obj.headers)
request = requests.post(webhook_obj.url.strip(), data=json.dumps(webhook_obj_body), headers=webhook_obj.headers)
except Exception as e:
logger.error(f"Exception occured in webhooks {e}")
error=str(e)
if r is not None and r.status_code not in [200,201] and error is None:
error=r.text
if request is not None and request.status_code not in [200,201] and error is None:
error=request.text
if error is not None:
status='Error'
web_hook_event=WebHookEvents(agent_id=agent_id,run_id=agent_execution_id,event=f"{old_val} to {val}",status=status,errors=error)
Expand Down
2 changes: 1 addition & 1 deletion superagi/models/api_key.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class ApiKey(DBBaseModel):
org_id = Column(Integer)
key_name = Column(String)
key = Column(String)
revoked= Column(Boolean)
is_expired= Column(Boolean)



4 changes: 2 additions & 2 deletions superagi/models/resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ def validate_resource_type(storage_type):
raise InvalidResourceType("Invalid resource type")

@classmethod
def get_all_resources_from_run_ids(cls,session,run_ids_arr):
db_resources_arr=session.query(Resource).filter(Resource.agent_execution_id.in_(run_ids_arr)).all()
def get_all_resources_from_run_ids(cls,session,execution_ids):
db_resources_arr=session.query(Resource).filter(Resource.agent_execution_id.in_(execution_ids)).all()
return db_resources_arr

class InvalidResourceType(Exception):
Expand Down
2 changes: 1 addition & 1 deletion superagi/models/web_hook_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class WebHookEvents(DBBaseModel):
Methods:
"""
__tablename__ = 'web_hook_events'
__tablename__ = 'webhook_events'

id = Column(Integer, primary_key=True)
agent_id=Column(Integer)
Expand Down
4 changes: 2 additions & 2 deletions superagi/models/web_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@ class WebHooks(DBBaseModel):
Methods:
"""
__tablename__ = 'web_hooks'
__tablename__ = 'webhooks'

id = Column(Integer, primary_key=True)
name=Column(String)
org_id = Column(Integer)
url = Column(String)
headers=Column(JSON)
isDeleted=Column(Boolean)
is_deleted=Column(Boolean)

0 comments on commit 11a2c37

Please sign in to comment.