Skip to content

Commit

Permalink
Adjust to analysis depending suffix for state doc ids
Browse files Browse the repository at this point in the history
  • Loading branch information
dimitris-athanasiou committed Dec 12, 2019
1 parent ff284c5 commit c66bc6b
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -256,8 +256,7 @@ public boolean persistsState() {

@Override
public String getStateDocId(String jobId) {
// The state doc id prefix is same as for regression
return jobId + "_regression_state#1";
return jobId + "_classification_state#1";
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,6 @@ public void testGetStateDocId() {
Classification classification = createRandom();
assertThat(classification.persistsState(), is(true));
String randomId = randomAlphaOfLength(10);
assertThat(classification.getStateDocId(randomId), equalTo(randomId + "_regression_state#1"));
assertThat(classification.getStateDocId(randomId), equalTo(randomId + "_classification_state#1"));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix;
import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex;
import org.junit.After;

import java.util.ArrayList;
Expand Down Expand Up @@ -96,7 +95,7 @@ public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows() throws

assertProgress(jobId, 100, 100, 100, 100);
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
assertModelStatePersisted(jobId);
assertModelStatePersisted(stateDocId());
assertInferenceModelPersisted(jobId);
assertThatAuditMessagesMatch(jobId,
"Created analytics with analysis type [classification]",
Expand Down Expand Up @@ -137,7 +136,7 @@ public void testWithOnlyTrainingRowsAndTrainingPercentIsHundred() throws Excepti

assertProgress(jobId, 100, 100, 100, 100);
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
assertModelStatePersisted(jobId);
assertModelStatePersisted(stateDocId());
assertInferenceModelPersisted(jobId);
assertThatAuditMessagesMatch(jobId,
"Created analytics with analysis type [classification]",
Expand Down Expand Up @@ -198,7 +197,7 @@ public <T> void testWithOnlyTrainingRowsAndTrainingPercentIsFifty(

assertProgress(jobId, 100, 100, 100, 100);
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
assertModelStatePersisted(jobId);
assertModelStatePersisted(stateDocId());
assertInferenceModelPersisted(jobId);
assertThatAuditMessagesMatch(jobId,
"Created analytics with analysis type [classification]",
Expand Down Expand Up @@ -452,11 +451,7 @@ private <T> void assertEvaluation(String dependentVariable, List<T> dependentVar
}
}

private static void assertModelStatePersisted(String jobId) {
String docId = jobId + "_regression_state#1";
SearchResponse searchResponse = client().prepareSearch(AnomalyDetectorsIndex.jobStateIndexPattern())
.setQuery(QueryBuilders.idsQuery().addIds(docId))
.get();
assertThat(searchResponse.getHits().getHits().length, equalTo(1));
protected String stateDocId() {
return jobId + "_classification_state#1";
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -274,4 +274,11 @@ protected static Set<String> getTrainingRowsIds(String index) {
assertThat(trainingRowsIds.isEmpty(), is(false));
return trainingRowsIds;
}

protected static void assertModelStatePersisted(String stateDocId) {
SearchResponse searchResponse = client().prepareSearch(AnomalyDetectorsIndex.jobStateIndexPattern())
.setQuery(QueryBuilders.idsQuery().addIds(stateDocId))
.get();
assertThat(searchResponse.getHits().getHits().length, equalTo(1));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,12 @@
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.support.WriteRequest;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParams;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParamsTests;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex;
import org.junit.After;

import java.util.Arrays;
Expand Down Expand Up @@ -82,7 +80,7 @@ public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows() throws

assertProgress(jobId, 100, 100, 100, 100);
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
assertModelStatePersisted(jobId);
assertModelStatePersisted(stateDocId());
assertInferenceModelPersisted(jobId);
assertThatAuditMessagesMatch(jobId,
"Created analytics with analysis type [regression]",
Expand Down Expand Up @@ -119,7 +117,7 @@ public void testWithOnlyTrainingRowsAndTrainingPercentIsHundred() throws Excepti

assertProgress(jobId, 100, 100, 100, 100);
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
assertModelStatePersisted(jobId);
assertModelStatePersisted(stateDocId());
assertInferenceModelPersisted(jobId);
assertThatAuditMessagesMatch(jobId,
"Created analytics with analysis type [regression]",
Expand Down Expand Up @@ -171,7 +169,7 @@ public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty() throws Exception

assertProgress(jobId, 100, 100, 100, 100);
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
assertModelStatePersisted(jobId);
assertModelStatePersisted(stateDocId());
assertInferenceModelPersisted(jobId);
assertThatAuditMessagesMatch(jobId,
"Created analytics with analysis type [regression]",
Expand Down Expand Up @@ -233,7 +231,7 @@ public void testStopAndRestart() throws Exception {

assertProgress(jobId, 100, 100, 100, 100);
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
assertModelStatePersisted(jobId);
assertModelStatePersisted(stateDocId());
assertInferenceModelPersisted(jobId);
}

Expand Down Expand Up @@ -324,11 +322,7 @@ private static Map<String, Object> getMlResultsObjectFromDestDoc(Map<String, Obj
return resultsObject;
}

private static void assertModelStatePersisted(String jobId) {
String docId = jobId + "_regression_state#1";
SearchResponse searchResponse = client().prepareSearch(AnomalyDetectorsIndex.jobStateIndexPattern())
.setQuery(QueryBuilders.idsQuery().addIds(docId))
.get();
assertThat(searchResponse.getHits().getHits().length, equalTo(1));
protected String stateDocId() {
return jobId + "_regression_state#1";
}
}

0 comments on commit c66bc6b

Please sign in to comment.