Skip to content

Commit

Permalink
Address PR comments
Browse files Browse the repository at this point in the history
Signed-off-by: Vedant Mahabaleshwarkar <[email protected]>
  • Loading branch information
VedantMahabaleshwarkar committed Jul 21, 2023
1 parent 5d2e8fd commit 000cce7
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 67 deletions.
27 changes: 14 additions & 13 deletions src/main/java/com/ibm/watson/modelmesh/Metrics.java
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ interface Metrics extends AutoCloseable {
boolean isPerModelMetricsEnabled();

boolean isEnabled();

void logTimingMetricSince(Metric metric, long prevTime, boolean isNano);

void logTimingMetricDuration(Metric metric, long elapsed, boolean isNano, String modelId);
Expand Down Expand Up @@ -164,14 +165,14 @@ final class PrometheusMetrics implements Metrics {
private final CollectorRegistry registry;
private final NettyServer metricServer;
private final boolean shortNames;
private final boolean enablePerModelMetrics;
private final boolean perModelMetricsEnabled;
private final EnumMap<Metric, Collector> metricsMap = new EnumMap<>(Metric.class);

public PrometheusMetrics(Map<String, String> params, Map<String, String> infoMetricParams) throws Exception {
int port = 2112;
boolean shortNames = true;
boolean https = true;
boolean enablePerModelMetrics = true;
boolean perModelMetricsEnabled= true;
String memMetrics = "all"; // default to all
for (Entry<String, String> ent : params.entrySet()) {
switch (ent.getKey()) {
Expand All @@ -183,7 +184,7 @@ public PrometheusMetrics(Map<String, String> params, Map<String, String> infoMet
}
break;
case "per_model_metrics":
enablePerModelMetrics = "true".equalsIgnoreCase(ent.getValue());
perModelMetricsEnabled= "true".equalsIgnoreCase(ent.getValue());
break;
case "fq_names":
shortNames = !"true".equalsIgnoreCase(ent.getValue());
Expand All @@ -203,7 +204,7 @@ public PrometheusMetrics(Map<String, String> params, Map<String, String> infoMet
throw new Exception("Unrecognized metrics config parameter: " + ent.getKey());
}
}
this.enablePerModelMetrics = enablePerModelMetrics;
this.perModelMetricsEnabled= perModelMetricsEnabled;

registry = new CollectorRegistry();
for (Metric m : Metric.values()) {
Expand Down Expand Up @@ -237,12 +238,12 @@ public PrometheusMetrics(Map<String, String> params, Map<String, String> infoMet

if (m == API_REQUEST_TIME || m == API_REQUEST_COUNT || m == INVOKE_MODEL_TIME
|| m == INVOKE_MODEL_COUNT || m == REQUEST_PAYLOAD_SIZE || m == RESPONSE_PAYLOAD_SIZE) {
if (this.enablePerModelMetrics) {
if (this.perModelMetricsEnabled) {
builder.labelNames("method", "code", "modelId", "vModelId");
} else {
builder.labelNames("method", "code");
}
} else if (this.enablePerModelMetrics && m.type != GAUGE && m.type != COUNTER && m.type != COUNTER_WITH_HISTO) {
} else if (this.perModelMetricsEnabled && m.type != GAUGE && m.type != COUNTER && m.type != COUNTER_WITH_HISTO) {
builder.labelNames("modelId", "vModelId");
}
Collector collector = builder.name(m.promName).help(m.description).create();
Expand Down Expand Up @@ -352,7 +353,7 @@ public void close() {

@Override
public boolean isPerModelMetricsEnabled() {
return enablePerModelMetrics;
return perModelMetricsEnabled;
}

@Override
Expand All @@ -368,7 +369,7 @@ public void logTimingMetricSince(Metric metric, long prevTime, boolean isNano) {

@Override
public void logTimingMetricDuration(Metric metric, long elapsed, boolean isNano, String modelId) {
if (enablePerModelMetrics && modelId != null) {
if (perModelMetricsEnabled && modelId != null) {
((Histogram) metricsMap.get(metric)).labels(modelId, "").observe(isNano ? elapsed / M : elapsed);
} else {
((Histogram) metricsMap.get(metric)).observe(isNano ? elapsed / M : elapsed);
Expand All @@ -377,7 +378,7 @@ public void logTimingMetricDuration(Metric metric, long elapsed, boolean isNano,

@Override
public void logSizeEventMetric(Metric metric, long value, String modelId) {
if (enablePerModelMetrics) {
if (perModelMetricsEnabled) {
((Histogram) metricsMap.get(metric)).labels(modelId, "").observe(value * metric.newMultiplier);
} else {
((Histogram) metricsMap.get(metric)).observe(value * metric.newMultiplier);
Expand All @@ -404,16 +405,16 @@ public void logRequestMetrics(boolean external, String name, long elapsedNanos,
.get(external ? API_REQUEST_TIME : INVOKE_MODEL_TIME);
int idx = shortNames ? name.indexOf('/') : -1;
String methodName = idx == -1 ? name : name.substring(idx + 1);
if (enablePerModelMetrics && vModelId == null) {
if (perModelMetricsEnabled&& vModelId == null) {
vModelId = "";
}
if (enablePerModelMetrics) {
if (perModelMetricsEnabled) {
timingHisto.labels(methodName, code.name(), modelId, vModelId).observe(elapsedMillis);
} else {
timingHisto.labels(methodName, code.name()).observe(elapsedMillis);
}
if (reqPayloadSize != -1) {
if (enablePerModelMetrics) {
if (perModelMetricsEnabled) {
((Histogram) metricsMap.get(REQUEST_PAYLOAD_SIZE))
.labels(methodName, code.name(), modelId, vModelId).observe(reqPayloadSize);
} else {
Expand All @@ -422,7 +423,7 @@ public void logRequestMetrics(boolean external, String name, long elapsedNanos,
}
}
if (respPayloadSize != -1) {
if (enablePerModelMetrics) {
if (perModelMetricsEnabled) {
((Histogram) metricsMap.get(RESPONSE_PAYLOAD_SIZE))
.labels(methodName, code.name(), modelId, vModelId).observe(respPayloadSize);
} else {
Expand Down
6 changes: 3 additions & 3 deletions src/main/java/com/ibm/watson/modelmesh/ModelMesh.java
Original file line number Diff line number Diff line change
Expand Up @@ -3315,7 +3315,7 @@ protected Map<String, ServiceInstanceInfo> getMap(Object[] arr) {
static final String KNOWN_SIZE_CXT_KEY = "tas.known_size";
static final String UNBALANCED_KEY = "mmesh.unbalanced";
static final String DEST_INST_ID_KEY = "tas.dest_iid";
static final String VMODELID = "vmodelid";
static final String VMODEL_ID = "vmodelid";

// these are the possible values for the tas.internal context parameter
// it won't be set on requests from outside of the cluster, and will
Expand Down Expand Up @@ -3431,7 +3431,7 @@ protected Object invokeModel(final String modelId, final Method method,
}

final String tasInternal = contextMap.get(TAS_INTERNAL_CXT_KEY);
String vModelId = contextMap.getOrDefault(VMODELID, "");
String vModelId = contextMap.getOrDefault(VMODEL_ID, "");
// Set the external request flag if it's not a tasInternal call or if
// tasInternal == INTERNAL_REQ. The latter is a new ensureLoaded
// invocation originating from within the cluster.
Expand Down Expand Up @@ -4437,7 +4437,7 @@ private Object invokeLocalModel(CacheEntry<?> ce, Method method, Object[] args)
if (contextMap == null) {
vModelId = "";
} else {
vModelId = contextMap.get(VMODELID);
vModelId = contextMap.get(VMODEL_ID);
}

// The future-waiting timeouts should not be needed, request threads are interrupted when their
Expand Down
97 changes: 54 additions & 43 deletions src/main/java/com/ibm/watson/modelmesh/ModelMeshApi.java
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import com.ibm.watson.litelinks.server.ReleaseAfterResponse;
import com.ibm.watson.litelinks.server.ServerRequestThread;
import com.ibm.watson.modelmesh.DataplaneApiConfig.RpcConfig;
import com.ibm.watson.modelmesh.GrpcSupport.InterruptingListener;
import com.ibm.watson.modelmesh.ModelMesh.ExtendedStatusInfo;
import com.ibm.watson.modelmesh.api.DeleteVModelRequest;
import com.ibm.watson.modelmesh.api.DeleteVModelResponse;
Expand Down Expand Up @@ -68,6 +69,7 @@
import io.grpc.ServerInterceptors;
import io.grpc.ServerMethodDefinition;
import io.grpc.ServerServiceDefinition;
import io.grpc.Status.Code;
import io.grpc.StatusException;
import io.grpc.StatusRuntimeException;
import io.grpc.netty.GrpcSslContexts;
Expand Down Expand Up @@ -345,7 +347,7 @@ protected static void setUnbalancedLitelinksContextParam() {
}

protected static void setvModelIdLiteLinksContextParam(String vModelId) {
ThreadContext.addContextEntry(ModelMesh.VMODELID, vModelId);
ThreadContext.addContextEntry(ModelMesh.VMODEL_ID, vModelId);
}

// ----------------- concrete model management methods
Expand Down Expand Up @@ -434,38 +436,39 @@ public void ensureLoaded(EnsureLoadedRequest request, StreamObserver<ModelStatus
// Returned ModelResponse will be released once the request thread exits so
// must be retained before transferring.
// non-private to avoid synthetic method access
ModelResponse callModel(String originalModelId, boolean isVModel, String methodName, String grpcBalancedHeader,
ModelResponse callModel(String modelId, String vModelId, String methodName, String grpcBalancedHeader,
Metadata headers, ByteBuf data) throws Exception {
boolean unbalanced = grpcBalancedHeader == null ? UNBALANCED_DEFAULT : !"true".equals(grpcBalancedHeader);
if (!isVModel) {
if (!modelId.isBlank()) {
if (unbalanced) {
setUnbalancedLitelinksContextParam();
}
return delegate.callModel(originalModelId, methodName, headers, data);
}
String vModelId = originalModelId;
if (delegate.metrics.isEnabled()) {
setvModelIdLiteLinksContextParam(originalModelId);
}
boolean first = true;
while (true) {
String modelId = vmm().resolveVModelId(vModelId, originalModelId);
if (unbalanced) {
setUnbalancedLitelinksContextParam();
}
try {
return delegate.callModel(modelId, methodName, headers, data);
} catch (ModelNotFoundException mnfe) {
if (!first) throw mnfe;
} catch (Exception e) {
logger.error("Exception invoking " + methodName + " method of resolved model " + modelId + " of vmodel "
+ vModelId + ": " + e.getClass().getSimpleName() + ": " + e.getMessage());
throw e;
return delegate.callModel(modelId, methodName, headers, data);
} else if (!vModelId.isBlank()) {
boolean first = true;
while (true) {
String resolvedModelId = vmm().resolveVModelId(vModelId, vModelId);
if (unbalanced) {
setUnbalancedLitelinksContextParam();
}
try {
return delegate.callModel(resolvedModelId, methodName, headers, data);
} catch (ModelNotFoundException mnfe) {
if (!first) throw mnfe;
} catch (Exception e) {
logger.error("Exception invoking " + methodName + " method of resolved model " + modelId + " of vmodel "
+ vModelId + ": " + e.getClass().getSimpleName() + ": " + e.getMessage());
throw e;
}
// try again
first = false;
data.readerIndex(0); // rewind buffer
}
// try again
first = false;
data.readerIndex(0); // rewind buffer
} else {
throw statusException(DATA_LOSS,
"no valid modelid or vmodelid found for request");
}

}

// -----
Expand Down Expand Up @@ -715,7 +718,9 @@ public void onHalfClose() {
io.grpc.Status status = INTERNAL;
String modelId = null;
String requestId = null;
String resolvedModelId = null;
ModelResponse response = null;
Boolean isSingleModelRequest = null;
try (InterruptingListener cancelListener = newInterruptingListener()) {
if (logHeaders != null) {
logHeaders.addToMDC(headers); // MDC cleared in finally block
Expand All @@ -730,10 +735,20 @@ public void onHalfClose() {
modelId = validateModelId(midIt.next(), isVModel);
if (!midIt.hasNext()) {
// single model case (most common)
response = callModel(modelId, isVModel, methodName,
isSingleModelRequest = true;
if (isVModel && delegate.metrics.isEnabled()) {
setvModelIdLiteLinksContextParam(modelId);
resolvedModelId = vmm().resolveVModelId(modelId, modelId);
response = callModel("", modelId, methodName,
balancedMetaVal, headers, reqMessage).retain();
} else {
response = callModel(modelId, "", methodName,
balancedMetaVal, headers, reqMessage).retain();
}

} else {
// multi-model case (specialized)
isSingleModelRequest = false;
boolean allRequired = "all".equalsIgnoreCase(headers.get(REQUIRED_KEY));
List<String> idList = new ArrayList<>();
idList.add(modelId);
Expand Down Expand Up @@ -789,26 +804,19 @@ public void onHalfClose() {
call.close(status, emptyMeta());
Metrics metrics = delegate.metrics;
if (metrics.isEnabled()) {
Iterator<String> midIt = modelIds.iterator();
while (midIt.hasNext()) {
if (isSingleModelRequest && metrics.isPerModelMetricsEnabled() && modelId!=null) {
if (isVModel) {
String mId = null;
String vmId = midIt.next();
try {
mId = vmm().resolveVModelId(midIt.next(), mId);
metrics.logRequestMetrics(true, methodName, nanoTime() - startNanos,
status.getCode(), reqSize, respSize, mId, vmId);
}
catch (Exception e) {
logger.error("Could not resolve model id for vModelId" + vmId, e);
metrics.logRequestMetrics(true, methodName, nanoTime() - startNanos,
status.getCode(), reqSize, respSize, "", vmId);
}
metrics.logRequestMetrics(true, methodName, nanoTime() - startNanos,
status.getCode(), reqSize, respSize, resolvedModelId, modelId);
} else {
metrics.logRequestMetrics(true, methodName, nanoTime() - startNanos,
status.getCode(), reqSize, respSize, midIt.next(), "");
status.getCode(), reqSize, respSize, modelId, "");
}
}
else {
metrics.logRequestMetrics(true, methodName, nanoTime() - startNanos,
status.getCode(), reqSize, respSize, "", "");
}
}
}
}
Expand Down Expand Up @@ -969,7 +977,10 @@ protected ModelResponse applyParallelMultiModel(List<String> modelIds, boolean i
logHeaders.addToMDC(headers);
}
// need to pass slices of the buffers for threadsafety
return callModel(modelId, isVModel, methodName, balancedMetaVal, headers, data.slice());
if (isVModel) {
return callModel("", modelId, methodName, balancedMetaVal, headers, data.slice());
}
return callModel(modelId, "", methodName, balancedMetaVal, headers, data.slice());
} catch (ModelNotFoundException mnfe) {
logger.warn("model " + modelId + " not found (from supplied list of " + total + ")");
if (!requireAll) {
Expand Down
12 changes: 4 additions & 8 deletions src/main/java/com/ibm/watson/modelmesh/VModelManager.java
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@
/**
* This class contains logic related to VModels (virtual or versioned models)
*/
public final class VModelManager implements Closeable {
public final class VModelManager implements AutoCloseable {

private static final Logger logger = LoggerFactory.getLogger(VModelManager.class);

Expand Down Expand Up @@ -126,13 +126,9 @@ public ListenableFuture<Boolean> start() {
}

@Override
public void close() {
try {
vModelTable.close();
targetScaleupExecutor.shutdown();
} catch (IOException e) {
throw new UncheckedIOException(e);
}
public void close() throws Exception {
vModelTable.close();
targetScaleupExecutor.shutdown();
}

/**
Expand Down

0 comments on commit 000cce7

Please sign in to comment.