From 11a2c3799db8c37bdd4145717e6832d8dc145018 Mon Sep 17 00:00:00 2001 From: Maverick-F35 Date: Tue, 8 Aug 2023 11:16:01 +0530 Subject: [PATCH] made changes according to comments of code review --- .../446884dcae58_add_api_key_and_web_hook.py | 12 +++--- superagi/controllers/api/agent.py | 40 +++++++++++------- superagi/controllers/webhook.py | 42 ++++++------------- superagi/helper/auth.py | 4 +- superagi/helper/s3_helper.py | 5 +-- superagi/helper/webhook_manager.py | 8 ++-- superagi/models/api_key.py | 2 +- superagi/models/resource.py | 4 +- superagi/models/web_hook_events.py | 2 +- superagi/models/web_hooks.py | 4 +- 10 files changed, 57 insertions(+), 66 deletions(-) diff --git a/migrations/versions/446884dcae58_add_api_key_and_web_hook.py b/migrations/versions/446884dcae58_add_api_key_and_web_hook.py index c967bcd9d..e8ee27d3d 100644 --- a/migrations/versions/446884dcae58_add_api_key_and_web_hook.py +++ b/migrations/versions/446884dcae58_add_api_key_and_web_hook.py @@ -25,10 +25,10 @@ def upgrade() -> None: sa.Column('key', sa.String(), nullable=True), sa.Column('created_at', sa.DateTime(), nullable=True), sa.Column('updated_at', sa.DateTime(), nullable=True), - sa.Column('revoked',sa.Boolean(),nullable=True,default=False), + sa.Column('is_expired',sa.Boolean(),nullable=True,default=False), sa.PrimaryKeyConstraint('id') ) - op.create_table('web_hooks', + op.create_table('webhooks', sa.Column('id', sa.Integer(), nullable=False), sa.Column('name', sa.String(), nullable=True), sa.Column('org_id', sa.Integer(), nullable=True), @@ -36,10 +36,10 @@ def upgrade() -> None: sa.Column('headers', sa.JSON(), nullable=True), sa.Column('created_at', sa.DateTime(), nullable=True), sa.Column('updated_at', sa.DateTime(), nullable=True), - sa.Column('isDeleted',sa.Boolean(),nullable=True), + sa.Column('is_deleted',sa.Boolean(),nullable=True), sa.PrimaryKeyConstraint('id') ) - op.create_table('web_hook_events', + op.create_table('webhook_events', sa.Column('id', sa.Integer(), nullable=False), sa.Column('agent_id', sa.Integer(), nullable=True), sa.Column('run_id', sa.Integer(), nullable=True), @@ -58,8 +58,8 @@ def upgrade() -> None: def downgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### - op.drop_table('web_hooks') + op.drop_table('webhooks') op.drop_table('api_key') - op.drop_table('web_hook_events') + op.drop_table('webhook_events') # ### end Alembic commands ### diff --git a/superagi/controllers/api/agent.py b/superagi/controllers/api/agent.py index cc20f6e36..ae0fce85e 100644 --- a/superagi/controllers/api/agent.py +++ b/superagi/controllers/api/agent.py @@ -77,27 +77,23 @@ class Config: @router.post("",status_code=200) def create_agent_with_config(agent_with_config: AgentConfigExtInput, api_key: str = Security(validate_api_key),organisation:Organisation = Depends(get_organisation_from_api_key)): - project=Project.get_project_from_org_id(db.session,organisation.id) - tools_arr=Toolkit.get_tool_and_toolkit_arr(db.session,agent_with_config.tools) - invalid_tools = Tool.get_invalid_tools(tools_arr, db.session) - if len(invalid_tools) > 0: # If the returned value is not True (then it is an invalid tool_id) - raise HTTPException(status_code=404, - detail=f"Tool with IDs {str(invalid_tools)} does not exist. 404 Not Found.") - + try: + tools_arr=Toolkit.get_tool_and_toolkit_arr(db.session,agent_with_config.tools) + except Exception as e: + raise HTTPException(status_code=404,detail=str(e)) + agent_with_config.tools=tools_arr agent_with_config.project_id=project.id agent_with_config.exit="No exit criterion" agent_with_config.permission_type="God Mode" agent_with_config.LTM_DB=None - db_agent = Agent.create_agent_with_config(db, agent_with_config) if agent_with_config.schedule is not None: agent_schedule = AgentSchedule.get_schedule_from_config(db.session,db_agent,agent_with_config.schedule) if agent_schedule is None: raise HTTPException(status_code=500, detail="Failed to schedule agent") - EventHandler(session=db.session).create_event('agent_created', {'agent_name': agent_with_config.name, 'model': agent_with_config.model}, db_agent.id, organisation.id if organisation else 0) @@ -196,12 +192,11 @@ def update_agent(agent_id: int, agent_with_config: AgentConfigUpdateExtInput,api db_schedule=AgentSchedule.get_schedule_from_agent_id(db.session,agent_id) if db_schedule is not None: raise HTTPException(status_code=409, detail="Agent is already scheduled,cannot update") - - tools_arr=Toolkit.get_tool_and_toolkit_arr(db.session,agent_with_config.tools) - invalid_tools = Tool.get_invalid_tools(tools_arr, db.session) - if len(invalid_tools) > 0: # If the returned value is not True (then it is an invalid tool_id) - raise HTTPException(status_code=404, - detail=f"Tool with IDs {str(invalid_tools)} does not exist.") + + try: + tools_arr=Toolkit.get_tool_and_toolkit_arr(db.session,agent_with_config.tools) + except Exception as e: + raise HTTPException(status_code=404,detail=str(e)) agent_with_config.tools=tools_arr agent_with_config.project_id=project.id @@ -270,6 +265,13 @@ def pause_agent_runs(agent_id:int,execution_state_change_input:ExecutionStateCha if project.organisation_id!=organisation.id: raise HTTPException(status_code=404, detail="Agent not found") + #Checking if the run_ids whose output files are requested belong to the organisation + if execution_state_change_input.run_ids is not None: + try: + AgentExecution.validate_run_ids(db.session,execution_state_change_input.run_ids,organisation.id) + except Exception as e: + raise HTTPException(status_code=404, detail=str(e)) + db_execution_arr=AgentExecution.get_all_executions_with_status_and_agent_id(db.session,agent.id,execution_state_change_input,"RUNNING") for ind_execution in db_execution_arr: ind_execution.status="PAUSED" @@ -289,6 +291,12 @@ def resume_agent_runs(agent_id:int,execution_state_change_input:ExecutionStateCh if project.organisation_id!=organisation.id: raise HTTPException(status_code=404, detail="Agent not found") + if execution_state_change_input.run_ids is not None: + try: + AgentExecution.validate_run_ids(db.session,execution_state_change_input.run_ids,organisation.id) + except Exception as e: + raise HTTPException(status_code=404, detail=str(e)) + db_execution_arr=AgentExecution.get_all_executions_with_status_and_agent_id(db.session,agent.id,execution_state_change_input,"PAUSED") for ind_execution in db_execution_arr: ind_execution.status="RUNNING" @@ -312,6 +320,6 @@ def get_run_resources(run_id_config:RunIDConfig,api_key: str = Security(validate raise HTTPException(status_code=404, detail=str(e)) db_resources_arr=Resource.get_all_resources_from_run_ids(db.session,run_ids_arr) - response_obj=S3Helper.get_download_url_of_resources(db_resources_arr) + response_obj=S3Helper().get_download_url_of_resources(db_resources_arr) return response_obj diff --git a/superagi/controllers/webhook.py b/superagi/controllers/webhook.py index ebdfc91f8..62cb32d3c 100644 --- a/superagi/controllers/webhook.py +++ b/superagi/controllers/webhook.py @@ -1,41 +1,25 @@ -import json + from datetime import datetime from fastapi import APIRouter -from fastapi import HTTPException, Depends ,Security +from fastapi import Depends from fastapi_jwt_auth import AuthJWT from fastapi_sqlalchemy import db -from pydantic import BaseModel,Json - -from jsonmerge import merge -from pytz import timezone -from sqlalchemy import func, or_ -from superagi.models.agent_execution_permission import AgentExecutionPermission -from superagi.worker import execute_agent -from superagi.helper.auth import check_auth,validate_api_key -from superagi.models.agent import Agent -from superagi.models.agent_execution_config import AgentExecutionConfiguration -from superagi.models.agent_config import AgentConfiguration -from superagi.models.agent_schedule import AgentSchedule -from superagi.models.agent_template import AgentTemplate -from superagi.models.project import Project -from superagi.models.agent_execution import AgentExecution -from superagi.models.tool import Tool +from pydantic import BaseModel + + + + +from superagi.helper.auth import check_auth + from superagi.models.web_hooks import WebHooks -from superagi.controllers.types.agent_schedule import AgentScheduleInput -from superagi.controllers.types.agent_with_config import AgentConfigInput -from superagi.controllers.types.agent_with_config_schedule import AgentConfigSchedule -from jsonmerge import merge + from datetime import datetime -import json -from superagi.models.toolkit import Toolkit -from superagi.models.knowledges import Knowledges -from sqlalchemy import func # from superagi.types.db import AgentOut, AgentIn from superagi.helper.auth import check_auth, get_user_organisation -from superagi.apm.event_handler import EventHandler + router = APIRouter() @@ -54,7 +38,7 @@ class WebHookOut(BaseModel): name: str url: str headers: dict - isDeleted: bool + is_deleted: bool created_at: datetime updated_at: datetime @@ -76,7 +60,7 @@ def create_webhook(webhook: WebHookIn,Authorize: AuthJWT = Depends(check_auth),o Raises: HTTPException (Status Code=404): If the associated project is not found. """ - db_webhook=WebHooks(name=webhook.name,url=webhook.url,headers=webhook.headers,org_id=organisation.id,isDeleted=False) + db_webhook=WebHooks(name=webhook.name,url=webhook.url,headers=webhook.headers,org_id=organisation.id,is_deleted=False) db.session.add(db_webhook) db.session.commit() db.session.flush() diff --git a/superagi/helper/auth.py b/superagi/helper/auth.py index a30eb7b4b..f916a3ae5 100644 --- a/superagi/helper/auth.py +++ b/superagi/helper/auth.py @@ -63,7 +63,7 @@ def get_current_user(Authorize: AuthJWT = Depends(check_auth)): def validate_api_key(api_key: str = Security(api_key_header)) -> str: query_result = db.session.query(ApiKey).filter(ApiKey.key == api_key, - or_(ApiKey.revoked == False, ApiKey.revoked == None)).first() + or_(ApiKey.is_expired == False, ApiKey.is_expired == None)).first() if query_result is None: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -75,7 +75,7 @@ def validate_api_key(api_key: str = Security(api_key_header)) -> str: def get_organisation_from_api_key(api_key: str = Security(api_key_header)) -> Organisation: query_result = db.session.query(ApiKey).filter(ApiKey.key == api_key, - or_(ApiKey.revoked == False, ApiKey.revoked == None)).first() + or_(ApiKey.is_expired == False, ApiKey.is_expired == None)).first() if query_result is None: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, diff --git a/superagi/helper/s3_helper.py b/superagi/helper/s3_helper.py index 59eae8bac..5c526ee95 100644 --- a/superagi/helper/s3_helper.py +++ b/superagi/helper/s3_helper.py @@ -113,8 +113,7 @@ def upload_file_content(self, content, file_path): except: raise HTTPException(status_code=500, detail="AWS credentials not found. Check your configuration.") - @classmethod - def get_download_url_of_resources(cls,db_resources_arr): + def get_download_url_of_resources(self,db_resources_arr): s3 = boto3.client( 's3', aws_access_key_id=get_config("AWS_ACCESS_KEY_ID"), @@ -122,7 +121,7 @@ def get_download_url_of_resources(cls,db_resources_arr): ) response_obj={} for db_resource in db_resources_arr: - response = s3.get_object(Bucket=get_config("BUCKET_NAME"), Key=db_resource.path) + response = self.s3.get_object(Bucket=get_config("BUCKET_NAME"), Key=db_resource.path) content = response["Body"].read() bucket_name = get_config("INSTAGRAM_TOOL_BUCKET_NAME") file_name=db_resource.path.split('/')[-1] diff --git a/superagi/helper/webhook_manager.py b/superagi/helper/webhook_manager.py index 163479fb8..690d0aa15 100644 --- a/superagi/helper/webhook_manager.py +++ b/superagi/helper/webhook_manager.py @@ -21,15 +21,15 @@ def agentStatusChangeCallback(self,agent_execution_id, val,old_val): for webhook_obj in org_webhooks: webhook_obj_body={"agent_id":agent_id,"org_id":org_id,"event":f"{old_val} to {val}"} error=None - r=None + request=None status='sent' try: - r = requests.post(webhook_obj.url.strip(), data=json.dumps(webhook_obj_body), headers=webhook_obj.headers) + request = requests.post(webhook_obj.url.strip(), data=json.dumps(webhook_obj_body), headers=webhook_obj.headers) except Exception as e: logger.error(f"Exception occured in webhooks {e}") error=str(e) - if r is not None and r.status_code not in [200,201] and error is None: - error=r.text + if request is not None and request.status_code not in [200,201] and error is None: + error=request.text if error is not None: status='Error' web_hook_event=WebHookEvents(agent_id=agent_id,run_id=agent_execution_id,event=f"{old_val} to {val}",status=status,errors=error) diff --git a/superagi/models/api_key.py b/superagi/models/api_key.py index f9a5fe3f0..dfea65b01 100644 --- a/superagi/models/api_key.py +++ b/superagi/models/api_key.py @@ -18,7 +18,7 @@ class ApiKey(DBBaseModel): org_id = Column(Integer) key_name = Column(String) key = Column(String) - revoked= Column(Boolean) + is_expired= Column(Boolean) diff --git a/superagi/models/resource.py b/superagi/models/resource.py index b4452128c..c0c50bc45 100644 --- a/superagi/models/resource.py +++ b/superagi/models/resource.py @@ -60,8 +60,8 @@ def validate_resource_type(storage_type): raise InvalidResourceType("Invalid resource type") @classmethod - def get_all_resources_from_run_ids(cls,session,run_ids_arr): - db_resources_arr=session.query(Resource).filter(Resource.agent_execution_id.in_(run_ids_arr)).all() + def get_all_resources_from_run_ids(cls,session,execution_ids): + db_resources_arr=session.query(Resource).filter(Resource.agent_execution_id.in_(execution_ids)).all() return db_resources_arr class InvalidResourceType(Exception): diff --git a/superagi/models/web_hook_events.py b/superagi/models/web_hook_events.py index a56f79270..ecca62910 100644 --- a/superagi/models/web_hook_events.py +++ b/superagi/models/web_hook_events.py @@ -12,7 +12,7 @@ class WebHookEvents(DBBaseModel): Methods: """ - __tablename__ = 'web_hook_events' + __tablename__ = 'webhook_events' id = Column(Integer, primary_key=True) agent_id=Column(Integer) diff --git a/superagi/models/web_hooks.py b/superagi/models/web_hooks.py index e7fdec4c3..a876040bc 100644 --- a/superagi/models/web_hooks.py +++ b/superagi/models/web_hooks.py @@ -12,11 +12,11 @@ class WebHooks(DBBaseModel): Methods: """ - __tablename__ = 'web_hooks' + __tablename__ = 'webhooks' id = Column(Integer, primary_key=True) name=Column(String) org_id = Column(Integer) url = Column(String) headers=Column(JSON) - isDeleted=Column(Boolean) + is_deleted=Column(Boolean)