Skip to content

Commit

Permalink
Support forecast tasks in profile API; enable index field modifications
Browse files Browse the repository at this point in the history
Signed-off-by: Kaituo Li <[email protected]>
  • Loading branch information
kaituo committed Sep 27, 2024
1 parent cbe1f33 commit 519640f
Show file tree
Hide file tree
Showing 50 changed files with 1,283 additions and 169 deletions.
5 changes: 1 addition & 4 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -696,15 +696,11 @@ List<String> jacocoExclusions = [

// TODO: add test coverage (kaituo)
'org.opensearch.forecast.*',
'org.opensearch.ad.transport.ADHCImputeNodeResponse',
'org.opensearch.ad.transport.GetAnomalyDetectorTransportAction',
'org.opensearch.timeseries.transport.BooleanNodeResponse',
'org.opensearch.timeseries.ml.TimeSeriesSingleStreamCheckpointDao',
'org.opensearch.timeseries.transport.JobRequest',
'org.opensearch.timeseries.transport.handler.ResultBulkIndexingHandler',
'org.opensearch.timeseries.ml.Inferencer',
'org.opensearch.timeseries.transport.SingleStreamResultRequest',
'org.opensearch.timeseries.transport.BooleanResponse',
'org.opensearch.timeseries.rest.handler.IndexJobActionHandler.1',
'org.opensearch.timeseries.transport.SuggestConfigParamResponse',
'org.opensearch.timeseries.transport.SuggestConfigParamRequest',
Expand Down Expand Up @@ -732,6 +728,7 @@ List<String> jacocoExclusions = [
'org.opensearch.timeseries.util.TimeUtil',
'org.opensearch.ad.transport.ADHCImputeTransportAction',
'org.opensearch.timeseries.ml.RealTimeInferencer',
'org.opensearch.timeseries.util.ExpiringValue',
]


Expand Down
8 changes: 6 additions & 2 deletions src/main/java/org/opensearch/ad/ml/ADRealTimeInferencer.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

import static org.opensearch.timeseries.TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME;

import java.time.Clock;

import org.opensearch.ad.caching.ADCacheProvider;
import org.opensearch.ad.caching.ADPriorityCache;
import org.opensearch.ad.indices.ADIndex;
Expand All @@ -32,7 +34,8 @@ public ADRealTimeInferencer(
ADColdStartWorker coldStartWorker,
ADSaveResultStrategy resultWriteWorker,
ADCacheProvider cache,
ThreadPool threadPool
ThreadPool threadPool,
Clock clock
) {
super(
modelManager,
Expand All @@ -43,7 +46,8 @@ public ADRealTimeInferencer(
resultWriteWorker,
cache,
threadPool,
AD_THREAD_POOL_NAME
AD_THREAD_POOL_NAME,
clock
);
}

Expand Down
3 changes: 2 additions & 1 deletion src/main/java/org/opensearch/ad/model/ADTask.java
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,8 @@ public static ADTask parse(XContentParser parser, String taskId) throws IOExcept
detector.getCustomResultIndexMinSize(),
detector.getCustomResultIndexMinAge(),
detector.getCustomResultIndexTTL(),
detector.getFlattenResultIndexMapping()
detector.getFlattenResultIndexMapping(),
detector.getLastBreakingUIChangeTime()
);
return new Builder()
.taskId(parsedTaskId)
Expand Down
17 changes: 14 additions & 3 deletions src/main/java/org/opensearch/ad/model/AnomalyDetector.java
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,8 @@ public Integer getShingleSize(Integer customShingleSize) {
* @param customResultIndexMinAge custom result index lifecycle management min age condition
* @param customResultIndexTTL custom result index lifecycle management ttl
* @param flattenResultIndexMapping flag to indicate whether to flatten result index mapping or not
* @param lastBreakingUIChangeTime last update time to configuration that can break UI and we have
* to display updates from the changed time
*/
public AnomalyDetector(
String detectorId,
Expand Down Expand Up @@ -178,7 +180,8 @@ public AnomalyDetector(
Integer customResultIndexMinSize,
Integer customResultIndexMinAge,
Integer customResultIndexTTL,
Boolean flattenResultIndexMapping
Boolean flattenResultIndexMapping,
Instant lastBreakingUIChangeTime
) {
super(
detectorId,
Expand Down Expand Up @@ -206,7 +209,8 @@ public AnomalyDetector(
customResultIndexMinSize,
customResultIndexMinAge,
customResultIndexTTL,
flattenResultIndexMapping
flattenResultIndexMapping,
lastBreakingUIChangeTime
);

checkAndThrowValidationErrors(ValidationAspect.DETECTOR);
Expand Down Expand Up @@ -284,6 +288,7 @@ public AnomalyDetector(StreamInput input) throws IOException {
this.customResultIndexMinAge = input.readOptionalInt();
this.customResultIndexTTL = input.readOptionalInt();
this.flattenResultIndexMapping = input.readOptionalBoolean();
this.lastUIBreakingChangeTime = input.readOptionalInstant();
}

public XContentBuilder toXContent(XContentBuilder builder) throws IOException {
Expand Down Expand Up @@ -350,6 +355,7 @@ public void writeTo(StreamOutput output) throws IOException {
output.writeOptionalInt(customResultIndexMinAge);
output.writeOptionalInt(customResultIndexTTL);
output.writeOptionalBoolean(flattenResultIndexMapping);
output.writeOptionalInstant(lastUIBreakingChangeTime);
}

@Override
Expand Down Expand Up @@ -447,6 +453,7 @@ public static AnomalyDetector parse(
Integer customResultIndexMinAge = null;
Integer customResultIndexTTL = null;
Boolean flattenResultIndexMapping = null;
Instant lastBreakingUIChangeTime = null;

ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
Expand Down Expand Up @@ -584,6 +591,9 @@ public static AnomalyDetector parse(
case FLATTEN_RESULT_INDEX_MAPPING:
flattenResultIndexMapping = onlyParseBooleanValue(parser);
break;
case BREAKING_UI_CHANGE_TIME:
lastBreakingUIChangeTime = ParseUtils.toInstant(parser);
break;
default:
parser.skipChildren();
break;
Expand Down Expand Up @@ -615,7 +625,8 @@ public static AnomalyDetector parse(
customResultIndexMinSize,
customResultIndexMinAge,
customResultIndexTTL,
flattenResultIndexMapping
flattenResultIndexMapping,
lastBreakingUIChangeTime
);
detector.setDetectionDateRange(detectionDateRange);
return detector;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,11 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli
: WriteRequest.RefreshPolicy.IMMEDIATE;
RestRequest.Method method = request.getHttpRequest().method();

if (method == RestRequest.Method.POST && detectorId != AnomalyDetector.NO_ID) {
// reset detector to empty string detectorId is only meant for updating detector
detectorId = AnomalyDetector.NO_ID;
}

IndexAnomalyDetectorRequest indexAnomalyDetectorRequest = new IndexAnomalyDetectorRequest(
detectorId,
seqNo,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,8 @@ protected AnomalyDetector copyConfig(User user, Config config) {
config.getCustomResultIndexMinSize(),
config.getCustomResultIndexMinAge(),
config.getCustomResultIndexTTL(),
config.getFlattenResultIndexMapping()
config.getFlattenResultIndexMapping(),
breakingUIChange ? Instant.now() : config.getLastBreakingUIChangeTime()
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
import org.opensearch.timeseries.ml.ModelState;
import org.opensearch.timeseries.ml.Sample;
import org.opensearch.timeseries.model.Config;
import org.opensearch.timeseries.model.IntervalTimeConfiguration;
import org.opensearch.timeseries.util.ActionListenerExecutor;
import org.opensearch.transport.TransportService;

Expand Down Expand Up @@ -129,14 +128,12 @@ protected ADHCImputeNodeResponse nodeOperation(ADHCImputeNodeRequest nodeRequest
return;
}
Config config = configOptional.get();
long windowDelayMillis = ((IntervalTimeConfiguration) config.getWindowDelay()).toDuration().toMillis();
int featureSize = config.getEnabledFeatureIds().size();
long dataEndMillis = nodeRequest.getRequest().getDataEndMillis();
long dataStartMillis = nodeRequest.getRequest().getDataStartMillis();
long executionEndTime = dataEndMillis + windowDelayMillis;
String taskId = nodeRequest.getRequest().getTaskId();
for (ModelState<ThresholdedRandomCutForest> modelState : cache.get().getAllModels(configId)) {
if (shouldProcessModelState(modelState, executionEndTime, clusterService, hashRing)) {
if (shouldProcessModelState(modelState, dataEndMillis, clusterService, hashRing)) {
double[] nanArray = new double[featureSize];
Arrays.fill(nanArray, Double.NaN);
adInferencer
Expand All @@ -163,8 +160,8 @@ protected ADHCImputeNodeResponse nodeOperation(ADHCImputeNodeRequest nodeRequest
* Determines whether the model state should be processed based on various conditions.
*
* Conditions checked:
* - The model's last seen execution end time is not the minimum Instant value.
* - The current execution end time is greater than or equal to the model's last seen execution end time,
* - The model's last seen data end time is not the minimum Instant value. This means the model hasn't been initialized yet.
* - The current data end time is greater than the model's last seen data end time,
* indicating that the model state was updated in previous intervals.
* - The entity associated with the model state is present.
* - The owning node for real-time processing of the entity, with the same local version, is present in the hash ring.
Expand All @@ -175,14 +172,14 @@ protected ADHCImputeNodeResponse nodeOperation(ADHCImputeNodeRequest nodeRequest
* concurrently (e.g., during tests when multiple threads may operate quickly).
*
* @param modelState The current state of the model.
* @param executionEndTime The end time of the current execution interval.
* @param dataEndTime The data end time of current interval.
* @param clusterService The service providing information about the current cluster node.
* @param hashRing The hash ring used to determine the owning node for real-time processing of entities.
* @return true if the model state should be processed; otherwise, false.
*/
private boolean shouldProcessModelState(
ModelState<ThresholdedRandomCutForest> modelState,
long executionEndTime,
long dataEndTime,
ClusterService clusterService,
HashRing hashRing
) {
Expand All @@ -194,8 +191,8 @@ private boolean shouldProcessModelState(
// Check if the model state conditions are met for processing
// We cannot use last used time as it will be updated whenever we update its priority in CacheBuffer.update when there is a
// PriorityCache.get.
return modelState.getLastSeenExecutionEndTime() != Instant.MIN
&& executionEndTime >= modelState.getLastSeenExecutionEndTime().toEpochMilli()
return modelState.getLastSeenDataEndTime() != Instant.MIN
&& dataEndTime > modelState.getLastSeenDataEndTime().toEpochMilli()
&& modelState.getEntity().isPresent()
&& owningNode.isPresent()
&& owningNode.get().getId().equals(clusterService.localNode().getId());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,18 @@ public class ForecastTaskProfileRunner implements TaskProfileRunner<ForecastTask

@Override
public void getTaskProfile(ForecastTask configLevelTask, ActionListener<ForecastTaskProfile> listener) {
// return null since forecasting have no in-memory task profiles as AD
listener.onResponse(null);
// return null in other fields since forecasting have no in-memory task profiles as AD
listener
.onResponse(
new ForecastTaskProfile(
configLevelTask,
null,
null,
null,
configLevelTask == null ? null : configLevelTask.getTaskId(),
null
)
);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

import static org.opensearch.timeseries.TimeSeriesAnalyticsPlugin.FORECAST_THREAD_POOL_NAME;

import java.time.Clock;

import org.opensearch.forecast.caching.ForecastCacheProvider;
import org.opensearch.forecast.caching.ForecastPriorityCache;
import org.opensearch.forecast.indices.ForecastIndex;
Expand All @@ -32,7 +34,8 @@ public ForecastRealTimeInferencer(
ForecastColdStartWorker coldStartWorker,
ForecastSaveResultStrategy resultWriteWorker,
ForecastCacheProvider cache,
ThreadPool threadPool
ThreadPool threadPool,
Clock clock
) {
super(
modelManager,
Expand All @@ -43,7 +46,8 @@ public ForecastRealTimeInferencer(
resultWriteWorker,
cache,
threadPool,
FORECAST_THREAD_POOL_NAME
FORECAST_THREAD_POOL_NAME,
clock
);
}

Expand Down
9 changes: 6 additions & 3 deletions src/main/java/org/opensearch/forecast/model/ForecastTask.java
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,8 @@ public static ForecastTask parse(XContentParser parser, String taskId) throws IO
forecaster.getCustomResultIndexMinSize(),
forecaster.getCustomResultIndexMinAge(),
forecaster.getCustomResultIndexTTL(),
forecaster.getFlattenResultIndexMapping()
forecaster.getFlattenResultIndexMapping(),
forecaster.getLastBreakingUIChangeTime()
);
return new Builder()
.taskId(parsedTaskId)
Expand Down Expand Up @@ -375,10 +376,12 @@ public static ForecastTask parse(XContentParser parser, String taskId) throws IO
@Generated
@Override
public boolean equals(Object other) {
if (this == other)
if (this == other) {
return true;
if (other == null || getClass() != other.getClass())
}
if (other == null || getClass() != other.getClass()) {
return false;
}
ForecastTask that = (ForecastTask) other;
return super.equals(that)
&& Objects.equal(getForecaster(), that.getForecaster())
Expand Down
13 changes: 10 additions & 3 deletions src/main/java/org/opensearch/forecast/model/Forecaster.java
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,8 @@ public Forecaster(
Integer customResultIndexMinSize,
Integer customResultIndexMinAge,
Integer customResultIndexTTL,
Boolean flattenResultIndexMapping
Boolean flattenResultIndexMapping,
Instant lastBreakingUIChangeTime
) {
super(
forecasterId,
Expand Down Expand Up @@ -163,7 +164,8 @@ public Forecaster(
customResultIndexMinSize,
customResultIndexMinAge,
customResultIndexTTL,
flattenResultIndexMapping
flattenResultIndexMapping,
lastBreakingUIChangeTime
);

checkAndThrowValidationErrors(ValidationAspect.FORECASTER);
Expand Down Expand Up @@ -306,6 +308,7 @@ public static Forecaster parse(
Integer customResultIndexMinAge = null;
Integer customResultIndexTTL = null;
Boolean flattenResultIndexMapping = null;
Instant lastBreakingUIChangeTime = null;

ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
Expand Down Expand Up @@ -437,6 +440,9 @@ public static Forecaster parse(
case FLATTEN_RESULT_INDEX_MAPPING:
flattenResultIndexMapping = parser.booleanValue();
break;
case BREAKING_UI_CHANGE_TIME:
lastBreakingUIChangeTime = ParseUtils.toInstant(parser);
break;
default:
parser.skipChildren();
break;
Expand Down Expand Up @@ -468,7 +474,8 @@ public static Forecaster parse(
customResultIndexMinSize,
customResultIndexMinAge,
customResultIndexTTL,
flattenResultIndexMapping
flattenResultIndexMapping,
lastBreakingUIChangeTime
);
return forecaster;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,11 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli
: WriteRequest.RefreshPolicy.IMMEDIATE;
RestRequest.Method method = request.getHttpRequest().method();

if (method == RestRequest.Method.POST && forecasterId != Config.NO_ID) {
// reset detector to empty string detectorId is only meant for updating detector
forecasterId = Config.NO_ID;
}

IndexForecasterRequest indexAnomalyDetectorRequest = new IndexForecasterRequest(
forecasterId,
seqNo,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,8 @@ protected Config copyConfig(User user, Config config) {
config.getCustomResultIndexMinSize(),
config.getCustomResultIndexMinAge(),
config.getCustomResultIndexTTL(),
config.getFlattenResultIndexMapping()
config.getFlattenResultIndexMapping(),
breakingUIChange ? Instant.now() : config.getLastBreakingUIChangeTime()
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -840,7 +840,8 @@ public PooledObject<LinkedBuffer> wrap(LinkedBuffer obj) {
adColdstartQueue,
adSaveResultStrategy,
adCacheProvider,
threadPool
threadPool,
getClock()
);

ADCheckpointReadWorker adCheckpointReadQueue = new ADCheckpointReadWorker(
Expand Down Expand Up @@ -1230,7 +1231,8 @@ public PooledObject<LinkedBuffer> wrap(LinkedBuffer obj) {
forecastColdstartQueue,
forecastSaveResultStrategy,
forecastCacheProvider,
threadPool
threadPool,
getClock()
);

ForecastCheckpointReadWorker forecastCheckpointReadQueue = new ForecastCheckpointReadWorker(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ public ModelState<RCFModelType> get(String modelId, Config config) {
// reset every 60 intervals
return new DoorKeeper(
TimeSeriesSettings.DOOR_KEEPER_FOR_CACHE_MAX_INSERTION,
config.getIntervalDuration().multipliedBy(TimeSeriesSettings.DOOR_KEEPER_MAINTENANCE_FREQ),
config.getIntervalDuration().multipliedBy(TimeSeriesSettings.EXPIRING_VALUE_MAINTENANCE_FREQ),
clock,
TimeSeriesSettings.CACHE_DOOR_KEEPER_COUNT_THRESHOLD
);
Expand Down
Loading

0 comments on commit 519640f

Please sign in to comment.