Skip to content

Commit

Permalink
change code for all successfully responses case
Browse files Browse the repository at this point in the history
Signed-off-by: zane-neo <[email protected]>
  • Loading branch information
zane-neo committed Apr 26, 2024
1 parent 5e6a7f6 commit a7a980c
Showing 1 changed file with 7 additions and 49 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,9 @@

import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.security.AccessController;
import java.security.PrivilegedActionException;
import java.security.PrivilegedExceptionAction;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.TreeMap;

import org.apache.http.HttpStatus;
import org.apache.logging.log4j.util.Strings;
Expand Down Expand Up @@ -139,52 +135,14 @@ private void processResponse(
}
}

// Only all requests successful case will be processed here.
private void reOrderTensorResponses(Map<Integer, ModelTensors> tensorOutputs) {
List<ModelTensors> modelTensors = new ArrayList<>();
TreeMap<Integer, ModelTensors> sortedMap = new TreeMap<>(tensorOutputs);
log.debug("Reordered tensor outputs size is {}", sortedMap.size());
if (tensorOutputs.size() == 1) {
// batch API case
ModelTensors singleTensor = tensorOutputs.get(0);
int status = singleTensor.getStatusCode();
if (status == HttpStatus.SC_OK) {
modelTensors.add(singleTensor);
actionListener.onResponse(modelTensors);
} else {
actionListener.onFailure(buildOpenSearchStatusException(singleTensor));
}
} else {
// non batch API. This is to follow the previously code logic. Previously when making multiple requests to remote model,
// either one fails, we will return a failure response.
OpenSearchStatusException openSearchStatusException = null;
for (Map.Entry<Integer, ModelTensors> entry : sortedMap.entrySet()) {
if (entry.getValue().getStatusCode() < HttpStatus.SC_OK
|| entry.getValue().getStatusCode() > HttpStatus.SC_MULTIPLE_CHOICES) {
openSearchStatusException = buildOpenSearchStatusException(entry.getValue());
break;
}
modelTensors.add(entry.getKey(), entry.getValue());
}
if (openSearchStatusException != null) {
actionListener.onFailure(openSearchStatusException);
} else {
actionListener.onResponse(modelTensors);
}
}
}

private OpenSearchStatusException buildOpenSearchStatusException(ModelTensors modelTensors) {
try {
return new OpenSearchStatusException(
AccessController
.doPrivileged(
(PrivilegedExceptionAction<String>) () -> GSON.toJson(modelTensors.getMlModelTensors().get(0).getDataAsMap())
),
RestStatus.fromCode(modelTensors.getStatusCode())
);
} catch (PrivilegedActionException e) {
return new OpenSearchStatusException(e.getMessage(), RestStatus.fromCode(statusCode));
ModelTensors[] modelTensors = new ModelTensors[tensorOutputs.size()];
log.debug("Reordered tensor outputs size is {}", tensorOutputs.size());
for (Map.Entry<Integer, ModelTensors> entry : tensorOutputs.entrySet()) {
modelTensors[entry.getKey()] = entry.getValue();
}
actionListener.onResponse(Arrays.asList(modelTensors));
}

protected class MLResponseSubscriber implements Subscriber<ByteBuffer> {
Expand Down

0 comments on commit a7a980c

Please sign in to comment.