Skip to content

Commit 9ae7886

Browse files
streaming implementation working
1 parent a033b08 commit 9ae7886

File tree

2 files changed

+60
-22
lines changed

2 files changed

+60
-22
lines changed

src/rb-api/rb/api/routes/cli.py

+44-22
Original file line numberDiff line numberDiff line change
@@ -1,71 +1,92 @@
11
import inspect
2-
from typing import Callable, Optional
2+
from typing import Callable, Generator, Optional
33

44
import typer
55
from fastapi import APIRouter, HTTPException
6+
from fastapi.responses import StreamingResponse
67
from makefun import with_signature
78
from rb.api.models import CommandResult
89
from rb.lib.stdout import Capturing # type: ignore
10+
from rb.lib.stdout import capture_stdout_as_generator
911

1012
from rescuebox.main import app as rescuebox_app
1113

1214
cli_router = APIRouter()
1315

1416

1517
def static_endpoint(callback: Callable, *args, **kwargs) -> CommandResult:
18+
"""Execute a CLI command and return the result synchronously"""
1619
with Capturing() as stdout:
1720
try:
1821
result = callback(*args, **kwargs)
19-
success = True
20-
error = None
22+
return CommandResult(result=result, stdout=stdout, success=True, error=None)
2123
except Exception as e:
22-
result = None
23-
success = False
24-
error = f"Typer CLI aborted {e}"
25-
return CommandResult(result=result, stdout=stdout, success=success, error=error)
24+
return CommandResult(
25+
result=None,
26+
stdout=stdout,
27+
success=False,
28+
error=f"Typer CLI aborted {e}",
29+
)
30+
31+
32+
def streaming_endpoint(callback: Callable, *args, **kwargs) -> Generator:
33+
"""Execute a CLI command and stream the results"""
34+
line_buffer = []
35+
for line in capture_stdout_as_generator(callback, *args, **kwargs):
36+
line_buffer.append(line)
37+
result = CommandResult(
38+
result=line, stdout=line_buffer, success=True, error=None
39+
)
40+
yield result.model_dump_json()
2641

2742

2843
def command_callback(command: typer.models.CommandInfo):
44+
"""Create a FastAPI endpoint handler for a Typer CLI command"""
2945
# Get the original callback signature
3046
original_signature = inspect.signature(command.callback)
3147

32-
# Modify the signature to include `streaming`
48+
# Add streaming parameter to signature
3349
new_params = list(original_signature.parameters.values())
34-
new_params.append(
35-
inspect.Parameter(
36-
"streaming",
37-
inspect.Parameter.KEYWORD_ONLY,
38-
default=False,
39-
annotation=Optional[bool],
40-
)
50+
streaming_param = inspect.Parameter(
51+
"streaming",
52+
inspect.Parameter.KEYWORD_ONLY,
53+
default=False,
54+
annotation=Optional[bool],
4155
)
56+
new_params.append(streaming_param)
4257
new_signature = original_signature.replace(parameters=new_params)
4358

44-
# Create a new function with the modified signature
4559
@with_signature(new_signature)
4660
def wrapper(*args, **kwargs):
47-
# Extract additional parameters
48-
# TODO(Jagath): Implement streaming
49-
streaming = kwargs.pop("streaming", False) # noqa: F841
61+
streaming = kwargs.pop("streaming", False)
62+
63+
if streaming:
64+
return StreamingResponse(
65+
streaming_endpoint(command.callback, *args, **kwargs)
66+
)
5067

51-
# Call the static endpoint with the wrapped callback and arguments
5268
result = static_endpoint(command.callback, *args, **kwargs)
5369
if not result.success:
54-
# Return the last 10 lines of stdout if there's an error
5570
raise HTTPException(
5671
status_code=400,
57-
detail={"error": result.error, "stdout": result.stdout[-10:]},
72+
detail={
73+
"error": result.error,
74+
"stdout": result.stdout[-10:], # Last 10 lines of output
75+
},
5876
)
5977
return result
6078

79+
# Preserve original function metadata
6180
wrapper.__name__ = command.callback.__name__
6281
wrapper.__doc__ = command.callback.__doc__
6382

6483
return wrapper
6584

6685

86+
# Register routes for each plugin command
6787
for plugin in rescuebox_app.registered_groups:
6888
router = APIRouter()
89+
6990
for command in plugin.typer_instance.registered_commands:
7091
router.add_api_route(
7192
f"/{command.callback.__name__}/",
@@ -74,4 +95,5 @@ def wrapper(*args, **kwargs):
7495
name=command.callback.__name__,
7596
response_model=CommandResult,
7697
)
98+
7799
cli_router.include_router(router, prefix=f"/{plugin.name}", tags=[plugin.name])

src/rb-lib/rb/lib/stdout.py

+16
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
1+
import contextlib
2+
import io
13
import sys
24
from io import StringIO
5+
from typing import Callable, Generator
36

47

58
class Capturing(list):
@@ -12,3 +15,16 @@ def __exit__(self, *args):
1215
self.extend(self._stringio.getvalue().splitlines())
1316
del self._stringio # free up some memory
1417
sys.stdout = self._stdout
18+
19+
20+
def capture_stdout_as_generator(
21+
func: Callable, *args, **kwargs
22+
) -> Generator[str, None, None]:
23+
stdout_buffer = io.StringIO()
24+
with contextlib.redirect_stdout(stdout_buffer):
25+
func(*args, **kwargs)
26+
27+
# Go to the start of the buffer and yield each line
28+
stdout_buffer.seek(0)
29+
for line in stdout_buffer:
30+
yield line

0 commit comments

Comments
 (0)