Skip to content

Commit

Permalink
Support for job take batch
Browse files Browse the repository at this point in the history
  • Loading branch information
deanq committed Sep 25, 2024
1 parent 43372e2 commit 2a6a5cd
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 80 deletions.
2 changes: 0 additions & 2 deletions runpod/serverless/modules/rp_fastapi.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
""" Used to launch the FastAPI web server when worker is running in API mode. """

# pylint: disable=too-few-public-methods, line-too-long

import os
import threading
import uuid
Expand Down
1 change: 0 additions & 1 deletion runpod/serverless/modules/rp_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ async def _transmit(client_session: ClientSession, url, job_data):
await client_response.text()


# pylint: disable=too-many-arguments, disable=line-too-long
async def _handle_result(
session: ClientSession, job_data, job, url_template, log_message, is_stream=False
):
Expand Down
133 changes: 56 additions & 77 deletions runpod/serverless/modules/rp_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import json
import os
import traceback
from typing import Any, AsyncGenerator, Callable, Dict, Optional, Union
from typing import Any, AsyncGenerator, Callable, Dict, Optional, Union, List

from runpod.http_client import ClientSession
from runpod.serverless.modules.rp_logger import RunPodLogger
Expand All @@ -24,7 +24,7 @@
job_list = JobsQueue()


def _job_get_url():
def _job_get_url(batch_size: int = 1):
"""
Prepare the URL for making a 'get' request to the serverless API (sls).
Expand All @@ -34,89 +34,68 @@ def _job_get_url():
Returns:
str: The prepared URL for the 'get' request to the serverless API.
"""
return JOB_GET_URL + f"&job_in_progress={job_in_progress}"
job_in_progress = "1" if job_list.get_job_count() else "0"

if batch_size > 1:
job_take_url = JOB_GET_URL.replace("/job-take/", "/job-take-batch/")
job_take_url += f"&batch_size={batch_size}&batch_strategy=LMove"
else:
job_take_url = JOB_GET_URL

return job_take_url + f"&job_in_progress={job_in_progress}"


async def get_job(session: ClientSession, retry=True) -> Optional[Dict[str, Any]]:
async def get_job(
session: ClientSession, num_jobs: int = 1
) -> Optional[List[Dict[str, Any]]]:
"""
Get the job from the queue.
Will continue trying to get a job until one is available.
Get a job from the job-take API.
`num_jobs = 1` will query the legacy singular job-take API.
`num_jobs > 1` will query the batch job-take API.
Args:
session (ClientSession): The async http client to use for the request.
retry (bool): Whether to retry if no job is available.
session (ClientSession): The aiohttp ClientSession to use for the request.
num_jobs (int): The number of jobs to get.
"""
next_job = None

while next_job is None:
try:
async with session.get(_job_get_url()) as response:
if response.status == 204:
log.debug("No content, no job to process.")
if retry is False:
break
continue

if response.status == 400:
log.debug(
"Received 400 status, expected when FlashBoot is enabled."
)
if retry is False:
break
continue

if response.status != 200:
log.error(f"Failed to get job, status code: {response.status}")
if retry is False:
break
continue

received_request = await response.json()
log.debug(f"Request Received | {received_request}")

# Check if the job is valid
job_id = received_request.get("id", None)
job_input = received_request.get("input", None)

if None in [job_id, job_input]:
missing_fields = []
if job_id is None:
missing_fields.append("id")
if job_input is None:
missing_fields.append("input")

log.error(f"Job has missing field(s): {', '.join(missing_fields)}.")
else:
next_job = received_request

except asyncio.TimeoutError:
log.debug("Timeout error, retrying.")
if retry is False:
break

except Exception as err: # pylint: disable=broad-except
err_type = type(err).__name__
err_message = str(err)
err_traceback = traceback.format_exc()
log.error(
f"Failed to get job. | Error Type: {err_type} | Error Message: {err_message}"
)
log.error(f"Traceback: {err_traceback}")

if next_job is None:
log.debug("No job available, waiting for the next one.")
if retry is False:
break

await asyncio.sleep(1)
else:
job_list.add_job(next_job["id"])
log.debug("Request ID added.", next_job["id"])

return next_job
try:
async with session.get(_job_get_url(num_jobs)) as response:
if response.status == 204:
log.debug("No content, no job to process.")
return

if response.status == 400:
log.debug("Received 400 status, expected when FlashBoot is enabled.")
return

if response.status != 200:
log.error(f"Failed to get job, status code: {response.status}")
return

jobs = await response.json()
log.debug(f"Request Received | {jobs}")

# legacy job-take API
if isinstance(jobs, dict):
if "id" not in jobs or "input" not in jobs:
raise Exception("Job has missing field(s): id or input.")
return [jobs]

# batch job-take API
if isinstance(jobs, list):
return jobs

except asyncio.TimeoutError:
log.debug("Timeout error, retrying.")

except Exception as error:
log.error(
f"Failed to get job. | Error Type: {type(error).__name__} | Error Message: {str(error)}"
)

return None
# empty
return []


async def run_job(handler: Callable, job: Dict[str, Any]) -> Dict[str, Any]:
Expand Down

0 comments on commit 2a6a5cd

Please sign in to comment.