Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add sorting to task api #1018

Merged
merged 1 commit into from
Oct 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions skyvern/forge/sdk/db/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
)
from skyvern.forge.sdk.models import Organization, OrganizationAuthToken, Step, StepStatus
from skyvern.forge.sdk.schemas.task_generations import TaskGeneration
from skyvern.forge.sdk.schemas.tasks import ProxyLocation, Task, TaskStatus
from skyvern.forge.sdk.schemas.tasks import OrderBy, ProxyLocation, SortDirection, Task, TaskStatus
from skyvern.forge.sdk.schemas.totp_codes import TOTPCode
from skyvern.forge.sdk.workflow.models.parameter import (
AWSSecretParameter,
Expand Down Expand Up @@ -461,6 +461,8 @@ async def get_tasks(
workflow_run_id: str | None = None,
organization_id: str | None = None,
only_standalone_tasks: bool = False,
order_by_column: OrderBy = OrderBy.created_at,
order: SortDirection = SortDirection.desc,
) -> list[Task]:
"""
Get all tasks.
Expand All @@ -469,6 +471,8 @@ async def get_tasks(
:param task_status:
:param workflow_run_id:
:param only_standalone_tasks:
:param order_by_column:
:param order:
:return:
"""
if page < 1:
Expand All @@ -484,7 +488,12 @@ async def get_tasks(
query = query.filter(TaskModel.workflow_run_id == workflow_run_id)
if only_standalone_tasks:
query = query.filter(TaskModel.workflow_run_id.is_(None))
query = query.order_by(TaskModel.created_at.desc()).limit(page_size).offset(db_page * page_size)
order_by_col = getattr(TaskModel, order_by_column)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider adding validation for order_by_column to ensure it is a valid attribute of TaskModel. This will prevent runtime errors if an invalid column name is passed.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider validating order_by_column before using getattr to ensure it is a valid attribute of TaskModel. This can prevent potential runtime errors.

query = (
query.order_by(order_by_col.desc() if order == SortDirection.desc else order_by_col.asc())
.limit(page_size)
.offset(db_page * page_size)
)
tasks = (await session.scalars(query)).all()
return [convert_to_task(task, debug_enabled=self.debug_enabled) for task in tasks]
except SQLAlchemyError:
Expand Down
16 changes: 15 additions & 1 deletion skyvern/forge/sdk/routes/agent_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,15 @@
OrganizationUpdate,
)
from skyvern.forge.sdk.schemas.task_generations import GenerateTaskRequest, TaskGeneration, TaskGenerationBase
from skyvern.forge.sdk.schemas.tasks import CreateTaskResponse, Task, TaskRequest, TaskResponse, TaskStatus
from skyvern.forge.sdk.schemas.tasks import (
CreateTaskResponse,
OrderBy,
SortDirection,
Task,
TaskRequest,
TaskResponse,
TaskStatus,
)
from skyvern.forge.sdk.services import org_auth_service
from skyvern.forge.sdk.settings_manager import SettingsManager
from skyvern.forge.sdk.workflow.exceptions import FailedToCreateWorkflow, FailedToUpdateWorkflow
Expand Down Expand Up @@ -385,6 +393,8 @@ async def get_agent_tasks(
workflow_run_id: Annotated[str | None, Query()] = None,
current_org: Organization = Depends(org_auth_service.get_current_org),
only_standalone_tasks: bool = Query(False),
sort: OrderBy = Query(OrderBy.created_at),
order: SortDirection = Query(SortDirection.desc),
) -> Response:
"""
Get all tasks.
Expand All @@ -393,6 +403,8 @@ async def get_agent_tasks(
:param task_status: Task status filter
:param workflow_run_id: Workflow run id filter
:param only_standalone_tasks: Only standalone tasks, tasks which are part of a workflow run will be filtered out
:param order: Direction to sort by, ascending or descending
:param sort: Column to sort by, created_at or modified_at
:return: List of tasks with pagination without steps populated. Steps can be populated by calling the
get_agent_task endpoint.
"""
Expand All @@ -409,6 +421,8 @@ async def get_agent_tasks(
workflow_run_id=workflow_run_id,
organization_id=current_org.organization_id,
only_standalone_tasks=only_standalone_tasks,
order=order,
order_by_column=sort,
)
return ORJSONResponse([task.to_task_response().model_dump() for task in tasks])

Expand Down
10 changes: 10 additions & 0 deletions skyvern/forge/sdk/schemas/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,3 +290,13 @@ def from_task(task: Task) -> TaskOutput:

class CreateTaskResponse(BaseModel):
task_id: str


class OrderBy(StrEnum):
created_at = "created_at"
modified_at = "modified_at"


class SortDirection(StrEnum):
asc = "asc"
desc = "desc"
Loading