Skip to content

Commit

Permalink
Include websocket tests by default
Browse files Browse the repository at this point in the history
Signed-off-by: Nijat K <[email protected]>
  • Loading branch information
NeejWeej committed Nov 26, 2024
1 parent d70dce8 commit bd24ec8
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 73 deletions.
20 changes: 12 additions & 8 deletions cpp/csp/adapters/websocket/WebsocketEndpointManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,8 @@ void WebsocketEndpointManager::shutdownEndpoint(const std::string& endpoint_id)
m_endpoints.erase(endpoint_it);
std::stringstream ss;
ss << "No more connections for endpoint={" << endpoint_id << "} Shutting down...";
m_mgr -> pushStatus(StatusLevel::INFO, ClientStatusType::CLOSED, ss.str());
std::string msg = ss.str();
m_mgr -> pushStatus(StatusLevel::INFO, ClientStatusType::CLOSED, msg);
}

void WebsocketEndpointManager::setupEndpoint(const std::string& endpoint_id,
Expand Down Expand Up @@ -135,9 +136,10 @@ void WebsocketEndpointManager::setupEndpoint(const std::string& endpoint_id,
// should only happen if persist is False
if ( !payload.empty() )
endpoint -> send(payload);

m_mgr -> pushStatus(StatusLevel::INFO, ClientStatusType::ACTIVE,
"Connected successfully for endpoint={" + endpoint_id +"}");
std::stringstream ss;
ss << "Connected successfully for endpoint={" << endpoint_id << "}";
std::string msg = ss.str();
m_mgr -> pushStatus(StatusLevel::INFO, ClientStatusType::ACTIVE, msg);
// We remove the caller id, if it was the only one, then we shut down the endpoint
if( !persist )
removeEndpointForCallerId(endpoint_id, is_consumer, validated_id);
Expand Down Expand Up @@ -170,8 +172,9 @@ void WebsocketEndpointManager::setupEndpoint(const std::string& endpoint_id,
stored_endpoint -> setOnSendFail(
[ this, endpoint_id ]( const std::string& s ) {
std::stringstream ss;
ss << "Error: " << s << " for " << endpoint_id;
m_mgr -> pushStatus( StatusLevel::ERROR, ClientStatusType::MESSAGE_SEND_FAIL, ss.str() );
ss << "Error: " << s << " for endpoint={" << endpoint_id << "}";
std::string msg = ss.str();
m_mgr -> pushStatus( StatusLevel::ERROR, ClientStatusType::MESSAGE_SEND_FAIL, msg );
}
);
stored_endpoint -> run();
Expand Down Expand Up @@ -214,10 +217,11 @@ void WebsocketEndpointManager::handleEndpointFailure(const std::string& endpoint

std::stringstream ss;
ss << "Connection Failure for endpoint={" << endpoint_id << "} Due to: " << reason;
std::string msg = ss.str();
if ( status_type == ClientStatusType::CLOSED || status_type == ClientStatusType::ACTIVE )
m_mgr -> pushStatus(StatusLevel::INFO, status_type, ss.str());
m_mgr -> pushStatus(StatusLevel::INFO, status_type, msg);
else{
m_mgr -> pushStatus(StatusLevel::ERROR, status_type, ss.str());
m_mgr -> pushStatus(StatusLevel::ERROR, status_type, msg);
}
};

Expand Down
6 changes: 4 additions & 2 deletions csp/adapters/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,17 @@
from csp.lib import _websocketadapterimpl

from .dynamic_adapter_utils import AdapterInfo
from .websocket_types import ActionType, ConnectionRequest, WebsocketHeaderUpdate, WebsocketStatus # noqa
from .websocket_types import ActionType, ConnectionRequest, WebsocketHeaderUpdate, WebsocketStatus

# InternalConnectionRequest,
_ = (
ActionType,
BytesMessageProtoMapper,
DateTimeType,
JSONTextMessageMapper,
RawBytesMessageMapper,
RawTextMessageMapper,
WebsocketStatus,
)
T = TypeVar("T")

Expand Down Expand Up @@ -577,7 +579,7 @@ def update_headers(self, x: ts[List[WebsocketHeaderUpdate]]):

def status(self, push_mode=csp.PushMode.NON_COLLAPSING):
ts_type = Status
return status_adapter_def(self, ts_type, push_mode=push_mode)
return status_adapter_def(self, ts_type, push_mode)

def _create(self, engine, memo):
"""method needs to return the wrapped c++ adapter manager"""
Expand Down
125 changes: 62 additions & 63 deletions csp/tests/adapters/test_websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,77 +2,76 @@
import pytest
import pytz
import threading
import tornado.ioloop
import tornado.web
import tornado.websocket
from contextlib import contextmanager
from datetime import datetime, timedelta
from typing import List, Optional, Type

import csp
from csp import ts

if os.environ.get("CSP_TEST_WEBSOCKET"):
import tornado.ioloop
import tornado.web
import tornado.websocket

from csp.adapters.websocket import (
ActionType,
ConnectionRequest,
JSONTextMessageMapper,
RawTextMessageMapper,
Status,
WebsocketAdapterManager,
WebsocketHeaderUpdate,
WebsocketStatus,
)

class EchoWebsocketHandler(tornado.websocket.WebSocketHandler):
def on_message(self, msg):
# Carve-out to allow inspecting the headers
if msg == "header1":
msg = self.request.headers.get(msg, "")
elif not isinstance(msg, str) and msg.decode("utf-8") == "header1":
# Need this for bytes
msg = self.request.headers.get("header1", "")
return self.write_message(msg)

@contextmanager
def create_tornado_server(port: int):
"""Base context manager for creating a Tornado server in a thread"""
ready_event = threading.Event()
io_loop = None
app = None
io_thread = None

def run_io_loop():
nonlocal io_loop, app
io_loop = tornado.ioloop.IOLoop()
io_loop.make_current()
app = tornado.web.Application([(r"/", EchoWebsocketHandler)])
app.listen(port)
ready_event.set()
io_loop.start()

io_thread = threading.Thread(target=run_io_loop)
io_thread.start()
ready_event.wait()

try:
yield io_loop, app, io_thread
finally:
io_loop.add_callback(io_loop.stop)
if io_thread:
io_thread.join(timeout=5)
if io_thread.is_alive():
raise RuntimeError("IOLoop failed to stop")

@contextmanager
def tornado_server(port: int = 8001):
"""Simplified context manager that uses the base implementation"""
with create_tornado_server(port) as (_io_loop, _app, _io_thread):
yield
from csp.adapters.websocket import (
ActionType,
ConnectionRequest,
JSONTextMessageMapper,
RawTextMessageMapper,
Status,
WebsocketAdapterManager,
WebsocketHeaderUpdate,
WebsocketStatus,
)


class EchoWebsocketHandler(tornado.websocket.WebSocketHandler):
def on_message(self, msg):
# Carve-out to allow inspecting the headers
if msg == "header1":
msg = self.request.headers.get(msg, "")
elif not isinstance(msg, str) and msg.decode("utf-8") == "header1":
# Need this for bytes
msg = self.request.headers.get("header1", "")
return self.write_message(msg)


@contextmanager
def create_tornado_server(port: int):
"""Base context manager for creating a Tornado server in a thread"""
ready_event = threading.Event()
io_loop = None
app = None
io_thread = None

def run_io_loop():
nonlocal io_loop, app
io_loop = tornado.ioloop.IOLoop()
io_loop.make_current()
app = tornado.web.Application([(r"/", EchoWebsocketHandler)])
app.listen(port)
ready_event.set()
io_loop.start()

io_thread = threading.Thread(target=run_io_loop)
io_thread.start()
ready_event.wait()

try:
yield io_loop, app, io_thread
finally:
io_loop.add_callback(io_loop.stop)
if io_thread:
io_thread.join(timeout=5)
if io_thread.is_alive():
raise RuntimeError("IOLoop failed to stop")


@contextmanager
def tornado_server(port: int = 8001):
"""Simplified context manager that uses the base implementation"""
with create_tornado_server(port) as (_io_loop, _app, _io_thread):
yield


@pytest.mark.skipif(os.environ.get("CSP_TEST_WEBSOCKET") is None, reason="'CSP_TEST_WEBSOCKET' env variable is not set")
class TestWebsocket:
@pytest.fixture(scope="class", autouse=True)
def setup_tornado(self, request):
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ develop = [
"sqlalchemy", # db
"threadpoolctl", # test_random
"tornado", # profiler, perspective, websocket
"python-rapidjson", # websocket
# type checking
"pydantic>=2",
]
Expand Down

0 comments on commit bd24ec8

Please sign in to comment.