diff --git a/ax/storage/sqa_store/save.py b/ax/storage/sqa_store/save.py index dd8798497a2..19c662699e3 100644 --- a/ax/storage/sqa_store/save.py +++ b/ax/storage/sqa_store/save.py @@ -343,10 +343,15 @@ def _update_generation_strategy( "should be saved before generation strategy." ) + curr_index = ( + generation_strategy.current_node_name + if generation_strategy.is_node_based + else generation_strategy.current_step_index + ) with session_scope() as session: session.query(gs_sqa_class).filter_by(id=gs_id).update( { - "curr_index": generation_strategy.current_step_index, + "curr_index": curr_index, "experiment_id": experiment_id, } )