Skip to content

Commit

Permalink
feat: Add CSRF token to terminal interpreter for CSWSH
Browse files Browse the repository at this point in the history
Apply synchronizer token pattern to terminal interpreter for CSWSH protection
  • Loading branch information
seung-00 committed Nov 3, 2024
1 parent 3575a3c commit 3a03a12
Show file tree
Hide file tree
Showing 10 changed files with 153 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -169,11 +169,12 @@ public void createTerminalDashboard(String noteId, String paragraphId, String ho
Jinjava jinjava = new Jinjava();
HashMap<String, Object> jinjaParams = new HashMap();
Date now = new Date();
String terminalServerUrl = generateOrigin(hostIp, port) +
"?noteId=" + noteId + "&paragraphId=" + paragraphId + "&t=" + now.getTime();
String terminalServerUrl = generateOrigin(hostIp, port) + "/terminal-ui" +
"?noteId=" + noteId + "&paragraphId=" + paragraphId + "&t=" + now.getTime();
jinjaParams.put("HOST_NAME", hostName);
jinjaParams.put("HOST_IP", hostIp);
jinjaParams.put("TERMINAL_SERVER_URL", terminalServerUrl);

String terminalDashboardTemplate = jinjava.render(template, jinjaParams);

LOGGER.info(terminalDashboardTemplate);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package org.apache.zeppelin.shell.terminal;

import java.util.Map;
import java.util.Objects;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import org.apache.zeppelin.shell.terminal.websocket.TerminalSocket;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class TerminalCsrfTokenManager {
private static TerminalCsrfTokenManager instance;

private static final Logger LOGGER = LoggerFactory.getLogger(TerminalSocket.class);

private final Map<String, String> csrfTokens = new ConcurrentHashMap<>();


public static synchronized TerminalCsrfTokenManager getInstance(){
if (instance == null) {
instance = new TerminalCsrfTokenManager();
}
return instance;
}

public String generateToken(String noteId, String paragraphId) {
String key = formatId(noteId, paragraphId);
return csrfTokens.computeIfAbsent(key, k -> UUID.randomUUID().toString());
}

public boolean validateToken(String noteId, String paragraphId, String token) {
if (token == null) {
LOGGER.warn("Received null CSRF token for validation");
return false;
}

String storedToken = csrfTokens.get(formatId(noteId, paragraphId));
return Objects.equals(storedToken, token);
}

private String formatId(String noteId, String paragraphId) {
return noteId + "@" + paragraphId;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ public class TerminalManager {
// NoteId@ParagraphId -> InterpreterContext
private HashMap<String, InterpreterContext> noteParagraphId2IntpContext;

private String csrfToken;

private static TerminalManager instance;

public static synchronized TerminalManager getInstance(){
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package org.apache.zeppelin.shell.terminal;

import com.google.common.io.Resources;
import com.hubspot.jinjava.Jinjava;
import java.io.IOException;
import java.net.URL;
import java.nio.charset.StandardCharsets;
import java.util.HashMap;
import java.util.Map;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class TerminalServlet extends HttpServlet {

private Jinjava jinjava = new Jinjava();

@Override
protected void doGet(
HttpServletRequest request,
HttpServletResponse response
)
throws ServletException, IOException {
URL urlTemplate = Resources.getResource("ui_templates/terminal-ui.jinja");
String template = Resources.toString(urlTemplate, StandardCharsets.UTF_8);

String noteId = request.getParameter("noteId");
String paragraphId = request.getParameter("paragraphId");

String csrfToken = TerminalCsrfTokenManager.getInstance().generateToken(noteId, paragraphId);

Map<String, Object> context = new HashMap<>();
context.put("CSRF_TOKEN", csrfToken);

String renderedTemplate = jinjava.render(template, context);

response.setContentType("text/html; charset=UTF-8");
response.setStatus(HttpServletResponse.SC_OK);
response.getWriter().write(renderedTemplate);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,13 @@ public void run() {
connector.setPort(port);
jettyServer.addConnector(connector);

ServletContextHandler context = new ServletContextHandler(ServletContextHandler.SESSIONS);
context.setContextPath("/terminal/");
ServletContextHandler terminalSocketContext = new ServletContextHandler(ServletContextHandler.SESSIONS);
terminalSocketContext.setContextPath("/terminal/");

// We look for a file, as ClassLoader.getResource() is not
// designed to look for directories (we resolve the directory later)
ClassLoader clazz = TerminalThread.class.getClassLoader();
URL url = clazz.getResource("html");
URL url = clazz.getResource("static");
if (url == null) {
throw new RuntimeException("Unable to find resource directory");
}
Expand All @@ -68,14 +68,17 @@ public void run() {
String webRootUri = url.toExternalForm();
LOGGER.info("WebRoot is " + webRootUri);
// debug
// webRootUri = "/home/hadoop/zeppelin-current/interpreter/sh";
resourceHandler.setResourceBase(webRootUri);

HandlerCollection handlers = new HandlerCollection(context, resourceHandler);
ServletContextHandler terminalWebContext = new ServletContextHandler(ServletContextHandler.SESSIONS);
terminalWebContext.setContextPath("/terminal-ui/");
terminalWebContext.addServlet(TerminalServlet.class, "/");

HandlerCollection handlers = new HandlerCollection(terminalSocketContext, terminalWebContext, resourceHandler);
jettyServer.setHandler(handlers);

try {
ServerContainer container = WebSocketServerContainerInitializer.configureContext(context);
ServerContainer container = WebSocketServerContainerInitializer.configureContext(terminalSocketContext);
container.addEndpoint(
ServerEndpointConfig.Builder.create(TerminalSocket.class, "/")
.configurator(new TerminalSessionConfigurator(allwedOrigin))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

import com.google.gson.Gson;
import com.google.gson.reflect.TypeToken;
import java.io.IOException;
import org.apache.zeppelin.shell.terminal.TerminalCsrfTokenManager;
import org.apache.zeppelin.shell.terminal.TerminalManager;
import org.apache.zeppelin.shell.terminal.service.TerminalService;
import org.slf4j.Logger;
Expand Down Expand Up @@ -54,7 +56,7 @@ public void onWebSocketConnect(Session sess) {
}

@OnMessage
public void onWebSocketText(String message) {
public void onWebSocketText(String message, Session sess) throws IOException {
if (LOGGER.isDebugEnabled()) {
LOGGER.debug("Received TEXT message: " + message);
}
Expand All @@ -74,6 +76,12 @@ public void onWebSocketText(String message) {
terminalService.onTerminalReady();
this.noteId = messageMap.get("noteId");
this.paragraphId = messageMap.get("paragraphId");
String csrfToken = messageMap.get("csrfToken");
if (!isValidCsrfToken(csrfToken)) {
LOGGER.error("Invalid CSRF token: " + csrfToken);
sess.close(new CloseReason(CloseReason.CloseCodes.CANNOT_ACCEPT, "Invalid CSRF Token"));
return;
}
terminalManager.onWebSocketConnect(noteId, paragraphId);
break;
case "TERMINAL_COMMAND":
Expand Down Expand Up @@ -108,4 +116,8 @@ private Map<String, String> getMessageMap(String message) {
new TypeToken<Map<String, String>>(){}.getType());
return map;
}

private boolean isValidCsrfToken(String csrfToken) {
return TerminalCsrfTokenManager.getInstance().validateToken(noteId, paragraphId, csrfToken);
}
}
32 changes: 0 additions & 32 deletions shell/src/main/resources/html/index.html

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,10 @@ let app = {
},
onTerminalReady() {
// alert("TERMINAL_READY");
const csrfToken = document.querySelector('meta[name="csrf-token"]').content;

ws.send(action("TERMINAL_READY", {
noteId, paragraphId
noteId, paragraphId, csrfToken
}));
}
};
Expand Down
35 changes: 35 additions & 0 deletions shell/src/main/resources/ui_templates/terminal-ui.jinja
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
{#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-#}
<html>
<header>
<meta name="csrf-token" content="{{ CSRF_TOKEN }}">
<style>
* {
margin: 0;
padding: 0;
}
html, body {
height: 100%;
}
</style>
<script src="/hterm_all.js"></script>
<script src="/index.js"></script>
</header>
<body>
<div id="terminal" style="position:relative; width:100%; height:100%"></div>
</body>
</html>

0 comments on commit 3a03a12

Please sign in to comment.