Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Configurable write timeout for websocket server #1554

Merged
merged 11 commits into from
Jul 15, 2021
5 changes: 3 additions & 2 deletions rskj-core/src/main/java/co/rsk/RskContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -1595,12 +1595,13 @@ private Web3WebSocketServer getWeb3WebSocketServer() {
new BlockchainBranchComparator(getBlockStore())
)
);
RskJsonRpcHandler jsonRpcHandler = new RskJsonRpcHandler(emitter, jsonRpcSerializer);
RskWebSocketJsonRpcHandler jsonRpcHandler = new RskWebSocketJsonRpcHandler(emitter, jsonRpcSerializer);
fedejinich marked this conversation as resolved.
Show resolved Hide resolved
web3WebSocketServer = new Web3WebSocketServer(
rskSystemProperties.rpcWebSocketBindAddress(),
rskSystemProperties.rpcWebSocketPort(),
jsonRpcHandler,
getJsonRpcWeb3ServerHandler()
getJsonRpcWeb3ServerHandler(),
rskSystemProperties.rpcWebSocketServerWriteTimeoutSeconds()
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import co.rsk.rpc.modules.RskJsonRpcRequestVisitor;
import co.rsk.rpc.modules.eth.subscribe.EthSubscribeRequest;
import co.rsk.rpc.modules.eth.subscribe.EthUnsubscribeRequest;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufHolder;
import io.netty.buffer.ByteBufInputStream;
import io.netty.channel.ChannelHandler.Sharable;
Expand All @@ -48,25 +49,25 @@
*/

@Sharable
public class RskJsonRpcHandler
public class RskWebSocketJsonRpcHandler
extends SimpleChannelInboundHandler<ByteBufHolder>
implements RskJsonRpcRequestVisitor {
private static final Logger LOGGER = LoggerFactory.getLogger(RskJsonRpcHandler.class);
private static final Logger LOGGER = LoggerFactory.getLogger(RskWebSocketJsonRpcHandler.class);

private final EthSubscriptionNotificationEmitter emitter;
private final JsonRpcSerializer serializer;

public RskJsonRpcHandler(EthSubscriptionNotificationEmitter emitter, JsonRpcSerializer serializer) {
public RskWebSocketJsonRpcHandler(EthSubscriptionNotificationEmitter emitter, JsonRpcSerializer serializer) {
this.emitter = emitter;
this.serializer = serializer;
}

@Override
protected void channelRead0(ChannelHandlerContext ctx, ByteBufHolder msg) {
try {
RskJsonRpcRequest request = serializer.deserializeRequest(
new ByteBufInputStream(msg.copy().content())
);
ByteBuf content = msg.copy().content();

try (ByteBufInputStream source = new ByteBufInputStream(content)){
RskJsonRpcRequest request = serializer.deserializeRequest(source);

// TODO(mc) we should support the ModuleDescription method filters
JsonRpcResultOrError resultOrError = request.accept(this, ctx);
Expand All @@ -75,10 +76,13 @@ protected void channelRead0(ChannelHandlerContext ctx, ByteBufHolder msg) {
return;
} catch (IOException e) {
LOGGER.trace("Not a known or valid JsonRpcRequest", e);

// We need to release this resource, netty only takes care about 'ByteBufHolder msg'
content.release(content.refCnt());
}

// delegate to the next handler if the message can't be matched to a known JSON-RPC request
ctx.fireChannelRead(msg.retain());
ctx.fireChannelRead(msg);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package co.rsk.rpc.netty;

import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler;
import io.netty.handler.timeout.WriteTimeoutException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class RskWebSocketServerProtocolHandler extends WebSocketServerProtocolHandler {
private static final Logger LOGGER = LoggerFactory.getLogger(RskWebSocketServerProtocolHandler.class);
public static final String WRITE_TIMEOUT_REASON = "Exceeded write timout";
public static final int NORMAL_CLOSE_WEBSOCKET_STATUS = 1000;

public RskWebSocketServerProtocolHandler(String websocketPath) {
super(websocketPath);
}

@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
if(cause instanceof WriteTimeoutException) {
ctx.writeAndFlush(new CloseWebSocketFrame(NORMAL_CLOSE_WEBSOCKET_STATUS, WRITE_TIMEOUT_REASON)).addListener(ChannelFutureListener.CLOSE);
LOGGER.error("Write timeout exceeded, closing web socket channel", cause);
} else {
super.exceptionCaught(ctx, cause);
}
}
}
22 changes: 14 additions & 8 deletions rskj-core/src/main/java/co/rsk/rpc/netty/Web3WebSocketServer.java
Original file line number Diff line number Diff line change
Expand Up @@ -28,35 +28,40 @@
import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.handler.codec.http.HttpObjectAggregator;
import io.netty.handler.codec.http.HttpServerCodec;
import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler;
import io.netty.handler.timeout.WriteTimeoutHandler;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.annotation.Nullable;
import java.net.InetAddress;
import java.util.concurrent.TimeUnit;

public class Web3WebSocketServer implements InternalService {
private static final Logger logger = LoggerFactory.getLogger(Web3WebSocketServer.class);
private static final int HTTP_MAX_CONTENT_LENGTH = 1024 * 1024 * 5;

private final InetAddress host;
private final int port;
private final RskJsonRpcHandler jsonRpcHandler;
private final RskWebSocketJsonRpcHandler webSocketJsonRpcHandler;
private final JsonRpcWeb3ServerHandler web3ServerHandler;
private final EventLoopGroup bossGroup;
private final EventLoopGroup workerGroup;
private @Nullable ChannelFuture webSocketChannel;
private final int serverWriteTimeoutSeconds;

public Web3WebSocketServer(
InetAddress host,
int port,
RskJsonRpcHandler jsonRpcHandler,
JsonRpcWeb3ServerHandler web3ServerHandler) {
RskWebSocketJsonRpcHandler webSocketJsonRpcHandler,
JsonRpcWeb3ServerHandler web3ServerHandler,
int serverWriteTimeoutSeconds) {
this.host = host;
this.port = port;
this.jsonRpcHandler = jsonRpcHandler;
this.webSocketJsonRpcHandler = webSocketJsonRpcHandler;
this.web3ServerHandler = web3ServerHandler;
this.bossGroup = new NioEventLoopGroup();
this.workerGroup = new NioEventLoopGroup();
this.serverWriteTimeoutSeconds = serverWriteTimeoutSeconds;
}

@Override
Expand All @@ -70,9 +75,10 @@ public void start() {
protected void initChannel(SocketChannel ch) throws Exception {
ChannelPipeline p = ch.pipeline();
p.addLast(new HttpServerCodec());
p.addLast(new HttpObjectAggregator(1024 * 1024 * 5));
p.addLast(new WebSocketServerProtocolHandler("/websocket"));
p.addLast(jsonRpcHandler);
p.addLast(new HttpObjectAggregator(HTTP_MAX_CONTENT_LENGTH));
p.addLast(new WriteTimeoutHandler(serverWriteTimeoutSeconds, TimeUnit.SECONDS));
p.addLast(new RskWebSocketServerProtocolHandler("/websocket"));
p.addLast(webSocketJsonRpcHandler);
p.addLast(web3ServerHandler);
p.addLast(new Web3ResultWebSocketResponseHandler());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ public abstract class SystemProperties {
private static final String PROPERTY_RPC_WEBSOCKET_ENABLED = "rpc.providers.web.ws.enabled";
private static final String PROPERTY_RPC_WEBSOCKET_ADDRESS = "rpc.providers.web.ws.bind_address";
private static final String PROPERTY_RPC_WEBSOCKET_PORT = "rpc.providers.web.ws.port";
private static final String PROPERTY_RPC_WEBSOCKET_SERVER_WRITE_TIMEOUT_SECONDS = "rpc.providers.web.ws.server_write_timeout_seconds";

public static final String PROPERTY_PUBLIC_IP = "public.ip";
public static final String PROPERTY_BIND_ADDRESS = "bind_address";
Expand Down Expand Up @@ -612,6 +613,10 @@ public int rpcWebSocketPort() {
return configFromFiles.getInt(PROPERTY_RPC_WEBSOCKET_PORT);
}

public int rpcWebSocketServerWriteTimeoutSeconds() {
return configFromFiles.getInt(PROPERTY_RPC_WEBSOCKET_SERVER_WRITE_TIMEOUT_SECONDS);
}

public InetAddress rpcHttpBindAddress() {
return getWebBindAddress(PROPERTY_RPC_HTTP_ADDRESS);
}
Expand Down
1 change: 1 addition & 0 deletions rskj-core/src/main/resources/expected.conf
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,7 @@ rpc = {
enabled = <enabled>
bind_address = <bind_address>
port = <port>
server_write_timeout_seconds = <timeout>
}
}
}
Expand Down
2 changes: 2 additions & 0 deletions rskj-core/src/main/resources/reference.conf
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,8 @@ rpc {
enabled = false
bind_address = localhost
port = 4445
# Shuts down the server when it's not able to write a response after a certain period (expressed in seconds)
server_write_timeout_seconds = 30
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
import static org.junit.Assert.assertThat;
import static org.mockito.Mockito.*;

public class RskJsonRpcHandlerTest {
public class RskWebSocketJsonRpcHandlerTest {
private static final SubscriptionId SAMPLE_SUBSCRIPTION_ID = new SubscriptionId("0x3075");
private static final EthSubscribeRequest SAMPLE_SUBSCRIBE_REQUEST = new EthSubscribeRequest(
JsonRpcVersion.V2_0,
Expand All @@ -47,15 +47,15 @@ public class RskJsonRpcHandlerTest {

);

private RskJsonRpcHandler handler;
private RskWebSocketJsonRpcHandler handler;
private EthSubscriptionNotificationEmitter emitter;
private JsonRpcSerializer serializer;

@Before
public void setUp() {
emitter = mock(EthSubscriptionNotificationEmitter.class);
serializer = mock(JsonRpcSerializer.class);
handler = new RskJsonRpcHandler(emitter, serializer);
handler = new RskWebSocketJsonRpcHandler(emitter, serializer);
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
*/
package co.rsk.rpc.netty;

import co.rsk.config.TestSystemProperties;
import co.rsk.rpc.JacksonBasedRpcSerializer;
import co.rsk.rpc.ModuleDescription;
import com.fasterxml.jackson.core.JsonProcessingException;
Expand Down Expand Up @@ -46,14 +47,14 @@
import java.util.concurrent.atomic.AtomicReference;

import static org.hamcrest.Matchers.is;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.fail;
import static org.junit.Assert.*;
import static org.mockito.Mockito.*;

public class Web3WebSocketServerTest {

private static JsonNodeFactory JSON_NODE_FACTORY = JsonNodeFactory.instance;
private static ObjectMapper OBJECT_MAPPER = new ObjectMapper();
private static final int DEFAULT_WRITE_TIMEOUT_SECONDS = 30;

private ExecutorService wsExecutor;

Expand All @@ -68,13 +69,24 @@ public void smokeTest() throws Exception {
String mockResult = "output";
when(web3Mock.web3_sha3(anyString())).thenReturn(mockResult);

int randomPort = 9998;//new ServerSocket(0).getLocalPort();
int randomPort = 9998;

fedejinich marked this conversation as resolved.
Show resolved Hide resolved
TestSystemProperties testSystemProperties = new TestSystemProperties();

List<ModuleDescription> filteredModules = Collections.singletonList(new ModuleDescription("web3", "1.0", true, Collections.emptyList(), Collections.emptyList()));
RskJsonRpcHandler handler = new RskJsonRpcHandler(null, new JacksonBasedRpcSerializer());
RskWebSocketJsonRpcHandler handler = new RskWebSocketJsonRpcHandler(null, new JacksonBasedRpcSerializer());
JsonRpcWeb3ServerHandler serverHandler = new JsonRpcWeb3ServerHandler(web3Mock, filteredModules);
int serverWriteTimeoutSeconds = testSystemProperties.rpcWebSocketServerWriteTimeoutSeconds();

assertEquals(DEFAULT_WRITE_TIMEOUT_SECONDS, serverWriteTimeoutSeconds);

Web3WebSocketServer websocketServer = new Web3WebSocketServer(InetAddress.getLoopbackAddress(), randomPort, handler, serverHandler);
Web3WebSocketServer websocketServer = new Web3WebSocketServer(
InetAddress.getLoopbackAddress(),
randomPort,
handler,
serverHandler,
serverWriteTimeoutSeconds
);
websocketServer.start();

OkHttpClient wsClient = new OkHttpClient();
Expand Down