diff --git a/README.md b/README.md index e8213851..4e8ff142 100644 --- a/README.md +++ b/README.md @@ -180,7 +180,7 @@ LitServe supports multiple advanced state-of-the-art features. | Automatic schema validation | ✅ | | Handle timeouts | ✅ | | Handle disconnects | ✅ | -| Streaming | in progress... | +| Streaming | ✅ | > [!NOTE] > Our goal is not to jump on every hype train, but instead support features that scale diff --git a/src/litserve/api.py b/src/litserve/api.py index 3f41898e..0ae7386e 100644 --- a/src/litserve/api.py +++ b/src/litserve/api.py @@ -11,10 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import inspect from abc import ABC, abstractmethod -def no_batch_unbatch_message(obj, data): +def no_batch_unbatch_message_no_stream(obj, data): return f""" You set `max_batch_size > 1`, but the default implementation for batch() and unbatch() only supports PyTorch tensors or NumPy ndarrays, while we found {type(data)}. @@ -30,7 +31,26 @@ def unbatch(self, output): """ +def no_batch_unbatch_message_stream(obj, data): + return f""" + You set `max_batch_size > 1`, but the default implementation for batch() and unbatch() only supports + PyTorch tensors or NumPy ndarrays, while we found {type(data)}. + Please implement these two methods in {obj.__class__.__name__}. + + Example: + + def batch(self, inputs): + return np.stack(inputs) + + def unbatch(self, output): + for out in output: + yield list(out) + """ + + class LitAPI(ABC): + _stream: bool = False + @abstractmethod def setup(self, devices): """Setup the model so it can be called in `predict`.""" @@ -53,7 +73,12 @@ def batch(self, inputs): import numpy return numpy.stack(inputs) - raise NotImplementedError(no_batch_unbatch_message(self, inputs)) + + if self.stream: + message = no_batch_unbatch_message_stream(self, inputs) + else: + message = no_batch_unbatch_message_no_stream(self, inputs) + raise NotImplementedError(message) @abstractmethod def predict(self, x): @@ -63,8 +88,16 @@ def predict(self, x): def unbatch(self, output): """Convert a batched output to a list of outputs.""" if hasattr(output, "__torch_function__") or output.__class__.__name__ == "ndarray": - return list(output) - raise NotImplementedError(no_batch_unbatch_message(self, output)) + if self._stream: + yield from list(output) + else: + return list(output) + + if self.stream: + message = no_batch_unbatch_message_stream(self, output) + else: + message = no_batch_unbatch_message_no_stream(self, output) + raise NotImplementedError(message) @abstractmethod def encode_response(self, output): @@ -74,3 +107,66 @@ def encode_response(self, output): """ pass + + @property + def stream(self): + return self._stream + + @stream.setter + def stream(self, value): + self._stream = value + + def sanitize(self, max_batch_size: int): + if ( + self.stream + and max_batch_size > 1 + and not all([ + inspect.isgeneratorfunction(self.predict), + inspect.isgeneratorfunction(self.encode_response), + inspect.isgeneratorfunction(self.unbatch), + ]) + ): + raise ValueError( + """When `stream=True` with max_batch_size > 1, `lit_api.predict`, `lit_api.encode_response` and + `lit_api.unbatch` must generate values using `yield`. + + Example: + + def predict(self, inputs): + ... + for i in range(max_token_length): + yield prediction + + def encode_response(self, outputs): + for output in outputs: + encoded_output = ... + yield encoded_output + + def unbatch(self, outputs): + for output in outputs: + unbatched_output = ... + yield unbatched_output + """ + ) + + if self.stream and not all([ + inspect.isgeneratorfunction(self.predict), + inspect.isgeneratorfunction(self.encode_response), + ]): + raise ValueError( + """When `stream=True` both `lit_api.predict` and + `lit_api.encode_response` must generate values using `yield`. + + Example: + + def predict(self, inputs): + ... + for i in range(max_token_length): + yield prediction + + def encode_response(self, outputs): + for output in outputs: + encoded_output = ... + yield encoded_output + """ + ) diff --git a/src/litserve/server.py b/src/litserve/server.py index a9eb95e9..c54dbd69 100644 --- a/src/litserve/server.py +++ b/src/litserve/server.py @@ -153,11 +153,50 @@ def run_streaming_loop(lit_api, request_queue: Queue, request_buffer): pipe_s.send((pickle.dumps(e), LitAPIStatus.ERROR)) +def run_batched_streaming_loop(lit_api, request_queue: Queue, request_buffer, max_batch_size, batch_timeout): + while True: + batches = collate_requests( + lit_api, + request_queue, + request_buffer, + max_batch_size, + batch_timeout, + ) + if not batches: + continue + + inputs, pipes = zip(*batches) + + try: + x = lit_api.batch(inputs) + y_iter = lit_api.predict(x) + unbatched_iter = lit_api.unbatch(y_iter) + y_enc_iter = lit_api.encode_response(unbatched_iter) + + # y_enc_iter -> [[response-1, response-2], [response-1, response-2]] + for y_batch in y_enc_iter: + for y_enc, pipe_s in zip(y_batch, pipes): + with contextlib.suppress(BrokenPipeError): + pipe_s.send((y_enc, LitAPIStatus.OK)) + + for pipe_s in pipes: + pipe_s.send(("", LitAPIStatus.FINISH_STREAMING)) + except Exception as e: + logging.exception(e) + err = pickle.dumps(e) + for pipe_s in pipes: + pipe_s.send((err, LitAPIStatus.ERROR)) + + def inference_worker(lit_api, device, worker_id, request_queue, request_buffer, max_batch_size, batch_timeout, stream): lit_api.setup(device=device) if stream: - run_streaming_loop(lit_api, request_queue, request_buffer) + if max_batch_size > 1: + run_batched_streaming_loop(lit_api, request_queue, request_buffer, max_batch_size, batch_timeout) + else: + run_streaming_loop(lit_api, request_queue, request_buffer) return + if max_batch_size > 1: run_batched_loop(lit_api, request_queue, request_buffer, max_batch_size, batch_timeout) else: @@ -227,7 +266,7 @@ async def lifespan(app: FastAPI): class LitServer: - # TODO: add support for accelerator="auto", devices="auto" + # TODO: add support for devices="auto" def __init__( self, lit_api: LitAPI, @@ -244,31 +283,8 @@ def __init__( if max_batch_size <= 0: raise ValueError("max_batch_size must be greater than 0") - if stream and max_batch_size > 1: - raise ValueError("streaming is not supported with automatic batching at this time.") - - if stream and not all([ - inspect.isgeneratorfunction(lit_api.predict), - inspect.isgeneratorfunction(lit_api.encode_response), - ]): - raise ValueError( - """When `stream=True` both `lit_api.predict` and - `lit_api.encode_response` must generate values using `yield`. - - Example: - - def predict(self, inputs): - ... - for i in range(max_token_length): - yield prediction - - def encode_response(self, outputs): - for output in outputs: - encoded_output = ... - yield encoded_output - """ - ) - + lit_api.stream = stream + lit_api.sanitize(max_batch_size) self.app = FastAPI(lifespan=lifespan) self.app.lit_api = lit_api self.app.workers_per_device = workers_per_device diff --git a/tests/conftest.py b/tests/conftest.py index 7713d556..20aa31ee 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -53,6 +53,34 @@ def encode_response(self, output: Generator) -> Generator: yield out.lower() +class SimpleBatchedStreamAPI(LitAPI): + def setup(self, device) -> None: + self.sentence = "LitServe is streaming output" + + def decode_request(self, request: Request) -> str: + return request["prompt"] + + def batch(self, inputs): + return inputs + + def predict(self, x) -> Generator: + n = len(x) + output = self.sentence.split() + responses = [x] + for out in output: + responses.append([out] * n) + yield from responses + + def encode_response(self, output: Generator) -> Generator: + delay = 0.01 # delay for testing timeouts + for out in output: + time.sleep(delay) + yield [e.lower() for e in out] + + def unbatch(self, output): + yield from output + + @pytest.fixture() def simple_litapi(): return SimpleLitAPI() @@ -63,6 +91,11 @@ def simple_stream_api(): return SimpleStreamAPI() +@pytest.fixture() +def simple_batched_stream_api(): + return SimpleBatchedStreamAPI() + + @pytest.fixture() def lit_server(simple_litapi): return LitServer(simple_litapi, accelerator="cpu", devices=1, timeout=10) diff --git a/tests/test_lit_server.py b/tests/test_lit_server.py index 727b12e9..b8cee76d 100644 --- a/tests/test_lit_server.py +++ b/tests/test_lit_server.py @@ -28,7 +28,13 @@ from unittest.mock import patch, MagicMock from litserve.connector import _Connector -from litserve.server import inference_worker, run_single_loop, run_streaming_loop +from litserve.server import ( + inference_worker, + run_single_loop, + run_streaming_loop, + LitAPIStatus, + run_batched_streaming_loop, +) from litserve.server import LitServer import pytest @@ -146,6 +152,22 @@ async def test_stream(simple_stream_api): assert resp2.text == expected_output2, "Server returns input prompt and generated output which didn't match." +@pytest.mark.asyncio() +async def test_batched_stream_server(simple_batched_stream_api): + server = LitServer(simple_batched_stream_api, stream=True, max_batch_size=4, batch_timeout=2, timeout=30) + expected_output1 = "Hello LitServe is streaming output".lower().replace(" ", "") + expected_output2 = "World LitServe is streaming output".lower().replace(" ", "") + + async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac: + resp1 = ac.post("/stream-predict", json={"prompt": "Hello"}, timeout=10) + resp2 = ac.post("/stream-predict", json={"prompt": "World"}, timeout=10) + resp1, resp2 = await asyncio.gather(resp1, resp2) + assert resp1.status_code == 200, "Check if server is running and the request format is valid." + assert resp2.status_code == 200, "Check if server is running and the request format is valid." + assert resp1.text == expected_output1, "Server returns input prompt and generated output which didn't match." + assert resp2.text == expected_output2, "Server returns input prompt and generated output which didn't match." + + class FakeStreamPipe: def __init__(self, num_streamed_outputs): self.num_streamed_outputs = num_streamed_outputs @@ -187,6 +209,58 @@ def fake_encode(output): fake_stream_api.encode_response.assert_called_once() +class FakeBatchedStreamPipe: + def __init__(self, num_streamed_outputs): + self.num_streamed_outputs = num_streamed_outputs + self.count = 0 + + def send(self, args): + response, status = args + if status == LitAPIStatus.FINISH_STREAMING: + raise StopIteration("interrupt iteration") + if status == LitAPIStatus.ERROR and b"interrupt iteration" in response: + assert self.count == self.num_streamed_outputs, ( + f"Loop count must have incremented for " f"{self.num_streamed_outputs} times." + ) + raise StopIteration("finish streaming") + + assert ( + response == f"{self.count}" + ), f"streaming loop generates number from 0 to 9 which is sent via Pipe. {args}" + self.count += 1 + + +def test_batched_streaming_loop(loop_args): + num_streamed_outputs = 10 + + def fake_predict(inputs: list): + n = len(inputs) + for i in range(num_streamed_outputs): + yield [{"output": f"{i}"}] * n + + def fake_encode(output_iter): + assert inspect.isgenerator(output_iter), "predict function must be a generator when `stream=True`" + for outputs in output_iter: + yield [output["output"] for output in outputs] + + fake_stream_api = MagicMock() + fake_stream_api.decode_request = MagicMock(side_effect=lambda x: x["prompt"]) + fake_stream_api.batch = MagicMock(side_effect=lambda inputs: inputs) + fake_stream_api.predict = MagicMock(side_effect=fake_predict) + fake_stream_api.encode_response = MagicMock(side_effect=fake_encode) + fake_stream_api.unbatch = MagicMock(side_effect=lambda inputs: inputs) + + _, requests_queue, request_buffer = loop_args + request_buffer = Manager().dict() + request_buffer[1] = {"prompt": "Hello"}, FakeBatchedStreamPipe(num_streamed_outputs) + request_buffer[2] = {"prompt": "World"}, FakeBatchedStreamPipe(num_streamed_outputs) + + with pytest.raises(StopIteration, match="finish streaming"): + run_batched_streaming_loop(fake_stream_api, requests_queue, request_buffer, max_batch_size=2, batch_timeout=2) + fake_stream_api.predict.assert_called_once_with(("Hello", "World")) + fake_stream_api.encode_response.assert_called_once() + + def test_litapi_with_stream(simple_litapi): with pytest.raises( ValueError,