diff --git a/.github/workflows/CD-publish_to_pypi.yml b/.github/workflows/CD-publish_to_pypi.yml index d00b533d..356f449b 100644 --- a/.github/workflows/CD-publish_to_pypi.yml +++ b/.github/workflows/CD-publish_to_pypi.yml @@ -69,6 +69,7 @@ jobs: "runpod-workers/worker-controlnet", "runpod-workers/worker-blip", "runpod-workers/worker-deforum", + runpod-workers/mock-worker, ] runs-on: ubuntu-latest diff --git a/CHANGELOG.md b/CHANGELOG.md index 9dea7cf7..3d08acf9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,8 @@ ### Added - BETA: CLI DevEx functionality to create development projects. +- `test_output` can be passed in as an arg to compare the results of `test_input` +- Generator/Streaming handlers supported with local testing ## Release 1.3.0 (10/12/23) diff --git a/runpod/serverless/__init__.py b/runpod/serverless/__init__.py index dfd86a7b..abf60633 100644 --- a/runpod/serverless/__init__.py +++ b/runpod/serverless/__init__.py @@ -66,6 +66,10 @@ def _set_config_args(config) -> dict: if config["rp_args"]["test_input"]: config["rp_args"]["test_input"] = json.loads(config["rp_args"]["test_input"]) + # Parse the test output from JSON + if config["rp_args"].get("test_output", None): + config["rp_args"]["test_output"] = json.loads(config["rp_args"]["test_output"]) + # Set the log level if config["rp_args"]["rp_log_level"]: log.set_level(config["rp_args"]["rp_log_level"]) diff --git a/runpod/serverless/modules/rp_fastapi.py b/runpod/serverless/modules/rp_fastapi.py index 1ba83f44..aa4b1a8f 100644 --- a/runpod/serverless/modules/rp_fastapi.py +++ b/runpod/serverless/modules/rp_fastapi.py @@ -9,7 +9,8 @@ from fastapi.encoders import jsonable_encoder from pydantic import BaseModel -from .rp_job import run_job +from .rp_handler import is_generator +from .rp_job import run_job, run_job_generator from .worker_state import Jobs from .rp_ping import Heartbeat from ...version import __version__ as runpod_version @@ -125,7 +126,13 @@ async def _debug_run(self, job: TestJob): # Set the current job ID. job_list.add_job(job.id) - job_results = await run_job(self.config["handler"], job.__dict__) + if is_generator(self.config["handler"]): + generator_output = run_job_generator(self.config["handler"], job.__dict__) + job_results = {"output": []} + async for stream_output in generator_output: + job_results["output"].append(stream_output["output"]) + else: + job_results = await run_job(self.config["handler"], job.__dict__) job_results["id"] = job.id diff --git a/runpod/serverless/modules/rp_handler.py b/runpod/serverless/modules/rp_handler.py new file mode 100644 index 00000000..40f62128 --- /dev/null +++ b/runpod/serverless/modules/rp_handler.py @@ -0,0 +1,8 @@ +"""Retrieve handler info. """ + +import inspect +from typing import Callable + +def is_generator(handler: Callable) -> bool: + """Check if handler is a generator function. """ + return inspect.isgeneratorfunction(handler) or inspect.isasyncgenfunction(handler) diff --git a/runpod/serverless/modules/rp_local.py b/runpod/serverless/modules/rp_local.py index 5716531d..7db3c69c 100644 --- a/runpod/serverless/modules/rp_local.py +++ b/runpod/serverless/modules/rp_local.py @@ -47,5 +47,13 @@ async def run_local(config: Dict[str, Any]) -> None: log.info(f"Job {local_job['id']} completed successfully.") log.info(f"Job result: {job_result}") + # Compare to sample output, if provided + if config['rp_args'].get('test_output', None): + log.info("test_output set, comparing output to test_output.") + if job_result != config['rp_args']['test_output']: + log.error("Job output does not match test_output.") + sys.exit(1) + log.info("Job output matches test_output.") + log.info("Local testing complete, exiting.") sys.exit(0) diff --git a/runpod/serverless/modules/rp_ping.py b/runpod/serverless/modules/rp_ping.py index 5d1a6e96..9e26f6d8 100644 --- a/runpod/serverless/modules/rp_ping.py +++ b/runpod/serverless/modules/rp_ping.py @@ -52,6 +52,10 @@ def start_ping(self, test=False): ''' Sends heartbeat pings to the Runpod server. ''' + if os.environ.get('RUNPOD_AI_API_KEY') is None: + log.debug("Not deployed on RunPod serverless, pings will not be sent.") + return + if os.environ.get('RUNPOD_POD_ID') is None: log.info("Not running on RunPod, pings will not be sent.") return diff --git a/runpod/serverless/worker.py b/runpod/serverless/worker.py index 82d6883f..f8694863 100644 --- a/runpod/serverless/worker.py +++ b/runpod/serverless/worker.py @@ -4,7 +4,6 @@ """ import os import asyncio -import inspect from typing import Dict, Any import aiohttp @@ -12,6 +11,7 @@ from runpod.serverless.modules.rp_logger import RunPodLogger from runpod.serverless.modules.rp_scale import JobScaler from .modules import rp_local +from .modules.rp_handler import is_generator from .modules.rp_ping import Heartbeat from .modules.rp_job import run_job, run_job_generator from .modules.rp_http import send_result, stream_result @@ -46,11 +46,10 @@ def _is_local(config) -> bool: async def _process_job(job, session, job_scaler, config): - if inspect.isgeneratorfunction(config["handler"]) \ - or inspect.isasyncgenfunction(config["handler"]): + if is_generator(config["handler"]): generator_output = run_job_generator(config["handler"], job) - log.debug("Handler is a generator, streaming results.") + job_result = {'output': []} async for stream_output in generator_output: if 'error' in stream_output: diff --git a/tests/test_serverless/test_modules/test_fastapi.py b/tests/test_serverless/test_modules/test_fastapi.py index e85bf99b..d8dedc80 100644 --- a/tests/test_serverless/test_modules/test_fastapi.py +++ b/tests/test_serverless/test_modules/test_fastapi.py @@ -87,4 +87,16 @@ def test_run(self): self.assertTrue(mock_ping.called) + # Test with generator handler + def generator_handler(job): + del job + yield {"result": "success"} + + generator_worker_api = rp_fastapi.WorkerAPI(handler=generator_handler) + generator_run_return = asyncio.run(generator_worker_api._debug_run(job_object)) # pylint: disable=protected-access + assert generator_run_return == { + "id": "test_job_id", + "output": [{"result": "success"}] + } + loop.close() diff --git a/tests/test_serverless/test_modules/test_handler.py b/tests/test_serverless/test_modules/test_handler.py new file mode 100644 index 00000000..d055a337 --- /dev/null +++ b/tests/test_serverless/test_modules/test_handler.py @@ -0,0 +1,33 @@ +""" Unit tests for the handler module. +""" +import unittest + +from runpod.serverless.modules.rp_handler import is_generator + + +class TestIsGenerator(unittest.TestCase): + """Tests for the is_generator function.""" + + def test_regular_function(self): + """Test that a regular function is not a generator.""" + def regular_func(): + return "I'm a regular function!" + self.assertFalse(is_generator(regular_func)) + + def test_generator_function(self): + """Test that a generator function is a generator.""" + def generator_func(): + yield "I'm a generator function!" + self.assertTrue(is_generator(generator_func)) + + def test_async_function(self): + """Test that an async function is not a generator.""" + async def async_func(): + return "I'm an async function!" + self.assertFalse(is_generator(async_func)) + + def test_async_generator_function(self): + """Test that an async generator function is a generator.""" + async def async_gen_func(): + yield "I'm an async generator function!" + self.assertTrue(is_generator(async_gen_func)) diff --git a/tests/test_serverless/test_modules/test_local.py b/tests/test_serverless/test_modules/test_local.py index f242b4b5..cba0ffcf 100644 --- a/tests/test_serverless/test_modules/test_local.py +++ b/tests/test_serverless/test_modules/test_local.py @@ -9,7 +9,7 @@ class TestRunLocal(IsolatedAsyncioTestCase): ''' Tests for run_local function ''' - @patch("runpod.serverless.modules.rp_local.run_job", return_value={}) + @patch("runpod.serverless.modules.rp_local.run_job", return_value={"result": "success"}) @patch("builtins.open", new_callable=mock_open, read_data='{"input": "test"}') async def test_run_local_with_test_input(self, mock_file, mock_run): ''' @@ -21,12 +21,20 @@ async def test_run_local_with_test_input(self, mock_file, mock_run): "test_input": { "input": "test", "id": "test_id" + }, + "test_output": { + "result": "success" } } } with self.assertRaises(SystemExit) as sys_exit: await rp_local.run_local(config) - self.assertEqual(sys_exit.exception.code, 0) + self.assertEqual(sys_exit.exception.code, 0) + + config["rp_args"]["test_output"] = {"result": "fail"} + with self.assertRaises(SystemExit) as sys_exit: + await rp_local.run_local(config) + self.assertEqual(sys_exit.exception.code, 1) assert mock_file.called is False assert mock_run.called diff --git a/tests/test_serverless/test_modules/test_ping.py b/tests/test_serverless/test_modules/test_ping.py index 3e57560c..751d45ea 100644 --- a/tests/test_serverless/test_modules/test_ping.py +++ b/tests/test_serverless/test_modules/test_ping.py @@ -46,6 +46,13 @@ def test_start_ping(self, mock_get_return): ''' Tests that the start_ping function works correctly ''' + # No RUNPOD_AI_API_KEY case + with patch("threading.Thread.start") as mock_thread_start: + rp_ping.Heartbeat().start_ping(test=True) + assert mock_thread_start.call_count == 0 + + os.environ["RUNPOD_AI_API_KEY"] = "test_key" + # No RUNPOD_POD_ID case with patch("threading.Thread.start") as mock_thread_start: rp_ping.Heartbeat().start_ping(test=True) diff --git a/tests/test_serverless/test_worker.py b/tests/test_serverless/test_worker.py index a94a67e5..888022bf 100644 --- a/tests/test_serverless/test_worker.py +++ b/tests/test_serverless/test_worker.py @@ -75,7 +75,6 @@ def test_local_api(self): ''' Test local FastAPI setup. ''' - known_args = argparse.Namespace() known_args.rp_log_level = None known_args.rp_debugger = None @@ -126,6 +125,7 @@ def test_worker_bad_local(self): known_args.rp_api_concurrency = 1 known_args.rp_api_host = "localhost" known_args.test_input = '{"test": "test"}' + known_args.test_output = '{"test": "test"}' with patch("argparse.ArgumentParser.parse_known_args") as mock_parse_known_args, \ self.assertRaises(SystemExit):