From b3dfb381b152edf77dbd9c88d67db1910f274ab1 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Thu, 16 Nov 2023 14:39:14 -0800 Subject: [PATCH] feat: Add vModelId to PayloadProcessor Payload Motivation Currently the payloads passed to PayloadProcessors only contain the modelId, which in the case of vModels will be a "resolved" modelId corresponding to a particular model revision (in particular this will be true when used in KServe modelmesh-serving). It would be useful to include the vModelId too. Modifications Add a vModelId field to the Payload class and correspondingly update built-in PayloadProcessor implementations where applicable. It may be null if the request was directed at a concrete modelId rather than a vModelId. Result Both modelId and vModelId are available to PayloadProcessors Signed-off-by: Nick Hill --- .../ibm/watson/modelmesh/ModelMeshApi.java | 8 +-- .../payload/MatchingPayloadProcessor.java | 49 ++++++++++++------- .../ibm/watson/modelmesh/payload/Payload.java | 20 ++++++++ .../payload/RemotePayloadProcessor.java | 21 ++++---- 4 files changed, 67 insertions(+), 31 deletions(-) diff --git a/src/main/java/com/ibm/watson/modelmesh/ModelMeshApi.java b/src/main/java/com/ibm/watson/modelmesh/ModelMeshApi.java index 715c0efe..0c437a17 100644 --- a/src/main/java/com/ibm/watson/modelmesh/ModelMeshApi.java +++ b/src/main/java/com/ibm/watson/modelmesh/ModelMeshApi.java @@ -767,7 +767,7 @@ public void onHalfClose() { } finally { if (payloadProcessor != null) { processPayload(reqMessage.readerIndex(reqReaderIndex), - requestId, resolvedModelId, methodName, headers, null, true); + requestId, resolvedModelId, vModelId, methodName, headers, null, true); } else { releaseReqMessage(); } @@ -803,7 +803,7 @@ public void onHalfClose() { data = response.data.readerIndex(respReaderIndex); metadata = response.metadata; } - processPayload(data, requestId, resolvedModelId, methodName, metadata, status, releaseResponse); + processPayload(data, requestId, resolvedModelId, vModelId, methodName, metadata, status, releaseResponse); } else if (releaseResponse && response != null) { response.release(); } @@ -829,7 +829,7 @@ public void onHalfClose() { * @param status null for requests, non-null for responses * @param takeOwnership whether the processor should take ownership */ - private void processPayload(ByteBuf data, String payloadId, String modelId, String methodName, + private void processPayload(ByteBuf data, String payloadId, String vModelId, String modelId, String methodName, Metadata metadata, io.grpc.Status status, boolean takeOwnership) { Payload payload = null; try { @@ -837,7 +837,7 @@ private void processPayload(ByteBuf data, String payloadId, String modelId, Stri if (!takeOwnership) { ReferenceCountUtil.retain(data); } - payload = new Payload(payloadId, modelId, methodName, metadata, data, status); + payload = new Payload(payloadId, modelId, vModelId, methodName, metadata, data, status); if (payloadProcessor.process(payload)) { data = null; // ownership transferred } diff --git a/src/main/java/com/ibm/watson/modelmesh/payload/MatchingPayloadProcessor.java b/src/main/java/com/ibm/watson/modelmesh/payload/MatchingPayloadProcessor.java index 45402423..8e2a2501 100644 --- a/src/main/java/com/ibm/watson/modelmesh/payload/MatchingPayloadProcessor.java +++ b/src/main/java/com/ibm/watson/modelmesh/payload/MatchingPayloadProcessor.java @@ -17,6 +17,7 @@ package com.ibm.watson.modelmesh.payload; import java.io.IOException; +import java.util.Objects; /** * A {@link PayloadProcessor} that processes {@link Payload}s only if they match with given model ID or method name. @@ -29,10 +30,13 @@ public class MatchingPayloadProcessor implements PayloadProcessor { private final String modelId; - MatchingPayloadProcessor(PayloadProcessor delegate, String methodName, String modelId) { + private final String vModelId; + + MatchingPayloadProcessor(PayloadProcessor delegate, String methodName, String modelId, String vModelId) { this.delegate = delegate; this.methodName = methodName; this.modelId = modelId; + this.vModelId = vModelId; } @Override @@ -42,40 +46,49 @@ public String getName() { @Override public boolean process(Payload payload) { - boolean processed = false; - boolean methodMatches = true; - if (this.methodName != null) { - methodMatches = payload.getMethod() != null && this.methodName.equals(payload.getMethod()); - } + boolean methodMatches = this.methodName == null || Objects.equals(this.methodName, payload.getMethod()); if (methodMatches) { - boolean modelIdMatches = true; - if (this.modelId != null) { - modelIdMatches = this.modelId.equals(payload.getModelId()); - } + boolean modelIdMatches = this.modelId == null || this.modelId.equals(payload.getModelId()); if (modelIdMatches) { - processed = delegate.process(payload); + boolean vModelIdMatches = this.vModelId == null || this.vModelId.equals(payload.getVModelId()); + if (vModelIdMatches) { + return delegate.process(payload); + } } } - return processed; + return false; } public static MatchingPayloadProcessor from(String modelId, String method, PayloadProcessor processor) { + return from(modelId, null, method, processor); + } + + public static MatchingPayloadProcessor from(String modelId, String vModelId, + String method, PayloadProcessor processor) { if (modelId != null) { - if (modelId.length() > 0) { + if (!modelId.isEmpty()) { modelId = modelId.replaceFirst("/", ""); - if (modelId.length() == 0 || modelId.equals("*")) { + if (modelId.isEmpty() || modelId.equals("*")) { modelId = null; } } else { modelId = null; } } - if (method != null) { - if (method.length() == 0 || method.equals("*")) { - method = null; + if (vModelId != null) { + if (!vModelId.isEmpty()) { + vModelId = vModelId.replaceFirst("/", ""); + if (vModelId.isEmpty() || vModelId.equals("*")) { + vModelId = null; + } + } else { + vModelId = null; } } - return new MatchingPayloadProcessor(processor, method, modelId); + if (method != null && (method.isEmpty() || method.equals("*"))) { + method = null; + } + return new MatchingPayloadProcessor(processor, method, modelId, vModelId); } @Override diff --git a/src/main/java/com/ibm/watson/modelmesh/payload/Payload.java b/src/main/java/com/ibm/watson/modelmesh/payload/Payload.java index 9eed4367..6dcafd17 100644 --- a/src/main/java/com/ibm/watson/modelmesh/payload/Payload.java +++ b/src/main/java/com/ibm/watson/modelmesh/payload/Payload.java @@ -39,6 +39,8 @@ public enum Kind { private final String modelId; + private final String vModelId; + private final String method; private final Metadata metadata; @@ -48,10 +50,17 @@ public enum Kind { // null for requests, non-null for responses private final Status status; + public Payload(@Nonnull String id, @Nonnull String modelId, @Nullable String method, @Nullable Metadata metadata, @Nullable ByteBuf data, @Nullable Status status) { + this(id, modelId, null, method, metadata, data, status); + } + + public Payload(@Nonnull String id, @Nonnull String modelId, @Nullable String vModelId, @Nullable String method, + @Nullable Metadata metadata, @Nullable ByteBuf data, @Nullable Status status) { this.id = id; this.modelId = modelId; + this.vModelId = vModelId; this.method = method; this.metadata = metadata; this.data = data; @@ -68,6 +77,16 @@ public String getModelId() { return modelId; } + @CheckForNull + public String getVModelId() { + return vModelId; + } + + @Nonnull + public String getVModelIdOrModelId() { + return vModelId != null ? vModelId : modelId; + } + @CheckForNull public String getMethod() { return method; @@ -101,6 +120,7 @@ public void release() { public String toString() { return "Payload{" + "id='" + id + '\'' + + ", vModelId=" + (vModelId != null ? ('\'' + vModelId + '\'') : "null") + ", modelId='" + modelId + '\'' + ", method='" + method + '\'' + ", status=" + (status == null ? "request" : String.valueOf(status)) + diff --git a/src/main/java/com/ibm/watson/modelmesh/payload/RemotePayloadProcessor.java b/src/main/java/com/ibm/watson/modelmesh/payload/RemotePayloadProcessor.java index 401fba2d..23c2fba1 100644 --- a/src/main/java/com/ibm/watson/modelmesh/payload/RemotePayloadProcessor.java +++ b/src/main/java/com/ibm/watson/modelmesh/payload/RemotePayloadProcessor.java @@ -57,14 +57,10 @@ public boolean process(Payload payload) { private static PayloadContent prepareContentBody(Payload payload) { String id = payload.getId(); String modelId = payload.getModelId(); + String vModelId = payload.getVModelId(); String kind = payload.getKind().toString().toLowerCase(); ByteBuf byteBuf = payload.getData(); - String data; - if (byteBuf != null) { - data = encodeBinaryToString(byteBuf); - } else { - data = ""; - } + String data = byteBuf != null ? encodeBinaryToString(byteBuf) : ""; Metadata metadata = payload.getMetadata(); Map metadataMap = new HashMap<>(); if (metadata != null) { @@ -79,7 +75,7 @@ private static PayloadContent prepareContentBody(Payload payload) { } } String status = payload.getStatus() != null ? payload.getStatus().getCode().toString() : ""; - return new PayloadContent(id, modelId, data, kind, status, metadataMap); + return new PayloadContent(id, modelId, vModelId, data, kind, status, metadataMap); } private static String encodeBinaryToString(ByteBuf byteBuf) { @@ -116,15 +112,17 @@ private static class PayloadContent { private final String id; private final String modelid; + private final String vModelId; private final String data; private final String kind; private final String status; private final Map metadata; - private PayloadContent(String id, String modelid, String data, String kind, String status, - Map metadata) { + private PayloadContent(String id, String modelid, String vModelId, String data, String kind, + String status, Map metadata) { this.id = id; this.modelid = modelid; + this.vModelId = vModelId; this.data = data; this.kind = kind; this.status = status; @@ -143,6 +141,10 @@ public String getModelid() { return modelid; } + public String getvModelId() { + return vModelId; + } + public String getData() { return data; } @@ -160,6 +162,7 @@ public String toString() { return "PayloadContent{" + "id='" + id + '\'' + ", modelid='" + modelid + '\'' + + ", vModelId=" + (vModelId != null ? ('\'' + vModelId + '\'') : "null") + ", data='" + data + '\'' + ", kind='" + kind + '\'' + ", status='" + status + '\'' +