Skip to content

Commit

Permalink
[ML] DF Regression: fix custom results_field and prediction_field_nam…
Browse files Browse the repository at this point in the history
…e not considered in eval config (#48599) (#48714)

* consider results_field and prediction_field_name in eval congif

* temp typescript fix for evaluatePanel not being used yet

* fix typescript error. clarify comment

* Disable create button if dependentVariable not set

* ensure advancedEditor json valid at every change

* update reducer test
  • Loading branch information
alvarezmelissa87 authored Oct 20, 2019
1 parent a614e89 commit d218e77
Show file tree
Hide file tree
Showing 6 changed files with 83 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ interface RegressionAnalysis {
regression: {
dependent_variable: string;
training_percent?: number;
prediction_field_name?: string;
};
}

Expand Down Expand Up @@ -81,6 +82,15 @@ export const getDependentVar = (analysis: AnalysisConfig) => {
return depVar;
};

export const getPredictionFieldName = (analysis: AnalysisConfig) => {
// If undefined will be defaulted to dependent_variable when config is created
let predictionFieldName;
if (isRegressionAnalysis(analysis) && analysis.regression.prediction_field_name !== undefined) {
predictionFieldName = analysis.regression.prediction_field_name;
}
return predictionFieldName;
};

export const isOutlierAnalysis = (arg: any): arg is OutlierAnalysis => {
const keys = Object.keys(arg);
return keys.length === 1 && keys[0] === ANALYSIS_CONFIG_TYPE.OUTLIER_DETECTION;
Expand Down Expand Up @@ -189,26 +199,30 @@ export const loadEvalData = async ({
isTraining,
index,
dependentVariable,
resultsField,
predictionFieldName,
}: {
isTraining: boolean;
index: string;
dependentVariable: string;
resultsField: string;
predictionFieldName?: string;
}) => {
const results: LoadEvaluateResult = { success: false, eval: null, error: null };
const defaultPredictionField = `${dependentVariable}_prediction`;
const predictedField = `${resultsField}.${
predictionFieldName ? predictionFieldName : defaultPredictionField
}`;

const query = { term: { [`${resultsField}.is_training`]: { value: isTraining } } };

const config = {
index,
query: {
term: {
'ml.is_training': {
value: isTraining,
},
},
},
query,
evaluation: {
regression: {
actual_field: dependentVariable,
predicted_field: `ml.${dependentVariable}_prediction`,
predicted_field: predictedField,
metrics: {
r_squared: {},
mean_squared_error: {},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
export {
getAnalysisType,
getDependentVar,
getPredictionFieldName,
isOutlierAnalysis,
refreshAnalyticsList$,
useRefreshAnalyticsList,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,15 @@ export const EvaluatePanel: FC<Props> = ({ jobId, index, dependentVariable }) =>
const loadData = async () => {
setIsLoadingGeneralization(true);
setIsLoadingTraining(true);

const genErrorEval = await loadEvalData({ isTraining: false, index, dependentVariable });
// TODO: resultsField and predictionFieldName will need to be properly passed to this function
// once the results view is in use.
const genErrorEval = await loadEvalData({
isTraining: false,
index,
dependentVariable,
resultsField: 'ml',
predictionFieldName: undefined,
});

if (genErrorEval.success === true && genErrorEval.eval) {
const { meanSquaredError, rSquared } = getValuesFromResponse(genErrorEval.eval);
Expand All @@ -58,8 +65,15 @@ export const EvaluatePanel: FC<Props> = ({ jobId, index, dependentVariable }) =>
error: genErrorEval.error,
});
}

const trainingErrorEval = await loadEvalData({ isTraining: true, index, dependentVariable });
// TODO: resultsField and predictionFieldName will need to be properly passed to this function
// once the results view is in use.
const trainingErrorEval = await loadEvalData({
isTraining: true,
index,
dependentVariable,
resultsField: 'ml',
predictionFieldName: undefined,
});

if (trainingErrorEval.success === true && trainingErrorEval.eval) {
const { meanSquaredError, rSquared } = getValuesFromResponse(trainingErrorEval.eval);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,13 @@ import { DataFrameAnalyticsListRow } from './common';
import { ExpandedRowDetailsPane, SectionConfig } from './expanded_row_details_pane';
import { ExpandedRowJsonPane } from './expanded_row_json_pane';
import { ProgressBar } from './progress_bar';
import { getDependentVar, getValuesFromResponse, loadEvalData, Eval } from '../../../../common';
import {
getDependentVar,
getPredictionFieldName,
getValuesFromResponse,
loadEvalData,
Eval,
} from '../../../../common';
import { isCompletedAnalyticsJob } from './common';
import { isRegressionAnalysis } from '../../../../common/analytics';
// import { ExpandedRowMessagesPane } from './expanded_row_messages_pane';
Expand Down Expand Up @@ -60,6 +66,9 @@ export const ExpandedRow: FC<Props> = ({ item }) => {
const [isLoadingGeneralization, setIsLoadingGeneralization] = useState<boolean>(false);
const index = idx(item, _ => _.config.dest.index) as string;
const dependentVariable = getDependentVar(item.config.analysis);
const predictionFieldName = getPredictionFieldName(item.config.analysis);
// default is 'ml'
const resultsField = item.config.dest.results_field;
const jobIsCompleted = isCompletedAnalyticsJob(item.stats);
const isRegressionJob = isRegressionAnalysis(item.config.analysis);

Expand All @@ -71,6 +80,8 @@ export const ExpandedRow: FC<Props> = ({ item }) => {
isTraining: false,
index,
dependentVariable,
resultsField,
predictionFieldName,
});

if (genErrorEval.success === true && genErrorEval.eval) {
Expand All @@ -94,6 +105,8 @@ export const ExpandedRow: FC<Props> = ({ item }) => {
isTraining: true,
index,
dependentVariable,
resultsField,
predictionFieldName,
});

if (trainingErrorEval.success === true && trainingErrorEval.eval) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ const getMockState = (index: SourceIndex) =>
jobConfig: {
source: { index },
dest: { index: 'the-destination-index' },
analysis: {},
},
});

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@ import { validateIndexPattern } from 'ui/index_patterns';
import { isValidIndexName } from '../../../../../../common/util/es_utils';

import { Action, ACTION } from './actions';
import { getInitialState, getJobConfigFromFormState, State } from './state';
import { getInitialState, getJobConfigFromFormState, State, JOB_TYPES } from './state';
import { isJobIdValid } from '../../../../../../common/util/job_utils';
import { maxLengthValidator } from '../../../../../../common/util/validators';
import { JOB_ID_MAX_LENGTH } from '../../../../../../common/constants/validation';
import { getDependentVar, isRegressionAnalysis } from '../../../../common/analytics';

const getSourceIndexString = (state: State) => {
const { jobConfig } = state;
Expand All @@ -34,7 +35,7 @@ const getSourceIndexString = (state: State) => {
};

export const validateAdvancedEditor = (state: State): State => {
const { jobIdEmpty, jobIdValid, jobIdExists, createIndexPattern } = state.form;
const { jobIdEmpty, jobIdValid, jobIdExists, jobType, createIndexPattern } = state.form;
const { jobConfig } = state;

state.advancedEditorMessages = [];
Expand Down Expand Up @@ -64,6 +65,12 @@ export const validateAdvancedEditor = (state: State): State => {
name => destinationIndexName === name
);

let dependentVariableEmpty = false;
if (isRegressionAnalysis(jobConfig.analysis)) {
const dependentVariableName = getDependentVar(jobConfig.analysis) || '';
dependentVariableEmpty = jobType === JOB_TYPES.REGRESSION && dependentVariableName === '';
}

if (sourceIndexNameEmpty) {
state.advancedEditorMessages.push({
error: i18n.translate(
Expand Down Expand Up @@ -108,6 +115,18 @@ export const validateAdvancedEditor = (state: State): State => {
});
}

if (dependentVariableEmpty) {
state.advancedEditorMessages.push({
error: i18n.translate(
'xpack.ml.dataframe.analytics.create.advancedEditorMessage.dependentVariableEmpty',
{
defaultMessage: 'The dependent variable field must not be empty.',
}
),
message: '',
});
}

state.isValid =
!jobIdEmpty &&
jobIdValid &&
Expand All @@ -116,6 +135,7 @@ export const validateAdvancedEditor = (state: State): State => {
sourceIndexNameValid &&
!destinationIndexNameEmpty &&
destinationIndexNameValid &&
!dependentVariableEmpty &&
(!destinationIndexPatternTitleExists || !createIndexPattern);

return state;
Expand All @@ -126,14 +146,18 @@ const validateForm = (state: State): State => {
jobIdEmpty,
jobIdValid,
jobIdExists,
jobType,
sourceIndexNameEmpty,
sourceIndexNameValid,
destinationIndexNameEmpty,
destinationIndexNameValid,
destinationIndexPatternTitleExists,
createIndexPattern,
dependentVariable,
} = state.form;

const dependentVariableEmpty = jobType === JOB_TYPES.REGRESSION && dependentVariable === '';

state.isValid =
!jobIdEmpty &&
jobIdValid &&
Expand All @@ -142,6 +166,7 @@ const validateForm = (state: State): State => {
sourceIndexNameValid &&
!destinationIndexNameEmpty &&
destinationIndexNameValid &&
!dependentVariableEmpty &&
(!destinationIndexPatternTitleExists || !createIndexPattern);

return state;
Expand Down

0 comments on commit d218e77

Please sign in to comment.