Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cli stream fix #184

Merged
merged 15 commits into from
Oct 28, 2023
1 change: 1 addition & 0 deletions .github/workflows/CD-publish_to_pypi.yml
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ jobs:
"runpod-workers/worker-controlnet",
"runpod-workers/worker-blip",
"runpod-workers/worker-deforum",
runpod-workers/mock-worker,
]

runs-on: ubuntu-latest
Expand Down
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
### Added

- BETA: CLI DevEx functionality to create development projects.
- `test_output` can be passed in as an arg to compare the results of `test_input`
- Generator/Streaming handlers supported with local testing

## Release 1.3.0 (10/12/23)

Expand Down
4 changes: 4 additions & 0 deletions runpod/serverless/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@ def _set_config_args(config) -> dict:
if config["rp_args"]["test_input"]:
config["rp_args"]["test_input"] = json.loads(config["rp_args"]["test_input"])

# Parse the test output from JSON
if config["rp_args"].get("test_output", None):
config["rp_args"]["test_output"] = json.loads(config["rp_args"]["test_output"])

# Set the log level
if config["rp_args"]["rp_log_level"]:
log.set_level(config["rp_args"]["rp_log_level"])
Expand Down
11 changes: 9 additions & 2 deletions runpod/serverless/modules/rp_fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
from fastapi.encoders import jsonable_encoder
from pydantic import BaseModel

from .rp_job import run_job
from .rp_handler import is_generator
from .rp_job import run_job, run_job_generator
from .worker_state import Jobs
from .rp_ping import Heartbeat
from ...version import __version__ as runpod_version
Expand Down Expand Up @@ -125,7 +126,13 @@ async def _debug_run(self, job: TestJob):
# Set the current job ID.
job_list.add_job(job.id)

job_results = await run_job(self.config["handler"], job.__dict__)
if is_generator(self.config["handler"]):
generator_output = run_job_generator(self.config["handler"], job.__dict__)
job_results = {"output": []}
async for stream_output in generator_output:
job_results["output"].append(stream_output["output"])
else:
job_results = await run_job(self.config["handler"], job.__dict__)

job_results["id"] = job.id

Expand Down
8 changes: 8 additions & 0 deletions runpod/serverless/modules/rp_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
"""Retrieve handler info. """

import inspect
from typing import Callable

def is_generator(handler: Callable) -> bool:
"""Check if handler is a generator function. """
return inspect.isgeneratorfunction(handler) or inspect.isasyncgenfunction(handler)
8 changes: 8 additions & 0 deletions runpod/serverless/modules/rp_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,5 +47,13 @@ async def run_local(config: Dict[str, Any]) -> None:
log.info(f"Job {local_job['id']} completed successfully.")
log.info(f"Job result: {job_result}")

# Compare to sample output, if provided
if config['rp_args'].get('test_output', None):
log.info("test_output set, comparing output to test_output.")
if job_result != config['rp_args']['test_output']:
log.error("Job output does not match test_output.")
sys.exit(1)
log.info("Job output matches test_output.")

log.info("Local testing complete, exiting.")
sys.exit(0)
4 changes: 4 additions & 0 deletions runpod/serverless/modules/rp_ping.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ def start_ping(self, test=False):
'''
Sends heartbeat pings to the Runpod server.
'''
if os.environ.get('RUNPOD_AI_API_KEY') is None:
log.debug("Not deployed on RunPod serverless, pings will not be sent.")
return

