Skip to content

Commit

Permalink
Merge pull request #184 from runpod/cli-stream-fix
Browse files Browse the repository at this point in the history
Cli stream fix
  • Loading branch information
justinmerrell authored Oct 28, 2023
2 parents 1ba2d85 + d3012ae commit 2190a85
Show file tree
Hide file tree
Showing 13 changed files with 102 additions and 9 deletions.
1 change: 1 addition & 0 deletions .github/workflows/CD-publish_to_pypi.yml
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ jobs:
"runpod-workers/worker-controlnet",
"runpod-workers/worker-blip",
"runpod-workers/worker-deforum",
runpod-workers/mock-worker,
]

runs-on: ubuntu-latest
Expand Down
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
### Added

- BETA: CLI DevEx functionality to create development projects.
- `test_output` can be passed in as an arg to compare the results of `test_input`
- Generator/Streaming handlers supported with local testing

## Release 1.3.0 (10/12/23)

Expand Down
4 changes: 4 additions & 0 deletions runpod/serverless/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@ def _set_config_args(config) -> dict:
if config["rp_args"]["test_input"]:
config["rp_args"]["test_input"] = json.loads(config["rp_args"]["test_input"])

# Parse the test output from JSON
if config["rp_args"].get("test_output", None):
config["rp_args"]["test_output"] = json.loads(config["rp_args"]["test_output"])

# Set the log level
if config["rp_args"]["rp_log_level"]:
log.set_level(config["rp_args"]["rp_log_level"])
Expand Down
11 changes: 9 additions & 2 deletions runpod/serverless/modules/rp_fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
from fastapi.encoders import jsonable_encoder
from pydantic import BaseModel

from .rp_job import run_job
from .rp_handler import is_generator
from .rp_job import run_job, run_job_generator
from .worker_state import Jobs
from .rp_ping import Heartbeat
from ...version import __version__ as runpod_version
Expand Down Expand Up @@ -125,7 +126,13 @@ async def _debug_run(self, job: TestJob):
# Set the current job ID.
job_list.add_job(job.id)

job_results = await run_job(self.config["handler"], job.__dict__)
if is_generator(self.config["handler"]):
generator_output = run_job_generator(self.config["handler"], job.__dict__)
job_results = {"output": []}
async for stream_output in generator_output:
job_results["output"].append(stream_output["output"])
else:
job_results = await run_job(self.config["handler"], job.__dict__)

job_results["id"] = job.id

Expand Down
8 changes: 8 additions & 0 deletions runpod/serverless/modules/rp_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
"""Retrieve handler info. """

import inspect
from typing import Callable

def is_generator(handler: Callable) -> bool:
"""Check if handler is a generator function. """
return inspect.isgeneratorfunction(handler) or inspect.isasyncgenfunction(handler)
8 changes: 8 additions & 0 deletions runpod/serverless/modules/rp_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,5 +47,13 @@ async def run_local(config: Dict[str, Any]) -> None:
log.info(f"Job {local_job['id']} completed successfully.")
log.info(f"Job result: {job_result}")

# Compare to sample output, if provided
if config['rp_args'].get('test_output', None):
log.info("test_output set, comparing output to test_output.")
if job_result != config['rp_args']['test_output']:
log.error("Job output does not match test_output.")
sys.exit(1)
log.info("Job output matches test_output.")

log.info("Local testing complete, exiting.")
sys.exit(0)
4 changes: 4 additions & 0 deletions runpod/serverless/modules/rp_ping.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ def start_ping(self, test=False):
'''
Sends heartbeat pings to the Runpod server.
'''
if os.environ.get('RUNPOD_AI_API_KEY') is None:
log.debug("Not deployed on RunPod serverless, pings will not be sent.")
return

