From ad1f2ee4b10932b138d96eb9c4103a21551e5d1f Mon Sep 17 00:00:00 2001 From: zane-neo Date: Mon, 22 Apr 2024 18:14:23 +0800 Subject: [PATCH] Fix security plugin not initialized issue Signed-off-by: zane-neo --- .../ml/task/MLPredictTaskRunner.java | 53 +++++++++++++++++-- 1 file changed, 48 insertions(+), 5 deletions(-) diff --git a/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java b/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java index 92e05a5ba9..acd64f31d4 100644 --- a/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java +++ b/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java @@ -15,7 +15,9 @@ import java.time.Instant; import java.util.UUID; +import org.apache.commons.lang3.exception.ExceptionUtils; import org.opensearch.OpenSearchException; +import org.opensearch.OpenSearchSecurityException; import org.opensearch.ResourceNotFoundException; import org.opensearch.action.ActionListenerResponseHandler; import org.opensearch.action.get.GetRequest; @@ -29,6 +31,8 @@ import org.opensearch.common.xcontent.XContentType; import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.core.transport.TransportResponse; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.breaker.MLCircuitBreakerService; @@ -55,6 +59,7 @@ import org.opensearch.ml.stats.MLNodeLevelStat; import org.opensearch.ml.stats.MLStats; import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportException; import org.opensearch.transport.TransportResponseHandler; import org.opensearch.transport.TransportService; @@ -108,15 +113,41 @@ protected String getTransportActionName() { @Override protected TransportResponseHandler getResponseHandler(ActionListener listener) { - return new ActionListenerResponseHandler<>(listener, MLTaskResponse::new); + return new RetryableActionListenerResponseHandler<>(listener, MLTaskResponse::new); } - @Override - public void dispatchTask( + public static final class RetryableActionListenerResponseHandler extends ActionListenerResponseHandler { + + private Runnable runnable; + private int retryCount = 0; + + public RetryableActionListenerResponseHandler(ActionListener listener, Writeable.Reader reader) { + super(listener, reader); + } + public RetryableActionListenerResponseHandler(ActionListener listener, Writeable.Reader reader, int retryCount) { + super(listener, reader); + this.retryCount = retryCount; + } + + @Override + public void handleException(TransportException exp) { + log.debug("Failed to execute ML predict task and started retry, current retry count is: {}", retryCount, exp); + if (runnable == null || retryCount >= 3) { + super.handleException(exp); + return; + } + if (ExceptionUtils.indexOfThrowable(exp, OpenSearchSecurityException.class) != -1) { + runnable.run(); + } + } + } + + public void retryableDispatchTask( FunctionName functionName, MLPredictionTaskRequest request, TransportService transportService, - ActionListener listener + ActionListener listener, + int retryCount ) { String modelId = request.getModelId(); try { @@ -128,7 +159,9 @@ public void dispatchTask( } else { log.debug("Execute ML predict request {} remotely on node {}", request.getRequestID(), node.getId()); request.setDispatchTask(false); - transportService.sendRequest(node, getTransportActionName(), request, getResponseHandler(listener)); + RetryableActionListenerResponseHandler handler = new RetryableActionListenerResponseHandler<>(listener, MLTaskResponse::new, retryCount); + if (handler.runnable == null) handler.runnable = () -> retryableDispatchTask(functionName, request, transportService, listener, retryCount + 1); + transportService.sendRequest(node, getTransportActionName(), request, handler); } }, e -> { listener.onFailure(e); }); String[] workerNodes = mlModelManager.getWorkerNodes(modelId, functionName, true); @@ -152,6 +185,16 @@ public void dispatchTask( } } + @Override + public void dispatchTask( + FunctionName functionName, + MLPredictionTaskRequest request, + TransportService transportService, + ActionListener listener + ) { + retryableDispatchTask(functionName, request, transportService, listener, 0); + } + /** * Start prediction task * @param request MLPredictionTaskRequest