Skip to content

Commit

Permalink
Live terminal output (#5396)
Browse files Browse the repository at this point in the history
* Add /logs/raw and /logs/subscribe for getting logs on frontend
Hijacks stderr/stdout to send all output data to the client on flush

* Use existing send sync method

* Fix get_logs should return string

* Fix bug

* pass no server

* fix tests

* Fix output flush on linux
  • Loading branch information
pythongosssss authored Nov 9, 2024
1 parent dd5b57e commit 6ee066a
Show file tree
Hide file tree
Showing 5 changed files with 131 additions and 17 deletions.
29 changes: 27 additions & 2 deletions api_server/routes/internal/internal_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Optional
from folder_paths import models_dir, user_directory, output_directory, folder_names_and_paths
from api_server.services.file_service import FileService
from api_server.services.terminal_service import TerminalService
import app.logger

class InternalRoutes:
Expand All @@ -11,14 +12,17 @@ class InternalRoutes:
Check README.md for more information.
'''
def __init__(self):

def __init__(self, prompt_server):
self.routes: web.RouteTableDef = web.RouteTableDef()
self._app: Optional[web.Application] = None
self.file_service = FileService({
"models": models_dir,
"user": user_directory,
"output": output_directory
})
self.prompt_server = prompt_server
self.terminal_service = TerminalService(prompt_server)

def setup_routes(self):
@self.routes.get('/files')
Expand All @@ -34,7 +38,28 @@ async def list_files(request):

@self.routes.get('/logs')
async def get_logs(request):
return web.json_response(app.logger.get_logs())
return web.json_response("".join([(l["t"] + " - " + l["m"]) for l in app.logger.get_logs()]))

@self.routes.get('/logs/raw')
async def get_logs(request):
self.terminal_service.update_size()
return web.json_response({
"entries": list(app.logger.get_logs()),
"size": {"cols": self.terminal_service.cols, "rows": self.terminal_service.rows}
})

@self.routes.patch('/logs/subscribe')
async def subscribe_logs(request):
json_data = await request.json()
client_id = json_data["clientId"]
enabled = json_data["enabled"]
if enabled:
self.terminal_service.subscribe(client_id)
else:
self.terminal_service.unsubscribe(client_id)

return web.Response(status=200)


@self.routes.get('/folder_paths')
async def get_folder_paths(request):
Expand Down
47 changes: 47 additions & 0 deletions api_server/services/terminal_service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from app.logger import on_flush
import os


class TerminalService:
def __init__(self, server):
self.server = server
self.cols = None
self.rows = None
self.subscriptions = set()
on_flush(self.send_messages)

def update_size(self):
sz = os.get_terminal_size()
changed = False
if sz.columns != self.cols:
self.cols = sz.columns
changed = True

if sz.lines != self.rows:
self.rows = sz.lines
changed = True

if changed:
return {"cols": self.cols, "rows": self.rows}

return None

def subscribe(self, client_id):
self.subscriptions.add(client_id)

def unsubscribe(self, client_id):
self.subscriptions.discard(client_id)

def send_messages(self, entries):
if not len(entries) or not len(self.subscriptions):
return

new_size = self.update_size()

for client_id in self.subscriptions.copy(): # prevent: Set changed size during iteration
if client_id not in self.server.sockets:
# Automatically unsub if the socket has disconnected
self.unsubscribe(client_id)
continue

self.server.send_sync("logs", {"entries": entries, "size": new_size}, client_id)
64 changes: 53 additions & 11 deletions app/logger.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,73 @@
import logging
from logging.handlers import MemoryHandler
from collections import deque
from datetime import datetime
import io
import logging
import sys
import threading

logs = None
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
stdout_interceptor = None
stderr_interceptor = None


class LogInterceptor(io.TextIOWrapper):
def __init__(self, stream, *args, **kwargs):
buffer = stream.buffer
encoding = stream.encoding
super().__init__(buffer, *args, **kwargs, encoding=encoding, line_buffering=stream.line_buffering)
self._lock = threading.Lock()
self._flush_callbacks = []
self._logs_since_flush = []

def write(self, data):
entry = {"t": datetime.now().isoformat(), "m": data}
with self._lock:
self._logs_since_flush.append(entry)

# Simple handling for cr to overwrite the last output if it isnt a full line
# else logs just get full of progress messages
if isinstance(data, str) and data.startswith("\r") and not logs[-1]["m"].endswith("\n"):
logs.pop()
logs.append(entry)
super().write(data)

def flush(self):
super().flush()
for cb in self._flush_callbacks:
cb(self._logs_since_flush)
self._logs_since_flush = []

def on_flush(self, callback):
self._flush_callbacks.append(callback)


def get_logs():
return "\n".join([formatter.format(x) for x in logs])
return logs


def on_flush(callback):
if stdout_interceptor is not None:
stdout_interceptor.on_flush(callback)
if stderr_interceptor is not None:
stderr_interceptor.on_flush(callback)

def setup_logger(log_level: str = 'INFO', capacity: int = 300):
global logs
if logs:
return

# Override output streams and log to buffer
logs = deque(maxlen=capacity)

global stdout_interceptor
global stderr_interceptor
stdout_interceptor = sys.stdout = LogInterceptor(sys.stdout)
stderr_interceptor = sys.stderr = LogInterceptor(sys.stderr)

# Setup default global logger
logger = logging.getLogger()
logger.setLevel(log_level)

stream_handler = logging.StreamHandler()
stream_handler.setFormatter(logging.Formatter("%(message)s"))
logger.addHandler(stream_handler)

# Create a memory handler with a deque as its buffer
logs = deque(maxlen=capacity)
memory_handler = MemoryHandler(capacity, flushLevel=logging.INFO)
memory_handler.buffer = logs
memory_handler.setFormatter(formatter)
logger.addHandler(memory_handler)
2 changes: 1 addition & 1 deletion server.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def __init__(self, loop):
mimetypes.types_map['.js'] = 'application/javascript; charset=utf-8'

self.user_manager = UserManager()
self.internal_routes = InternalRoutes()
self.internal_routes = InternalRoutes(self)
self.supports = ["custom_nodes_from_web"]
self.prompt_queue = None
self.loop = loop
Expand Down
6 changes: 3 additions & 3 deletions tests-unit/server/routes/internal_routes_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

@pytest.fixture
def internal_routes():
return InternalRoutes()
return InternalRoutes(None)

@pytest.fixture
def aiohttp_client_factory(aiohttp_client, internal_routes):
Expand Down Expand Up @@ -102,7 +102,7 @@ async def test_file_service_initialization():
# Create a mock instance
mock_file_service_instance = MagicMock(spec=FileService)
MockFileService.return_value = mock_file_service_instance
internal_routes = InternalRoutes()
internal_routes = InternalRoutes(None)

# Check if FileService was initialized with the correct parameters
MockFileService.assert_called_once_with({
Expand All @@ -112,4 +112,4 @@ async def test_file_service_initialization():
})

# Verify that the file_service attribute of InternalRoutes is set
assert internal_routes.file_service == mock_file_service_instance
assert internal_routes.file_service == mock_file_service_instance

0 comments on commit 6ee066a

Please sign in to comment.