if os.environ.get('RUNPOD_POD_ID') is None:
log.info("Not running on RunPod, pings will not be sent.")
return
Expand Down
7 changes: 3 additions & 4 deletions runpod/serverless/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@
"""
import os
import asyncio
import inspect
from typing import Dict, Any

import aiohttp

from runpod.serverless.modules.rp_logger import RunPodLogger
from runpod.serverless.modules.rp_scale import JobScaler
from .modules import rp_local
from .modules.rp_handler import is_generator
from .modules.rp_ping import Heartbeat
from .modules.rp_job import run_job, run_job_generator
from .modules.rp_http import send_result, stream_result
Expand Down Expand Up @@ -46,11 +46,10 @@ def _is_local(config) -> bool:


async def _process_job(job, session, job_scaler, config):
if inspect.isgeneratorfunction(config["handler"]) \
or inspect.isasyncgenfunction(config["handler"]):
if is_generator(config["handler"]):
generator_output = run_job_generator(config["handler"], job)

log.debug("Handler is a generator, streaming results.")

job_result = {'output': []}
async for stream_output in generator_output:
if 'error' in stream_output:
Expand Down
12 changes: 12 additions & 0 deletions tests/test_serverless/test_modules/test_fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,4 +87,16 @@ def test_run(self):

self.assertTrue(mock_ping.called)

# Test with generator handler
def generator_handler(job):
del job
yield {"result": "success"}

generator_worker_api = rp_fastapi.WorkerAPI(handler=generator_handler)
generator_run_return = asyncio.run(generator_worker_api._debug_run(job_object)) # pylint: disable=protected-access
assert generator_run_return == {
"id": "test_job_id",
"output": [{"result": "success"}]
}

loop.close()
33 changes: 33 additions & 0 deletions tests/test_serverless/test_modules/test_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
""" Unit tests for the handler module.
"""
import unittest

from runpod.serverless.modules.rp_handler import is_generator


class TestIsGenerator(unittest.TestCase):
"""Tests for the is_generator function."""

def test_regular_function(self):
"""Test that a regular function is not a generator."""
def regular_func():
return "I'm a regular function!"
self.assertFalse(is_generator(regular_func))

def test_generator_function(self):
"""Test that a generator function is a generator."""
def generator_func():
yield "I'm a generator function!"
self.assertTrue(is_generator(generator_func))

def test_async_function(self):
"""Test that an async function is not a generator."""
async def async_func():
return "I'm an async function!"
self.assertFalse(is_generator(async_func))

def test_async_generator_function(self):
"""Test that an async generator function is a generator."""
async def async_gen_func():
yield "I'm an async generator function!"
self.assertTrue(is_generator(async_gen_func))
12 changes: 10 additions & 2 deletions tests/test_serverless/test_modules/test_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
class TestRunLocal(IsolatedAsyncioTestCase):
''' Tests for run_local function '''

@patch("runpod.serverless.modules.rp_local.run_job", return_value={})
@patch("runpod.serverless.modules.rp_local.run_job", return_value={"result": "success"})
@patch("builtins.open", new_callable=mock_open, read_data='{"input": "test"}')
async def test_run_local_with_test_input(self, mock_file, mock_run):
'''
Expand All @@ -21,12 +21,20 @@ async def test_run_local_with_test_input(self, mock_file, mock_run):
"test_input": {
"input": "test",
"id": "test_id"
},
"test_output": {
"result": "success"
}
}
}
with self.assertRaises(SystemExit) as sys_exit:
await rp_local.run_local(config)
self.assertEqual(sys_exit.exception.code, 0)
self.assertEqual(sys_exit.exception.code, 0)

config["rp_args"]["test_output"] = {"result": "fail"}
with self.assertRaises(SystemExit) as sys_exit:
await rp_local.run_local(config)
self.assertEqual(sys_exit.exception.code, 1)

assert mock_file.called is False
assert mock_run.called
Expand Down
7 changes: 7 additions & 0 deletions tests/test_serverless/test_modules/test_ping.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,13 @@ def test_start_ping(self, mock_get_return):
'''
Tests that the start_ping function works correctly
'''
# No RUNPOD_AI_API_KEY case
with patch("threading.Thread.start") as mock_thread_start:
rp_ping.Heartbeat().start_ping(test=True)
assert mock_thread_start.call_count == 0

os.environ["RUNPOD_AI_API_KEY"] = "test_key"

# No RUNPOD_POD_ID case
with patch("threading.Thread.start") as mock_thread_start:
rp_ping.Heartbeat().start_ping(test=True)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_serverless/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ def test_local_api(self):
'''
Test local FastAPI setup.
'''

known_args = argparse.Namespace()
known_args.rp_log_level = None
known_args.rp_debugger = None
Expand Down Expand Up @@ -126,6 +125,7 @@ def test_worker_bad_local(self):
known_args.rp_api_concurrency = 1
known_args.rp_api_host = "localhost"
known_args.test_input = '{"test": "test"}'
known_args.test_output = '{"test": "test"}'

with patch("argparse.ArgumentParser.parse_known_args") as mock_parse_known_args, \
self.assertRaises(SystemExit):
Expand Down
Loading