Skip to content

Commit

Permalink
Blocking job take call means 5-sec debounce no longer needed (#366)
Browse files Browse the repository at this point in the history
Fix: This was causing unnecessary delays in serverless workers.

Refactored rp_job.get_job to work well under pause and unpause conditions. More debug lines too.
Refactored rp_scale.JobScaler to handle shutdowns where it cleans up hanging tasks and connections gracefully. Better debug lines.
Fixed rp_scale.JobScaler from unnecessary long asyncio.sleeps made before considering the blocking get_job calls.
Improved worker_state's JobProgress and JobsQueue to timestamp when jobs are added or removed.
Incorporated the lines of code in worker.run_worker into rp_scale.JobScaler where it belongs and simplified to job_scaler.start()
Fixed non-error logged as errors in tracer
Updated unit tests mandating these changes
  • Loading branch information
deanq authored Oct 12, 2024
1 parent 5d1cec6 commit 5a6b911
Show file tree
Hide file tree
Showing 7 changed files with 316 additions and 269 deletions.
77 changes: 39 additions & 38 deletions runpod/serverless/modules/rp_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
Job related helpers.
"""

import asyncio
import inspect
import json
import os
Expand Down Expand Up @@ -60,43 +59,45 @@ async def get_job(
session (ClientSession): The aiohttp ClientSession to use for the request.
num_jobs (int): The number of jobs to get.
"""
try:
async with session.get(_job_get_url(num_jobs)) as response:
if response.status == 204:
log.debug("No content, no job to process.")
return

if response.status == 400:
log.debug("Received 400 status, expected when FlashBoot is enabled.")
return

if response.status != 200:
log.error(f"Failed to get job, status code: {response.status}")
return

jobs = await response.json()
log.debug(f"Request Received | {jobs}")

# legacy job-take API
if isinstance(jobs, dict):
if "id" not in jobs or "input" not in jobs:
raise Exception("Job has missing field(s): id or input.")
return [jobs]

# batch job-take API
if isinstance(jobs, list):
return jobs

except asyncio.TimeoutError:
log.debug("Timeout error, retrying.")

except Exception as error:
log.error(
f"Failed to get job. | Error Type: {type(error).__name__} | Error Message: {str(error)}"
)

# empty
return []
async with session.get(_job_get_url(num_jobs)) as response:
log.debug(f"- Response: {type(response).__name__} {response.status}")

if response.status == 204:
log.debug("- No content, no job to process.")
return

if response.status == 400:
log.debug("- Received 400 status, expected when FlashBoot is enabled.")
return

try:
response.raise_for_status()
except Exception:
log.error(f"- Failed to get job, status code: {response.status}")
return

# Verify if the content type is JSON
if response.content_type != "application/json":
log.error(f"- Unexpected content type: {response.content_type}")
return

# Check if there is a non-empty content to parse
if response.content_length == 0:
log.debug("- No content to parse.")
return

jobs = await response.json()
log.debug(f"- Received Job(s)")

# legacy job-take API
if isinstance(jobs, dict):
if "id" not in jobs or "input" not in jobs:
raise Exception("Job has missing field(s): id or input.")
return [jobs]

# batch job-take API
if isinstance(jobs, list):
return jobs


async def handle_job(session: ClientSession, config: Dict[str, Any], job) -> dict:
Expand Down
146 changes: 112 additions & 34 deletions runpod/serverless/modules/rp_scale.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
"""

import asyncio
import signal
from typing import Any, Dict

from ...http_client import ClientSession
from ...http_client import AsyncClientSession, ClientSession
from .rp_job import get_job, handle_job
from .rp_logger import RunPodLogger
from .worker_state import JobsQueue, JobsProgress
Expand Down Expand Up @@ -36,26 +37,91 @@ class JobScaler:
Job Scaler. This class is responsible for scaling the number of concurrent requests.
"""

def __init__(self, concurrency_modifier: Any):
def __init__(self, config: Dict[str, Any]):
concurrency_modifier = config.get("concurrency_modifier")
if concurrency_modifier is None:
self.concurrency_modifier = _default_concurrency_modifier
else:
self.concurrency_modifier = concurrency_modifier

self._shutdown_event = asyncio.Event()
self.current_concurrency = 1
self._is_alive = True
self.config = config

def start(self):
"""
This is required for the worker to be able to shut down gracefully
when the user sends a SIGTERM or SIGINT signal. This is typically
the case when the worker is running in a container.
"""
try:
# Register signal handlers for graceful shutdown
signal.signal(signal.SIGTERM, self.handle_shutdown)
signal.signal(signal.SIGINT, self.handle_shutdown)
except ValueError:
log.warning("Signal handling is only supported in the main thread.")

# Start the main loop
# Run forever until the worker is signalled to shut down.
asyncio.run(self.run())

def handle_shutdown(self, signum, frame):
"""
Called when the worker is signalled to shut down.
This function is called when the worker receives a signal to shut down, such as
SIGTERM or SIGINT. It sets the shutdown event, which will cause the worker to
exit its main loop and shut down gracefully.
Args:
signum: The signal number that was received.
frame: The current stack frame.
"""
log.debug(f"Received shutdown signal: {signum}.")
self.kill_worker()

async def run(self):
# Create an async session that will be closed when the worker is killed.

async with AsyncClientSession() as session:
# Create tasks for getting and running jobs.
jobtake_task = asyncio.create_task(self.get_jobs(session))
jobrun_task = asyncio.create_task(self.run_jobs(session))

tasks = [jobtake_task, jobrun_task]

try:
# Concurrently run both tasks and wait for both to finish.
await asyncio.gather(*tasks)
except asyncio.CancelledError: # worker is killed
log.debug("Worker tasks cancelled.")
self.kill_worker()
finally:
# Handle the task cancellation gracefully
for task in tasks:
if not task.done():
task.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
await self.cleanup() # Ensure resources are cleaned up

async def cleanup(self):
# Perform any necessary cleanup here, such as closing connections
log.debug("Cleaning up resources before shutdown.")
# TODO: stop heartbeat or close any open connections
await asyncio.sleep(0) # Give a chance for other tasks to run (optional)
log.debug("Cleanup complete.")

def is_alive(self):
"""
Return whether the worker is alive or not.
"""
return self._is_alive
return not self._shutdown_event.is_set()

def kill_worker(self):
"""
Whether to kill the worker.
"""
self._is_alive = False
self._shutdown_event.set()

async def get_jobs(self, session: ClientSession):
"""
Expand All @@ -66,38 +132,50 @@ async def get_jobs(self, session: ClientSession):
Adds jobs to the JobsQueue
"""
while self.is_alive():
log.debug(f"Jobs in progress: {job_progress.get_job_count()}")

try:
self.current_concurrency = self.concurrency_modifier(
self.current_concurrency
)
log.debug(f"Concurrency set to: {self.current_concurrency}")

jobs_needed = self.current_concurrency - job_progress.get_job_count()
if not jobs_needed: # zero or less
log.debug("Queue is full. Retrying soon.")
continue
log.debug(f"JobScaler.get_jobs | Jobs in progress: {job_progress.get_job_count()}")

acquired_jobs = await get_job(session, jobs_needed)
if not acquired_jobs:
log.debug("No jobs acquired.")
continue
self.current_concurrency = self.concurrency_modifier(
self.current_concurrency
)
log.debug(f"JobScaler.get_jobs | Concurrency set to: {self.current_concurrency}")

for job in acquired_jobs:
await job_list.add_job(job)

log.info(f"Jobs in queue: {job_list.get_job_count()}")
jobs_needed = self.current_concurrency - job_progress.get_job_count()
if jobs_needed <= 0:
log.debug("JobScaler.get_jobs | Queue is full. Retrying soon.")
await asyncio.sleep(0.1) # don't go rapidly
continue

try:
# Keep the connection to the blocking call up to 30 seconds
acquired_jobs = await asyncio.wait_for(
get_job(session, jobs_needed), timeout=30
)
except asyncio.CancelledError:
log.debug("JobScaler.get_jobs | Request was cancelled.")
continue
except TimeoutError:
log.debug("JobScaler.get_jobs | Job acquisition timed out. Retrying.")
continue
except TypeError as error:
log.debug(f"JobScaler.get_jobs | Unexpected error: {error}.")
continue
except Exception as error:
log.error(
f"Failed to get job. | Error Type: {type(error).__name__} | Error Message: {str(error)}"
)
continue

finally:
await asyncio.sleep(5) # yield control back to the event loop
if not acquired_jobs:
log.debug("JobScaler.get_jobs | No jobs acquired.")
await asyncio.sleep(0)
continue

async def run_jobs(self, session: ClientSession, config: Dict[str, Any]):
for job in acquired_jobs:
await job_list.add_job(job)

log.info(f"Jobs in queue: {job_list.get_job_count()}")

async def run_jobs(self, session: ClientSession):
"""
Retrieve jobs from the jobs queue and process them concurrently.
Expand All @@ -111,7 +189,7 @@ async def run_jobs(self, session: ClientSession, config: Dict[str, Any]):
job = await job_list.get_job()

# Create a new task for each job and add it to the task list
task = asyncio.create_task(self.handle_job(session, config, job))
task = asyncio.create_task(self.handle_job(session, job))
tasks.append(task)

# Wait for any job to finish
Expand All @@ -131,19 +209,19 @@ async def run_jobs(self, session: ClientSession, config: Dict[str, Any]):
# Ensure all remaining tasks finish before stopping
await asyncio.gather(*tasks)

async def handle_job(self, session: ClientSession, config: Dict[str, Any], job):
async def handle_job(self, session: ClientSession, job: dict):
"""
Process an individual job. This function is run concurrently for multiple jobs.
"""
log.debug(f"Processing job: {job}")
log.debug(f"JobScaler.handle_job | {job}")
job_progress.add(job)

try:
await handle_job(session, config, job)
await handle_job(session, self.config, job)

if config.get("refresh_worker", False):
if self.config.get("refresh_worker", False):
self.kill_worker()

except Exception as err:
log.error(f"Error handling job: {err}", job["id"])
raise err
Expand Down
8 changes: 8 additions & 0 deletions runpod/serverless/modules/worker_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@
from typing import Any, Dict, Optional
from asyncio import Queue

from .rp_logger import RunPodLogger


log = RunPodLogger()

REF_COUNT_ZERO = time.perf_counter() # Used for benchmarking with the debugger.

WORKER_ID = os.environ.get("RUNPOD_POD_ID", str(uuid.uuid4()))
Expand Down Expand Up @@ -87,6 +92,7 @@ def add(self, element: Any):
if not isinstance(element, Job):
raise TypeError("Only Job objects can be added to JobsProgress.")

log.debug(f"JobsProgress.add | {element}")
return super().add(element)

def remove(self, element: Any):
Expand All @@ -106,6 +112,7 @@ def remove(self, element: Any):
if not isinstance(element, Job):
raise TypeError("Only Job objects can be removed from JobsProgress.")

log.debug(f"JobsProgress.remove | {element}")
return super().remove(element)

def get(self, element: Any) -> Job:
Expand Down Expand Up @@ -155,6 +162,7 @@ async def add_job(self, job: dict):
If the queue is full, wait until a free
slot is available before adding item.
"""
log.debug(f"JobsQueue.add_job | {job}")
return await self.put(job)

async def get_job(self) -> dict:
Expand Down
31 changes: 5 additions & 26 deletions runpod/serverless/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import os
from typing import Any, Dict

from runpod.http_client import AsyncClientSession
from runpod.serverless.modules import rp_logger, rp_local, rp_ping, rp_scale

log = rp_logger.RunPodLogger()
Expand All @@ -26,7 +25,7 @@ def _is_local(config) -> bool:


# ------------------------- Main Worker Running Loop ------------------------- #
async def run_worker(config: Dict[str, Any]) -> None:
def run_worker(config: Dict[str, Any]) -> None:
"""
Starts the worker loop for multi-processing.
Expand All @@ -39,29 +38,9 @@ async def run_worker(config: Dict[str, Any]) -> None:
# Start pinging RunPod to show that the worker is alive.
heartbeat.start_ping()

# Create an async session that will be closed when the worker is killed.
async with AsyncClientSession() as session:
# Create a JobScaler responsible for adjusting the concurrency
# of the worker based on the modifier callable.
job_scaler = rp_scale.JobScaler(
concurrency_modifier=config.get("concurrency_modifier", None)
)

# Create tasks for getting and running jobs.
jobtake_task = asyncio.create_task(job_scaler.get_jobs(session))
jobrun_task = asyncio.create_task(job_scaler.run_jobs(session, config))

tasks = [jobtake_task, jobrun_task]

try:
# Concurrently run both tasks and wait for both to finish.
await asyncio.gather(*tasks)
except asyncio.CancelledError: # worker is killed
# Handle the task cancellation gracefully
for task in tasks:
task.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
log.debug("Worker tasks cancelled.")
# Create a JobScaler responsible for adjusting the concurrency
job_scaler = rp_scale.JobScaler(config)
job_scaler.start()


def main(config: Dict[str, Any]) -> None:
Expand All @@ -74,4 +53,4 @@ def main(config: Dict[str, Any]) -> None:
asyncio.run(rp_local.run_local(config))

else:
asyncio.run(run_worker(config))
run_worker(config)
Loading

0 comments on commit 5a6b911

Please sign in to comment.