Skip to content

Commit

Permalink
Merge pull request #1 from njhill/modelmetrics
Browse files Browse the repository at this point in the history
Update per-model metric changes
  • Loading branch information
VedantMahabaleshwarkar authored Aug 29, 2023
2 parents 72e1c8e + 37eed31 commit 7f7ae87
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 106 deletions.
37 changes: 20 additions & 17 deletions src/main/java/com/ibm/watson/modelmesh/Metrics.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package com.ibm.watson.modelmesh;

import com.google.common.base.Strings;
import com.ibm.watson.prometheus.Counter;
import com.ibm.watson.prometheus.Gauge;
import com.ibm.watson.prometheus.Histogram;
Expand Down Expand Up @@ -54,7 +55,6 @@
import static com.ibm.watson.modelmesh.Metric.MetricType.*;
import static com.ibm.watson.modelmesh.ModelMesh.M;
import static com.ibm.watson.modelmesh.ModelMeshEnvVars.MMESH_CUSTOM_ENV_VAR;
import static com.ibm.watson.modelmesh.ModelMeshEnvVars.MMESH_METRICS_ENV_VAR;
import static java.util.concurrent.TimeUnit.*;

/**
Expand Down Expand Up @@ -172,7 +172,7 @@ public PrometheusMetrics(Map<String, String> params, Map<String, String> infoMet
int port = 2112;
boolean shortNames = true;
boolean https = true;
boolean perModelMetricsEnabled= true;
boolean perModelMetricsEnabled = true;
String memMetrics = "all"; // default to all
for (Entry<String, String> ent : params.entrySet()) {
switch (ent.getKey()) {
Expand Down Expand Up @@ -204,7 +204,7 @@ public PrometheusMetrics(Map<String, String> params, Map<String, String> infoMet
throw new Exception("Unrecognized metrics config parameter: " + ent.getKey());
}
}
this.perModelMetricsEnabled= perModelMetricsEnabled;
this.perModelMetricsEnabled = perModelMetricsEnabled;

registry = new CollectorRegistry();
for (Metric m : Metric.values()) {
Expand Down Expand Up @@ -273,6 +273,7 @@ public PrometheusMetrics(Map<String, String> params, Map<String, String> infoMet

this.metricServer = new NettyServer(registry, port, https);
this.shortNames = shortNames;

logger.info("Will expose " + (https ? "https" : "http") + " Prometheus metrics on port " + port
+ " using " + (shortNames ? "short" : "fully-qualified") + " method names");

Expand Down Expand Up @@ -369,19 +370,21 @@ public void logTimingMetricSince(Metric metric, long prevTime, boolean isNano) {

@Override
public void logTimingMetricDuration(Metric metric, long elapsed, boolean isNano, String modelId) {
Histogram histogram = (Histogram) metricsMap.get(metric);
if (perModelMetricsEnabled && modelId != null) {
((Histogram) metricsMap.get(metric)).labels(modelId, "").observe(isNano ? elapsed / M : elapsed);
histogram.labels(modelId, "").observe(isNano ? elapsed / M : elapsed);
} else {
((Histogram) metricsMap.get(metric)).observe(isNano ? elapsed / M : elapsed);
histogram.observe(isNano ? elapsed / M : elapsed);
}
}

@Override
public void logSizeEventMetric(Metric metric, long value, String modelId) {
Histogram histogram = (Histogram) metricsMap.get(metric);
if (perModelMetricsEnabled) {
((Histogram) metricsMap.get(metric)).labels(modelId, "").observe(value * metric.newMultiplier);
histogram.labels(modelId, "").observe(value * metric.newMultiplier);
} else {
((Histogram) metricsMap.get(metric)).observe(value * metric.newMultiplier);
histogram.observe(value * metric.newMultiplier);
}
}

Expand All @@ -403,32 +406,32 @@ public void logRequestMetrics(boolean external, String name, long elapsedNanos,
final long elapsedMillis = elapsedNanos / M;
final Histogram timingHisto = (Histogram) metricsMap
.get(external ? API_REQUEST_TIME : INVOKE_MODEL_TIME);

int idx = shortNames ? name.indexOf('/') : -1;
String methodName = idx == -1 ? name : name.substring(idx + 1);
if (perModelMetricsEnabled&& vModelId == null) {
vModelId = "";
if (perModelMetricsEnabled) {
modelId = Strings.nullToEmpty(modelId);
vModelId = Strings.nullToEmpty(vModelId);
}
if (perModelMetricsEnabled) {
timingHisto.labels(methodName, code.name(), modelId, vModelId).observe(elapsedMillis);
} else {
timingHisto.labels(methodName, code.name()).observe(elapsedMillis);
}
if (reqPayloadSize != -1) {
Histogram reqPayloadHisto = (Histogram) metricsMap.get(REQUEST_PAYLOAD_SIZE);
if (perModelMetricsEnabled) {
((Histogram) metricsMap.get(REQUEST_PAYLOAD_SIZE))
.labels(methodName, code.name(), modelId, vModelId).observe(reqPayloadSize);
reqPayloadHisto.labels(methodName, code.name(), modelId, vModelId).observe(reqPayloadSize);
} else {
((Histogram) metricsMap.get(REQUEST_PAYLOAD_SIZE))
.labels(methodName, code.name()).observe(reqPayloadSize);
reqPayloadHisto.labels(methodName, code.name()).observe(reqPayloadSize);
}
}
if (respPayloadSize != -1) {
Histogram respPayloadHisto = (Histogram) metricsMap.get(RESPONSE_PAYLOAD_SIZE);
if (perModelMetricsEnabled) {
((Histogram) metricsMap.get(RESPONSE_PAYLOAD_SIZE))
.labels(methodName, code.name(), modelId, vModelId).observe(respPayloadSize);
respPayloadHisto.labels(methodName, code.name(), modelId, vModelId).observe(respPayloadSize);
} else {
((Histogram) metricsMap.get(RESPONSE_PAYLOAD_SIZE))
.labels(methodName, code.name()).observe(respPayloadSize);
respPayloadHisto.labels(methodName, code.name()).observe(respPayloadSize);
}
}
}
Expand Down
37 changes: 14 additions & 23 deletions src/main/java/com/ibm/watson/modelmesh/ModelMesh.java
Original file line number Diff line number Diff line change
Expand Up @@ -3348,8 +3348,8 @@ public StatusInfo internalOperation(String modelId, boolean returnStatus, boolea
List<String> excludeInstances)
throws ModelNotFoundException, ModelLoadException, ModelNotHereException, InternalException {
try {
return (StatusInfo) invokeModel(modelId, null,
internalOpRemoteMeth, returnStatus, load, sync, lastUsed, excludeInstances); // <-- "args"
return (StatusInfo) invokeModel(modelId, null, internalOpRemoteMeth,
returnStatus, load, sync, lastUsed, excludeInstances); // <-- "args"
} catch (ModelNotFoundException | ModelLoadException | ModelNotHereException | InternalException e) {
throw e;
} catch (TException e) {
Expand Down Expand Up @@ -3417,8 +3417,8 @@ public StatusInfo internalOperation(String modelId, boolean returnStatus, boolea
* @throws TException
*/
@SuppressWarnings("unchecked")
protected Object invokeModel(final String modelId, final Method method,
final Method remoteMeth, final Object... args) throws ModelNotFoundException, ModelNotHereException, ModelLoadException, TException {
protected Object invokeModel(final String modelId, final Method method, final Method remoteMeth,
final Object... args) throws ModelNotFoundException, ModelNotHereException, ModelLoadException, TException {

//verify parameter values
if (modelId == null || modelId.isEmpty()) {
Expand All @@ -3431,7 +3431,7 @@ protected Object invokeModel(final String modelId, final Method method,
}

final String tasInternal = contextMap.get(TAS_INTERNAL_CXT_KEY);
String vModelId = contextMap.getOrDefault(VMODEL_ID, "");
final String vModelId = contextMap.getOrDefault(VMODEL_ID, "");
// Set the external request flag if it's not a tasInternal call or if
// tasInternal == INTERNAL_REQ. The latter is a new ensureLoaded
// invocation originating from within the cluster.
Expand Down Expand Up @@ -3504,7 +3504,7 @@ protected Object invokeModel(final String modelId, final Method method,
throw new ModelNotHereException(instanceId, modelId);
}
try {
return invokeLocalModel(ce, method, args, modelId);
return invokeLocalModel(ce, method, args, vModelId);
} catch (ModelLoadException mle) {
mr = registry.get(modelId);
if (mr == null || !mr.loadFailedInInstance(instanceId)) {
Expand Down Expand Up @@ -3718,7 +3718,7 @@ protected Object invokeModel(final String modelId, final Method method,
localInvokesInFlight.incrementAndGet();
}
try {
Object result = invokeLocalModel(cacheEntry, method, args, modelId);
Object result = invokeLocalModel(cacheEntry, method, args, vModelId);
return method == null && externalReq ? updateWithModelCopyInfo(result, mr) : result;
} finally {
if (!favourSelfForHits) {
Expand Down Expand Up @@ -3938,7 +3938,7 @@ else if (mr.getInstanceIds().containsKey(instanceId)) {

// invoke model
try {
Object result = invokeLocalModel(cacheEntry, method, args, modelId);
Object result = invokeLocalModel(cacheEntry, method, args, vModelId);
return method == null && externalReq ? updateWithModelCopyInfo(result, mr) : result;
} catch (ModelNotHereException e) {
if (loadTargetFilter != null) loadTargetFilter.remove(instanceId);
Expand Down Expand Up @@ -4123,6 +4123,7 @@ private Map<String, Long> filterIfReadOnly(Map<String, Long> instId) {
* instances inside and some out, and a request has been sent from outside the
* cluster to an instance inside (since it may land on an unintended instance in
* that case).
*
* @throws ModelNotHereException if the specified destination instance isn't found
*/
protected Object forwardInvokeModel(String destId, String modelId, Method remoteMeth, Object... args)
Expand Down Expand Up @@ -4404,17 +4405,17 @@ protected Object invokeRemoteModel(BaseModelMeshService.Iface client, Method met
return remoteMeth.invoke(client, ObjectArrays.concat(modelId, args));
}

protected Object invokeLocalModel(CacheEntry<?> ce, Method method, Object[] args, String modelId)
protected Object invokeLocalModel(CacheEntry<?> ce, Method method, Object[] args, String vModelId)
throws InterruptedException, TException {
Object result = invokeLocalModel(ce, method, args);
final Object result = _invokeLocalModel(ce, method, args, vModelId);
// if this is an ensure-loaded request, check-for and trigger a "chained" load if necessary
if (method == null) {
triggerChainedLoadIfNecessary(modelId, result, args, ce.getWeight(), null);
triggerChainedLoadIfNecessary(ce.modelId, result, args, ce.getWeight(), null);
}
return result;
}

private Object invokeLocalModel(CacheEntry<?> ce, Method method, Object[] args)
private Object _invokeLocalModel(CacheEntry<?> ce, Method method, Object[] args, String vModelId)
throws InterruptedException, TException {

if (method == null) {
Expand All @@ -4429,17 +4430,7 @@ private Object invokeLocalModel(CacheEntry<?> ce, Method method, Object[] args)
long now = currentTimeMillis();
ce.upgradePriority(now + 3600_000L, now + 7200_000L); // (2 hours in future)
}
Map<String, String> contextMap = ThreadContext.getCurrentContext();
String vModelId = null;
// We might arrive here from a path where the original call was with a modelid.
// Hence, it is possible to arrive here with a null contextMap because the vModelId was never set
// To avoid catching a null pointer exception we just sanity check instead.
if (contextMap == null) {
vModelId = "";
} else {
vModelId = contextMap.get(VMODEL_ID);
}


// The future-waiting timeouts should not be needed, request threads are interrupted when their
// timeouts/deadlines expire, and the model loading thread that it waits for has its own timeout.
// But we still set a large one as a safeguard (there can be pathalogical cases where model-loading
Expand Down
Loading

0 comments on commit 7f7ae87

Please sign in to comment.