diff --git a/flyteplugins/go/tasks/plugins/webapi/bigquery/integration_test.go b/flyteplugins/go/tasks/plugins/webapi/bigquery/integration_test.go index bec6f3c2f8..6244ba152a 100644 --- a/flyteplugins/go/tasks/plugins/webapi/bigquery/integration_test.go +++ b/flyteplugins/go/tasks/plugins/webapi/bigquery/integration_test.go @@ -25,6 +25,11 @@ import ( "google.golang.org/api/bigquery/v2" ) +const ( + httpPost string = "POST" + httpGet string = "GET" +) + func TestEndToEnd(t *testing.T) { server := newFakeBigQueryServer() defer server.Close() @@ -44,19 +49,46 @@ func TestEndToEnd(t *testing.T) { plugin, err := pluginEntry.LoadPlugin(context.TODO(), newFakeSetupContext()) assert.NoError(t, err) + inputs, _ := coreutils.MakeLiteralMap(map[string]interface{}{"x": 1}) + template := flyteIdlCore.TaskTemplate{ + Type: bigqueryQueryJobTask, + Target: &flyteIdlCore.TaskTemplate_Sql{Sql: &flyteIdlCore.Sql{Statement: "SELECT 1", Dialect: flyteIdlCore.Sql_ANSI}}, + } + t.Run("SELECT 1", func(t *testing.T) { queryJobConfig := QueryJobConfig{ ProjectID: "flyte", } - inputs, _ := coreutils.MakeLiteralMap(map[string]interface{}{"x": 1}) custom, _ := pluginUtils.MarshalObjToStruct(queryJobConfig) - template := flyteIdlCore.TaskTemplate{ - Type: bigqueryQueryJobTask, - Custom: custom, - Target: &flyteIdlCore.TaskTemplate_Sql{Sql: &flyteIdlCore.Sql{Statement: "SELECT 1", Dialect: flyteIdlCore.Sql_ANSI}}, + template.Custom = custom + + phase := tests.RunPluginEndToEndTest(t, plugin, &template, inputs, nil, nil, iter) + + assert.Equal(t, true, phase.Phase().IsSuccess()) + }) + + t.Run("cache job result", func(t *testing.T) { + queryJobConfig := QueryJobConfig{ + ProjectID: "cache", + } + + custom, _ := pluginUtils.MarshalObjToStruct(queryJobConfig) + template.Custom = custom + + phase := tests.RunPluginEndToEndTest(t, plugin, &template, inputs, nil, nil, iter) + + assert.Equal(t, true, phase.Phase().IsSuccess()) + }) + + t.Run("pending job", func(t *testing.T) { + queryJobConfig := QueryJobConfig{ + ProjectID: "pending", } + custom, _ := pluginUtils.MarshalObjToStruct(queryJobConfig) + template.Custom = custom + phase := tests.RunPluginEndToEndTest(t, plugin, &template, inputs, nil, nil, iter) assert.Equal(t, true, phase.Phase().IsSuccess()) @@ -65,17 +97,17 @@ func TestEndToEnd(t *testing.T) { func newFakeBigQueryServer() *httptest.Server { return httptest.NewServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { - if request.URL.Path == "/projects/flyte/jobs" && request.Method == "POST" { + if request.URL.Path == "/projects/flyte/jobs" && request.Method == httpPost { writer.WriteHeader(200) - job := bigquery.Job{Status: &bigquery.JobStatus{State: "RUNNING"}} + job := bigquery.Job{Status: &bigquery.JobStatus{State: bigqueryStatusRunning}} bytes, _ := json.Marshal(job) _, _ = writer.Write(bytes) return } - if strings.HasPrefix(request.URL.Path, "/projects/flyte/jobs/") && request.Method == "GET" { + if strings.HasPrefix(request.URL.Path, "/projects/flyte/jobs/") && request.Method == httpGet { writer.WriteHeader(200) - job := bigquery.Job{Status: &bigquery.JobStatus{State: "DONE"}, + job := bigquery.Job{Status: &bigquery.JobStatus{State: bigqueryStatusDone}, Configuration: &bigquery.JobConfiguration{ Query: &bigquery.JobConfigurationQuery{ DestinationTable: &bigquery.TableReference{ @@ -85,6 +117,42 @@ func newFakeBigQueryServer() *httptest.Server { return } + if request.URL.Path == "/projects/cache/jobs" && request.Method == httpPost { + writer.WriteHeader(200) + job := bigquery.Job{Status: &bigquery.JobStatus{State: bigqueryStatusDone}} + bytes, _ := json.Marshal(job) + _, _ = writer.Write(bytes) + return + } + + if strings.HasPrefix(request.URL.Path, "/projects/cache/jobs/") && request.Method == httpGet { + writer.WriteHeader(200) + job := bigquery.Job{Status: &bigquery.JobStatus{State: bigqueryStatusDone}, + Configuration: &bigquery.JobConfiguration{ + Query: &bigquery.JobConfigurationQuery{ + DestinationTable: &bigquery.TableReference{ + ProjectId: "project", DatasetId: "dataset", TableId: "table"}}}} + bytes, _ := json.Marshal(job) + _, _ = writer.Write(bytes) + return + } + + if request.URL.Path == "/projects/pending/jobs" && request.Method == httpPost { + writer.WriteHeader(200) + job := bigquery.Job{Status: &bigquery.JobStatus{State: bigqueryStatusPending}} + bytes, _ := json.Marshal(job) + _, _ = writer.Write(bytes) + return + } + + if strings.HasPrefix(request.URL.Path, "/projects/pending/jobs/") && request.Method == httpGet { + writer.WriteHeader(200) + job := bigquery.Job{Status: &bigquery.JobStatus{State: bigqueryStatusDone}} + bytes, _ := json.Marshal(job) + _, _ = writer.Write(bytes) + return + } + writer.WriteHeader(500) })) } diff --git a/flyteplugins/go/tasks/plugins/webapi/bigquery/plugin.go b/flyteplugins/go/tasks/plugins/webapi/bigquery/plugin.go index 202937c233..bc1f6df83f 100644 --- a/flyteplugins/go/tasks/plugins/webapi/bigquery/plugin.go +++ b/flyteplugins/go/tasks/plugins/webapi/bigquery/plugin.go @@ -33,8 +33,11 @@ import ( ) const ( - bigqueryQueryJobTask = "bigquery_query_job_task" - bigqueryConsolePath = "https://console.cloud.google.com/bigquery" + bigqueryQueryJobTask = "bigquery_query_job_task" + bigqueryConsolePath = "https://console.cloud.google.com/bigquery" + bigqueryStatusRunning = "RUNNING" + bigqueryStatusPending = "PENDING" + bigqueryStatusDone = "DONE" ) type Plugin struct { @@ -148,7 +151,22 @@ func (p Plugin) createImpl(ctx context.Context, taskCtx webapi.TaskExecutionCont return nil, nil, pluginErrors.Wrapf(pluginErrors.RuntimeFailure, err, "failed to create query job") } - resource := ResourceWrapper{Status: resp.Status} + var outputLocation string + if resp.Status != nil && resp.Status.State == bigqueryStatusDone { + getResp, err := client.Jobs.Get(job.JobReference.ProjectId, job.JobReference.JobId).Do() + + if err != nil { + err := pluginErrors.Wrapf( + pluginErrors.RuntimeFailure, + err, + "failed to get job [%s]", + formatJobReference(*job.JobReference)) + + return nil, nil, err + } + outputLocation = constructOutputLocation(ctx, getResp) + } + resource := ResourceWrapper{Status: resp.Status, OutputLocation: outputLocation} resourceMeta := ResourceMetaWrapper{ JobReference: *job.JobReference, Namespace: namespace, @@ -214,9 +232,7 @@ func (p Plugin) getImpl(ctx context.Context, taskCtx webapi.GetContext) (wrapper return nil, err } - dst := job.Configuration.Query.DestinationTable - outputLocation := fmt.Sprintf("bq://%v:%v.%v", dst.ProjectId, dst.DatasetId, dst.TableId) - + outputLocation := constructOutputLocation(ctx, job) return &ResourceWrapper{ Status: job.Status, OutputLocation: outputLocation, @@ -267,13 +283,13 @@ func (p Plugin) Status(ctx context.Context, tCtx webapi.StatusContext) (phase co } switch resource.Status.State { - case "PENDING": + case bigqueryStatusPending: return core.PhaseInfoQueuedWithTaskInfo(version, "Query is PENDING", taskInfo), nil - case "RUNNING": + case bigqueryStatusRunning: return core.PhaseInfoRunning(version, taskInfo), nil - case "DONE": + case bigqueryStatusDone: if resource.Status.ErrorResult != nil { return handleErrorResult( resource.Status.ErrorResult.Reason, @@ -291,6 +307,16 @@ func (p Plugin) Status(ctx context.Context, tCtx webapi.StatusContext) (phase co return core.PhaseInfoUndefined, pluginErrors.Errorf(pluginsCore.SystemErrorCode, "unknown execution phase [%v].", resource.Status.State) } +func constructOutputLocation(ctx context.Context, job *bigquery.Job) string { + if job == nil || job.Configuration == nil || job.Configuration.Query == nil || job.Configuration.Query.DestinationTable == nil { + return "" + } + dst := job.Configuration.Query.DestinationTable + outputLocation := fmt.Sprintf("bq://%v:%v.%v", dst.ProjectId, dst.DatasetId, dst.TableId) + logger.Debugf(ctx, "BigQuery saves query results to [%v]", outputLocation) + return outputLocation +} + func writeOutput(ctx context.Context, tCtx webapi.StatusContext, OutputLocation string) error { taskTemplate, err := tCtx.TaskReader().Read(ctx) if err != nil { diff --git a/flyteplugins/go/tasks/plugins/webapi/bigquery/plugin_test.go b/flyteplugins/go/tasks/plugins/webapi/bigquery/plugin_test.go index d39d844184..151ab6d2d7 100644 --- a/flyteplugins/go/tasks/plugins/webapi/bigquery/plugin_test.go +++ b/flyteplugins/go/tasks/plugins/webapi/bigquery/plugin_test.go @@ -42,6 +42,26 @@ func TestFormatJobReference(t *testing.T) { }) } +func TestConstructOutputLocation(t *testing.T) { + job := &bigquery.Job{ + Configuration: &bigquery.JobConfiguration{ + Query: &bigquery.JobConfigurationQuery{ + DestinationTable: &bigquery.TableReference{ + ProjectId: "project", + DatasetId: "dataset", + TableId: "table", + }, + }, + }, + } + ol := constructOutputLocation(context.Background(), job) + assert.Equal(t, ol, "bq://project:dataset.table") + + job.Configuration.Query.DestinationTable = nil + ol = constructOutputLocation(context.Background(), job) + assert.Equal(t, ol, "") +} + func TestCreateTaskInfo(t *testing.T) { t.Run("create task info", func(t *testing.T) { resourceMeta := ResourceMetaWrapper{