diff --git a/webui/react/src/components/CreateExperimentModal.tsx b/webui/react/src/components/CreateExperimentModal.tsx index 0ee31ceaaa16..acb03a79f599 100644 --- a/webui/react/src/components/CreateExperimentModal.tsx +++ b/webui/react/src/components/CreateExperimentModal.tsx @@ -17,6 +17,7 @@ interface Props { error?: string; onVisibleChange: (visible: boolean) => void; onConfigChange: (config: string) => void; + onCancel?: () => void; } const CreateExperimentModal: React.FC = ( @@ -48,6 +49,7 @@ const CreateExperimentModal: React.FC = ( }; const handleCancel = (): void => { + props.onCancel && props.onCancel(); onVisibleChange(false); }; return { - const maxLength = experimentConfig.searcher.max_length; +const getLengthFromStepCount = (config: RawJson, stepCount: number): [string, number] => { + // provide backward compat for step count + const batchesPerStep = config.batches_per_step || DEFAULT_BATCHES_PER_STEP; + return [ 'batches', stepCount * batchesPerStep ]; +}; + +const getTrialLength = (config: RawJson): [string, number] => { + const maxLength = config.searcher.max_length; if (!maxLength) { - // provide backward compat for max_steps - let maxSteps = experimentConfig.searcher.max_steps; - maxSteps = parseInt(maxSteps); - const batchesPerStep = experimentConfig.batches_per_step || DEFAULT_BATCHES_PER_STEP; - return [ 'batches', maxSteps * batchesPerStep ]; + return getLengthFromStepCount(config, config.searcher.max_steps); } return Object.entries(maxLength)[0] as [string, number]; }; @@ -48,6 +50,23 @@ const setTrialLength = (experimentConfig: RawJson, length: number): void => { delete experimentConfig.searcher.max_steps; }; +// Add opportunistic backward compatibility to old configs. +const upgradeConfig = (config: RawJson): void => { + if (config.searcher.max_steps) { + const [ key, count ] = getLengthFromStepCount(config, config.searcher.max_steps); + config.searcher.max_length = { [key]: count }; + delete config.searcher.max_steps; + } + if (typeof config.min_validation_period === 'number') { + const [ key, count ] = getLengthFromStepCount(config, config.min_validation_period); + config.searcher.min_validation_period = { [key]: count }; + } + if (typeof config.min_checkpoint_period === 'number') { + const [ key, count ] = getLengthFromStepCount(config, config.min_checkpoint_period); + config.searcher.min_checkpoint_period = { [key]: count }; + } +}; + const trialContinueConfig = (experimentConfig: RawJson, trialHparams: Record, trialId: number): RawJson => { @@ -95,26 +114,59 @@ const TrialDetailsComp: React.FC = () => { } }, [ ]); - const handleContFormVisibile = useCallback((isVisible: boolean) => { - return () => setContFormVisible(isVisible); - }, []); + const setFreshContinueConfig = useCallback(() => { + if (!experiment?.configRaw || !hparams) return; + const config = clone(experiment.configRaw); - const handleFormFinish = useCallback(({ description, maxLength }) => { - setContDescription(description); - setContMaxLength(parseInt(maxLength)); - }, [ ]); + const newDescription = `Continuation of trial ${trialId}, experiment` + + ` ${experimentId} (${experiment.configRaw.description})`; + setContDescription(newDescription); + const maxLength = getTrialLength(experiment.configRaw)[1]; + if (maxLength !== undefined) setContMaxLength(maxLength); - const handleFormCreate = useCallback(async () => { - if (!experiment || !hparams || !trialId) return; + config.description = newDescription; + if (maxLength) setTrialLength(config, maxLength); + upgradeConfig(config); + const newConfig = trialContinueConfig(config, hparams, trialId); + setContModalConfig(yaml.safeDump(newConfig)); + }, [ hparams, trialId, experiment?.configRaw, experimentId ]); + + const handleContFormOnCancel = useCallback(() => { + setContFormVisible(false); + setFreshContinueConfig(); + form.resetFields(); + }, [ setFreshContinueConfig, form ]); + + const updateStatesFromForm = useCallback(() => { + if (!hparams || !trialId) return; const formValues = form.getFieldsValue(); - handleFormFinish(formValues); try { - const expConfig = clone(experiment.configRaw); + const expConfig = yaml.safeLoad(contModalConfig) as RawJson; expConfig.description = formValues.description; setTrialLength(expConfig, parseInt(formValues.maxLength)); + const updateConfig = trialContinueConfig(expConfig, hparams, trialId); + setContModalConfig(yaml.safeDump(updateConfig)); + return updateConfig; + } catch (e) { + handleError({ + error: e, + message: 'Failed to parse experiment config', + publicMessage: 'Please check the experiment config. \ +If the problem persists please contact support.', + publicSubject: 'Failed to parse experiment config', + silent: false, + type: ErrorType.Deprecated, + }); + } + }, [ contModalConfig, form, hparams, trialId ]); + + const handleFormCreate = useCallback(async () => { + if (!experimentId) return; + const updatedConfig = updateStatesFromForm(); + try { const newExperiementId = await forkExperiment({ - experimentConfig: JSON.stringify(trialContinueConfig(expConfig, hparams, trialId)), - parentId: experiment.id, + experimentConfig: JSON.stringify(updatedConfig), + parentId: experimentId, }); routeAll(`/det/experiments/${newExperiementId}`); } catch (e) { @@ -125,14 +177,14 @@ const TrialDetailsComp: React.FC = () => { If the problem persists please contact support.', publicSubject: 'Failed to continue trial', silent: false, - type: ErrorType.Api, + type: ErrorType.Deprecated, }); - setContError(e.response.data.message); + setContError(e.response?.data?.message || e.message); setContModalVisible(true); } finally { setContFormVisible(false); } - }, [ handleFormFinish, form, experiment, hparams, trialId ]); + }, [ experimentId, updateStatesFromForm ]); const onConfigChange = useCallback( (config: string) => { setContModalConfig(config); @@ -140,13 +192,8 @@ If the problem persists please contact support.', }, []); useEffect(() => { - if (!experiment?.configRaw || !hparams) return; - const config = clone(experiment.configRaw); - if (contDescription) config.description = contDescription; - if (contMaxLength) setTrialLength(config, contMaxLength); - const newConfig = trialContinueConfig(config, hparams, trialId); try { - setContModalConfig(yaml.safeDump(newConfig)); + setFreshContinueConfig(); } catch (e) { handleError({ error: e, @@ -155,26 +202,21 @@ If the problem persists please contact support.', }); setContModalConfig('failed to load experiment config'); } - }, [ contDescription, contMaxLength, trialId, experiment?.configRaw, hparams ]); + }, [ setFreshContinueConfig ]); useEffect(() => { if (experimentId === undefined) return; getExperimentDetails({ id:experimentId }) .then(experiment => { setExperiment(experiment); - const newDescription = `Continuation of trial ${trialId}, experiment` + - ` ${experiment.id} (${experiment.config.description})`; - setContDescription(newDescription); - const maxLength = getTrialLength(experiment.configRaw)[1]; - if (maxLength !== undefined) setContMaxLength(maxLength); }); }, [ experimentId, trialId ]); const handleEditContConfig = useCallback(() => { - handleFormFinish(form.getFieldsValue()); + updateStatesFromForm(); setContFormVisible(false); setContModalVisible(true); - }, [ form, handleFormFinish ]); + }, [ updateStatesFromForm ]); if (isNaN(trialId)) { return ( @@ -226,6 +268,7 @@ If the problem persists please contact support.', parentId={experiment.id} title={`Continue Trial ${trialId}`} visible={contModalVisible} + onCancel={setFreshContinueConfig} onConfigChange={onConfigChange} onVisibleChange={setContModalVisible} /> @@ -239,7 +282,7 @@ If the problem persists please contact support.', }} title={`Continue Trial ${trialId} of Experiment ${experimentId}`} visible={contFormVisible} - onCancel={handleContFormVisibile(false)} + onCancel={handleContFormOnCancel} >