Skip to content

Commit

Permalink
Fix security plugin not initialized issue
Browse files Browse the repository at this point in the history
Signed-off-by: zane-neo <[email protected]>
  • Loading branch information
zane-neo committed Apr 22, 2024
1 parent 7722020 commit ad1f2ee
Showing 1 changed file with 48 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
import java.time.Instant;
import java.util.UUID;

import org.apache.commons.lang3.exception.ExceptionUtils;
import org.opensearch.OpenSearchException;
import org.opensearch.OpenSearchSecurityException;
import org.opensearch.ResourceNotFoundException;
import org.opensearch.action.ActionListenerResponseHandler;
import org.opensearch.action.get.GetRequest;
Expand All @@ -29,6 +31,8 @@
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.commons.authuser.User;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.io.stream.Writeable;
import org.opensearch.core.transport.TransportResponse;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.ml.breaker.MLCircuitBreakerService;
Expand All @@ -55,6 +59,7 @@
import org.opensearch.ml.stats.MLNodeLevelStat;
import org.opensearch.ml.stats.MLStats;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.TransportException;
import org.opensearch.transport.TransportResponseHandler;
import org.opensearch.transport.TransportService;

Expand Down Expand Up @@ -108,15 +113,41 @@ protected String getTransportActionName() {

@Override
protected TransportResponseHandler<MLTaskResponse> getResponseHandler(ActionListener<MLTaskResponse> listener) {
return new ActionListenerResponseHandler<>(listener, MLTaskResponse::new);
return new RetryableActionListenerResponseHandler<>(listener, MLTaskResponse::new);
}

@Override
public void dispatchTask(
public static final class RetryableActionListenerResponseHandler<T extends TransportResponse> extends ActionListenerResponseHandler<T> {

private Runnable runnable;
private int retryCount = 0;

public RetryableActionListenerResponseHandler(ActionListener listener, Writeable.Reader<T> reader) {
super(listener, reader);
}
public RetryableActionListenerResponseHandler(ActionListener listener, Writeable.Reader reader, int retryCount) {
super(listener, reader);
this.retryCount = retryCount;
}

@Override
public void handleException(TransportException exp) {
log.debug("Failed to execute ML predict task and started retry, current retry count is: {}", retryCount, exp);
if (runnable == null || retryCount >= 3) {
super.handleException(exp);
return;
}
if (ExceptionUtils.indexOfThrowable(exp, OpenSearchSecurityException.class) != -1) {
runnable.run();
}
}
}

public void retryableDispatchTask(
FunctionName functionName,
MLPredictionTaskRequest request,
TransportService transportService,
ActionListener<MLTaskResponse> listener
ActionListener<MLTaskResponse> listener,
int retryCount
) {
String modelId = request.getModelId();
try {
Expand All @@ -128,7 +159,9 @@ public void dispatchTask(
} else {
log.debug("Execute ML predict request {} remotely on node {}", request.getRequestID(), node.getId());
request.setDispatchTask(false);
transportService.sendRequest(node, getTransportActionName(), request, getResponseHandler(listener));
RetryableActionListenerResponseHandler<MLTaskResponse> handler = new RetryableActionListenerResponseHandler<>(listener, MLTaskResponse::new, retryCount);
if (handler.runnable == null) handler.runnable = () -> retryableDispatchTask(functionName, request, transportService, listener, retryCount + 1);
transportService.sendRequest(node, getTransportActionName(), request, handler);
}
}, e -> { listener.onFailure(e); });
String[] workerNodes = mlModelManager.getWorkerNodes(modelId, functionName, true);
Expand All @@ -152,6 +185,16 @@ public void dispatchTask(
}
}

@Override
public void dispatchTask(
FunctionName functionName,
MLPredictionTaskRequest request,
TransportService transportService,
ActionListener<MLTaskResponse> listener
) {
retryableDispatchTask(functionName, request, transportService, listener, 0);
}

/**
* Start prediction task
* @param request MLPredictionTaskRequest
Expand Down

0 comments on commit ad1f2ee

Please sign in to comment.