diff --git a/master/internal/api_runs.go b/master/internal/api_runs.go index af0473b6ead..adced3a7bfb 100644 --- a/master/internal/api_runs.go +++ b/master/internal/api_runs.go @@ -233,6 +233,17 @@ func sortRuns(sortString *string, runQuery *bun.SelectQuery) error { hpQuery := strings.Join(hp, "->") queryArgs = append(queryArgs, bun.Safe(sortDirection)) runQuery.OrderExpr(fmt.Sprintf(`r.hparams->%s ?`, hpQuery), queryArgs...) + case strings.HasPrefix(paramDetail[0], "metadata."): + param := strings.ReplaceAll(paramDetail[0], "'", "") + mdt := strings.Split(strings.TrimPrefix(param, "metadata."), ".") + var queryArgs []interface{} + for i := 0; i < len(mdt); i++ { + queryArgs = append(queryArgs, mdt[i]) + mdt[i] = "?" + } + mdtQuery := strings.Join(mdt, "->") + queryArgs = append(queryArgs, bun.Safe(sortDirection)) + runQuery.OrderExpr(fmt.Sprintf(`rm.metadata->%s ?`, mdtQuery), queryArgs...) case strings.Contains(paramDetail[0], "."): metricGroup, metricName, metricQualifier, err := parseMetricsName(paramDetail[0]) if err != nil { diff --git a/master/internal/api_runs_intg_test.go b/master/internal/api_runs_intg_test.go index 97acd9bef7a..40355741b44 100644 --- a/master/internal/api_runs_intg_test.go +++ b/master/internal/api_runs_intg_test.go @@ -193,35 +193,81 @@ func TestSearchRunsSort(t *testing.T) { HParams: hyperparameters2, }, task2.TaskID)) - // Sort by start time - resp, err = api.SearchRuns(ctx, &apiv1.SearchRunsRequest{ - ProjectId: req.ProjectId, - Sort: ptrs.Ptr("startTime=asc"), - }) - + // Get runs in project + resp, err = api.SearchRuns(ctx, req) require.NoError(t, err) - require.Equal(t, int32(exp.ID), resp.Runs[0].Experiment.Id) - require.Equal(t, int32(exp2.ID), resp.Runs[1].Experiment.Id) + require.Len(t, resp.Runs, 2) - // Sort by hyperparameter - resp, err = api.SearchRuns(ctx, &apiv1.SearchRunsRequest{ - ProjectId: req.ProjectId, - Sort: ptrs.Ptr("hp.global_batch_size=desc"), + // add metadata + rawMetadata := map[string]any{ + "number_key": 1, + "nested": map[string]any{ + "number_key": 1, + }, + } + metadata := newProtoStruct(t, rawMetadata) + _, err = api.PostRunMetadata(ctx, &apiv1.PostRunMetadataRequest{ + RunId: resp.Runs[0].Id, + Metadata: metadata, }) - require.NoError(t, err) - require.Equal(t, int32(exp2.ID), resp.Runs[0].Experiment.Id) - require.Equal(t, int32(exp.ID), resp.Runs[1].Experiment.Id) - // Sort by nested hyperparameter - resp, err = api.SearchRuns(ctx, &apiv1.SearchRunsRequest{ - ProjectId: req.ProjectId, - Sort: ptrs.Ptr("hp.test1.test2=desc"), + rawMetadata = map[string]any{ + "number_key": 2, + "nested": map[string]any{ + "number_key": 2, + }, + } + metadata = newProtoStruct(t, rawMetadata) + _, err = api.PostRunMetadata(ctx, &apiv1.PostRunMetadataRequest{ + RunId: resp.Runs[1].Id, + Metadata: metadata, }) - require.NoError(t, err) - require.Equal(t, int32(exp2.ID), resp.Runs[0].Experiment.Id) - require.Equal(t, int32(exp.ID), resp.Runs[1].Experiment.Id) + + tests := map[string]struct { + sortBy string + reverse bool + }{ + "StartTime": { + sortBy: "startTime=asc", + reverse: false, + }, + "Hyperparameter": { + sortBy: "hp.global_batch_size=desc", + reverse: true, + }, + "HyperparameterNested": { + sortBy: "hp.test1.test2=desc", + reverse: true, + }, + "Metadata": { + sortBy: "metadata.number_key=desc", + reverse: true, + }, + "MetadataNested": { + sortBy: "metadata.nested.number_key=desc", + reverse: true, + }, + } + + for testCase, testVars := range tests { + t.Run(testCase, func(t *testing.T) { + resp, err = api.SearchRuns(ctx, &apiv1.SearchRunsRequest{ + ProjectId: &projectID, + Sort: ptrs.Ptr(testVars.sortBy), + }) + + require.NoError(t, err) + if testVars.reverse { + require.Equal(t, int32(exp2.ID), resp.Runs[0].Experiment.Id) + require.Equal(t, int32(exp.ID), resp.Runs[1].Experiment.Id) + } else { + require.Equal(t, int32(exp.ID), resp.Runs[0].Experiment.Id) + require.Equal(t, int32(exp2.ID), resp.Runs[1].Experiment.Id) + } + }) + } } func TestSearchRunsFilter(t *testing.T) {