if os.environ.get('RUNPOD_POD_ID') is None:
log.info("Not running on RunPod, pings will not be sent.")
return
Expand Down
7 changes: 3 additions & 4 deletions runpod/serverless/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@
"""
import os
import asyncio
import inspect
from typing import Dict, Any

import aiohttp

from runpod.serverless.modules.rp_logger import RunPodLogger
from runpod.serverless.modules.rp_scale import JobScaler
from .modules import rp_local
from .modules.rp_handler import is_generator
from .modules.rp_ping import Heartbeat
from .modules.rp_job import run_job, run_job_generator
from .modules.rp_http import send_result, stream_result
Expand Down Expand Up @@ -46,11 +46,10 @@ def _is_local(config) -> bool:


async def _process_job(job, session, job_scaler, config):
if inspect.isgeneratorfunction(config["handler"]) \
or inspect.isasyncgenfunction(config["handler"]):
if is_generator(config["handler"]):
generator_output = run_job_generator(config["handler"], job)

log.debug("Handler is a generator, streaming results.")

job_result = {'output': []}
async for stream_output in generator_output:
if 'error' in stream_output:
Expand Down
12 changes: 12 additions & 0 deletions tests/test_serverless/test_modules/test_fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,4 +87,16 @@ def test_run(self):

self.assertTrue(mock_ping.called)

# Test with generator handler
def generator_handler(job):
del job
yield {"result": "success"}

generator_worker_api = rp_fastapi.WorkerAPI(handler=generator_handler)
generator_run_return = asyncio.run(generator_worker_api._debug_run(job_object)) # pylint: disable=protected-access
assert generator_run_return == {
"id": "test_job_id",
"output": [{"result": "success"}]
}

loop.close()
33 changes: 33 additions & 0 deletions tests/test_serverless/test_modules/test_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
""" Unit tests for the handler module.
"""
import unittest

from runpod.serverless.modules.rp_handler import is_generator


class TestIsGenerator(unittest.TestCase):
"""Tests for the is_generator function."""

def test_regular_function(self):
"""Test that a regular function is not a generator."""
def regular_func():
return "I'm a regular function!"
self.assertFalse(is_generator(regular_func))

def test_generator_function(self):
"""Test that a generator function is a generator."""
def generator_func():
yield "I'm a generator function!"
self.assertTrue(is_generator(generator_func))

def test_async_function(self):
"""Test that an async function is not a generator."""
async def async_func():
return "I'm an async function!"
self.assertFalse(is_generator(async_func))

def test_async_generator_function(self):
"""Test that an async generator function is a generator."""
async def async_gen_func():
yield "I'm an async generator function!"
self.assertTrue(is_generator(async_gen_func))
12 changes: 10 additions & 2 deletions tests/test_serverless/test_modules/test_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
class TestRunLocal(IsolatedAsyncioTestCase):
''' Tests for run_local function '''

@patch("runpod.serverless.modules.rp_local.run_job", return_value={})
@patch("runpod.serverless.modules.rp_local.run_job", return_value={"result": "success"})
@patch("builtins.open", new_callable=mock_open, read_data='{"input": "test"}')
async def test_run_local_with_test_input(self, mock_file, mock_run):
'''
Expand All @@ -21,12 +21,20 @@ async def test_run_local_with_test_input(self, mock_file, mock_run):
"test_input": {
"input": "test",
"id": "test_id"
},
"test_output": {
"result": "success"
}
}
}
with self.assertRaises(SystemExit) as sys_exit:
await rp_local.run_local(config)
self.assertEqual(sys_exit.exception.code, 0)
self.assertEqual(sys_exit.exception.code, 0)

config["rp_args"]["test_output"] = {"result": "fail"}
with self.assertRaises(SystemExit) as sys_exit:
await rp_local.run_local(config)
self.assertEqual(sys_exit.exception.code, 1)

assert mock_file.called is False
assert mock_run.called
Expand Down
7 changes: 7 additions & 0 deletions tests/test_serverless/test_modules/test_ping.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,13 @@ def test_start_ping(self, mock_get_return):
'''
Tests that the start_ping function works correctly
'''
# No RUNPOD_AI_API_KEY case
with patch("threading.Thread.start") as mock_thread_start:
rp_ping.Heartbeat().start_ping(test=True)
assert mock_thread_start.call_count == 0

os.environ["RUNPOD_AI_API_KEY"] = "test_key"

# No RUNPOD_POD_ID case
with patch("threading.Thread.start") as mock_thread_start:
rp_ping.Heartbeat().start_ping(test=True)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_serverless/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ def test_local_api(self):
'''
Test local FastAPI setup.
'''

known_args = argparse.Namespace()
known_args.rp_log_level = None
known_args.rp_debugger = None
Expand Down Expand Up @@ -126,6 +125,7 @@ def test_worker_bad_local(self):
known_args.rp_api_concurrency = 1
known_args.rp_api_host = "localhost"
known_args.test_input = '{"test": "test"}'
known_args.test_output = '{"test": "test"}'

with patch("argparse.ArgumentParser.parse_known_args") as mock_parse_known_args, \
self.assertRaises(SystemExit):
Expand Down

0 comments on commit 2190a85

Please sign in to comment.