Skip to content

Commit

Permalink
feat(frontend&backend): Add UI support for object store customization…
Browse files Browse the repository at this point in the history
… and prefixes (#10787)

* feat: add store session info to artifact property

Signed-off-by: Humair Khan <[email protected]>

* chore: fix tests for api, artifact load, & runlist

Note that for runlist, the mock error response only returns one valid
run list with error set, the other is undefined, so to support multiple
runIds the mock error response will need to be adjusted.

Signed-off-by: Humair Khan <[email protected]>

* chore: support protocols in aws s3 endpoint config

Signed-off-by: Humair Khan <[email protected]>

* feat(ui): allow ui server to parse provider info

Signed-off-by: Humair Khan <[email protected]>

* feat(ui): parse artifact provider info in ui

Signed-off-by: Humair Khan <[email protected]>

* chore: add tests for provider info

Signed-off-by: Humair Khan <[email protected]>

* chore: update ec2 tests

..and clean up imports.

Signed-off-by: Humair Khan <[email protected]>

* chore: prettier fixes

Signed-off-by: Humair Khan <[email protected]>

---------

Signed-off-by: Humair Khan <[email protected]>
  • Loading branch information
HumairAK authored Jun 25, 2024
1 parent 991a610 commit 6723d3d
Show file tree
Hide file tree
Showing 27 changed files with 731 additions and 152 deletions.
20 changes: 20 additions & 0 deletions backend/src/v2/component/importer_launcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ package component

import (
"context"
"encoding/json"
"fmt"
"github.com/kubeflow/pipelines/backend/src/v2/objectstore"

pb "github.com/kubeflow/pipelines/third_party/ml-metadata/go/ml_metadata"

Expand Down Expand Up @@ -225,6 +227,10 @@ func (l *ImportLauncher) ImportSpecToMLMDArtifact(ctx context.Context) (artifact

state := pb.Artifact_LIVE

provider, err := objectstore.ParseProviderFromPath(artifactUri)
if err != nil {
return nil, fmt.Errorf("No Provider scheme found in artifact Uri: %s", artifactUri)
}
artifact = &pb.Artifact{
TypeId: &artifactTypeId,
State: &state,
Expand All @@ -241,6 +247,20 @@ func (l *ImportLauncher) ImportSpecToMLMDArtifact(ctx context.Context) (artifact
artifact.CustomProperties[k] = value
}
}

// Assume all imported artifacts will rely on execution environment for store provider session info
storeSessionInfo := objectstore.SessionInfo{
Provider: provider,
Params: map[string]string{
"fromEnv": "true",
},
}
storeSessionInfoJSON, err := json.Marshal(storeSessionInfo)
if err != nil {
return nil, err
}
storeSessionInfoStr := string(storeSessionInfoJSON)
artifact.CustomProperties["store_session_info"] = metadata.StringValue(storeSessionInfoStr)
return artifact, nil
}

Expand Down
2 changes: 1 addition & 1 deletion backend/src/v2/component/launcher_v2.go
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,7 @@ func uploadOutputArtifacts(ctx context.Context, executorInput *pipelinespec.Exec
if err != nil {
return nil, fmt.Errorf("failed to determine schema for output %q: %w", name, err)
}
mlmdArtifact, err := opts.metadataClient.RecordArtifact(ctx, name, schema, outputArtifact, pb.Artifact_LIVE)
mlmdArtifact, err := opts.metadataClient.RecordArtifact(ctx, name, schema, outputArtifact, pb.Artifact_LIVE, opts.bucketConfig)
if err != nil {
return nil, metadataErr(err)
}
Expand Down
3 changes: 1 addition & 2 deletions backend/src/v2/config/env.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,11 +102,10 @@ func InPodName() (string, error) {
}

func (c *Config) GetStoreSessionInfo(path string) (objectstore.SessionInfo, error) {
bucketConfig, err := objectstore.ParseBucketPathToConfig(path)
provider, err := objectstore.ParseProviderFromPath(path)
if err != nil {
return objectstore.SessionInfo{}, err
}
provider := strings.TrimSuffix(bucketConfig.Scheme, "://")
bucketProviders, err := c.getBucketProviders()
if err != nil {
return objectstore.SessionInfo{}, err
Expand Down
53 changes: 34 additions & 19 deletions backend/src/v2/metadata/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@ package metadata

import (
"context"
"encoding/json"
"errors"
"fmt"
"github.com/kubeflow/pipelines/backend/src/v2/objectstore"
"path"
"strconv"
"strings"
Expand Down Expand Up @@ -90,7 +92,7 @@ type ClientInterface interface {
GetArtifactName(ctx context.Context, artifactId int64) (string, error)
GetArtifacts(ctx context.Context, ids []int64) ([]*pb.Artifact, error)
GetOutputArtifactsByExecutionId(ctx context.Context, executionId int64) (map[string]*OutputArtifact, error)
RecordArtifact(ctx context.Context, outputName, schema string, runtimeArtifact *pipelinespec.RuntimeArtifact, state pb.Artifact_State) (*OutputArtifact, error)
RecordArtifact(ctx context.Context, outputName, schema string, runtimeArtifact *pipelinespec.RuntimeArtifact, state pb.Artifact_State, bucketConfig *objectstore.Config) (*OutputArtifact, error)
GetOrInsertArtifactType(ctx context.Context, schema string) (typeID int64, err error)
FindMatchedArtifact(ctx context.Context, artifactToMatch *pb.Artifact, pipelineContextId int64) (matchedArtifact *pb.Artifact, err error)
}
Expand Down Expand Up @@ -301,11 +303,11 @@ func (c *Client) GetPipeline(ctx context.Context, pipelineName, runID, namespace
}
glog.Infof("Pipeline Context: %+v", pipelineContext)
metadata := map[string]*pb.Value{
keyNamespace: stringValue(namespace),
keyResourceName: stringValue(runResource),
keyNamespace: StringValue(namespace),
keyResourceName: StringValue(runResource),
// pipeline root of this run
keyPipelineRoot: stringValue(GenerateOutputURI(pipelineRoot, []string{pipelineName, runID}, true)),
keyStoreSessionInfo: stringValue(storeSessionInfo),
keyPipelineRoot: StringValue(GenerateOutputURI(pipelineRoot, []string{pipelineName, runID}, true)),
keyStoreSessionInfo: StringValue(storeSessionInfo),
}
runContext, err := c.getOrInsertContext(ctx, runID, pipelineRunContextType, metadata)
glog.Infof("Pipeline Run Context: %+v", runContext)
Expand Down Expand Up @@ -401,7 +403,7 @@ func (c *Client) getExecutionTypeID(ctx context.Context, executionType *pb.Execu
return eType.GetTypeId(), nil
}

func stringValue(s string) *pb.Value {
func StringValue(s string) *pb.Value {
return &pb.Value{Value: &pb.Value_StringValue{StringValue: s}}
}

Expand Down Expand Up @@ -531,8 +533,8 @@ func (c *Client) CreateExecution(ctx context.Context, pipeline *Pipeline, config
TypeId: &typeID,
CustomProperties: map[string]*pb.Value{
// We should support overriding display name in the future, for now it defaults to task name.
keyDisplayName: stringValue(config.TaskName),
keyTaskName: stringValue(config.TaskName),
keyDisplayName: StringValue(config.TaskName),
keyTaskName: StringValue(config.TaskName),
},
}
if config.Name != "" {
Expand All @@ -555,15 +557,15 @@ func (c *Client) CreateExecution(ctx context.Context, pipeline *Pipeline, config
e.CustomProperties[keyIterationCount] = intValue(int64(*config.IterationCount))
}
if config.ExecutionType == ContainerExecutionTypeName {
e.CustomProperties[keyPodName] = stringValue(config.PodName)
e.CustomProperties[keyPodUID] = stringValue(config.PodUID)
e.CustomProperties[keyNamespace] = stringValue(config.Namespace)
e.CustomProperties[keyImage] = stringValue(config.Image)
e.CustomProperties[keyPodName] = StringValue(config.PodName)
e.CustomProperties[keyPodUID] = StringValue(config.PodUID)
e.CustomProperties[keyNamespace] = StringValue(config.Namespace)
e.CustomProperties[keyImage] = StringValue(config.Image)
if config.CachedMLMDExecutionID != "" {
e.CustomProperties[keyCachedExecutionID] = stringValue(config.CachedMLMDExecutionID)
e.CustomProperties[keyCachedExecutionID] = StringValue(config.CachedMLMDExecutionID)
}
if config.FingerPrint != "" {
e.CustomProperties[keyCacheFingerPrint] = stringValue(config.FingerPrint)
e.CustomProperties[keyCacheFingerPrint] = StringValue(config.FingerPrint)
}
}
if config.InputParameters != nil {
Expand Down Expand Up @@ -623,9 +625,9 @@ func (c *Client) PrePublishExecution(ctx context.Context, execution *Execution,
if e.CustomProperties == nil {
e.CustomProperties = make(map[string]*pb.Value)
}
e.CustomProperties[keyPodName] = stringValue(config.PodName)
e.CustomProperties[keyPodUID] = stringValue(config.PodUID)
e.CustomProperties[keyNamespace] = stringValue(config.Namespace)
e.CustomProperties[keyPodName] = StringValue(config.PodName)
e.CustomProperties[keyPodUID] = StringValue(config.PodUID)
e.CustomProperties[keyNamespace] = StringValue(config.Namespace)
e.LastKnownState = pb.Execution_RUNNING.Enum()

_, err := c.svc.PutExecution(ctx, &pb.PutExecutionRequest{
Expand Down Expand Up @@ -889,7 +891,7 @@ func SchemaToArtifactType(schema string) (*pb.ArtifactType, error) {
}

// RecordArtifact ...
func (c *Client) RecordArtifact(ctx context.Context, outputName, schema string, runtimeArtifact *pipelinespec.RuntimeArtifact, state pb.Artifact_State) (*OutputArtifact, error) {
func (c *Client) RecordArtifact(ctx context.Context, outputName, schema string, runtimeArtifact *pipelinespec.RuntimeArtifact, state pb.Artifact_State, bucketConfig *objectstore.Config) (*OutputArtifact, error) {
artifact, err := toMLMDArtifact(runtimeArtifact)
if err != nil {
return nil, err
Expand All @@ -911,7 +913,20 @@ func (c *Client) RecordArtifact(ctx context.Context, outputName, schema string,
}
if _, ok := artifact.CustomProperties["display_name"]; !ok {
// display name default value
artifact.CustomProperties["display_name"] = stringValue(outputName)
artifact.CustomProperties["display_name"] = StringValue(outputName)
}

// An artifact can belong to an external store specified via kfp-launcher
// or via executor environment (e.g. IRSA)
// This allows us to easily identify where to locate the artifact both
// in user executor environment as well as in kfp ui
if _, ok := artifact.CustomProperties["store_session_info"]; !ok {
storeSessionInfoJSON, err1 := json.Marshal(bucketConfig.SessionInfo)
if err1 != nil {
return nil, err1
}
storeSessionInfoStr := string(storeSessionInfoJSON)
artifact.CustomProperties["store_session_info"] = StringValue(storeSessionInfoStr)
}

res, err := c.svc.PutArtifacts(ctx, &pb.PutArtifactsRequest{
Expand Down
3 changes: 2 additions & 1 deletion backend/src/v2/metadata/client_fake.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package metadata

import (
"context"
"github.com/kubeflow/pipelines/backend/src/v2/objectstore"

"github.com/kubeflow/pipelines/api/v2alpha1/go/pipelinespec"
pb "github.com/kubeflow/pipelines/third_party/ml-metadata/go/ml_metadata"
Expand Down Expand Up @@ -82,7 +83,7 @@ func (c *FakeClient) GetOutputArtifactsByExecutionId(ctx context.Context, execut
return nil, nil
}

func (c *FakeClient) RecordArtifact(ctx context.Context, outputName, schema string, runtimeArtifact *pipelinespec.RuntimeArtifact, state pb.Artifact_State) (*OutputArtifact, error) {
func (c *FakeClient) RecordArtifact(ctx context.Context, outputName, schema string, runtimeArtifact *pipelinespec.RuntimeArtifact, state pb.Artifact_State, bucketConfig *objectstore.Config) (*OutputArtifact, error) {
return nil, nil
}

Expand Down
10 changes: 10 additions & 0 deletions backend/src/v2/objectstore/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,16 @@ func ParseBucketConfigForArtifactURI(uri string) (*Config, error) {
}, nil
}

// ParseProviderFromPath prases the uri and returns the scheme, which is
// used as the Provider string
func ParseProviderFromPath(uri string) (string, error) {
bucketConfig, err := ParseBucketPathToConfig(uri)
if err != nil {
return "", err
}
return strings.TrimSuffix(bucketConfig.Scheme, "://"), nil
}

func MinioDefaultEndpoint() string {
// Discover minio-service in the same namespace by env var.
// https://kubernetes.io/docs/concepts/services-networking/service/#environment-variables
Expand Down
6 changes: 4 additions & 2 deletions backend/src/v2/objectstore/object_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import (
"k8s.io/client-go/kubernetes"
"os"
"path/filepath"
"regexp"
"strings"
)

Expand Down Expand Up @@ -261,12 +262,13 @@ func createS3BucketSession(ctx context.Context, namespace string, sessionInfo *S

// AWS Specific:
// Path-style S3 endpoints, which are commonly used, may fall into either of two subdomains:
// 1) s3.amazonaws.com
// 1) [https://]s3.amazonaws.com
// 2) s3.<AWS Region>.amazonaws.com
// for (1) the endpoint is not required, thus we skip it, otherwise the writer will fail to close due to region mismatch.
// https://aws.amazon.com/blogs/infrastructure-and-automation/best-practices-for-using-amazon-s3-endpoints-in-aws-cloudformation-templates/
// https://docs.aws.amazon.com/sdk-for-go/api/aws/session/
if strings.ToLower(params.Endpoint) != "s3.amazonaws.com" {
awsEndpoint, _ := regexp.MatchString(`^(https://)?s3.amazonaws.com`, strings.ToLower(params.Endpoint))
if !awsEndpoint {
config.Endpoint = aws.String(params.Endpoint)
}

Expand Down
16 changes: 8 additions & 8 deletions frontend/server/aws-helper.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
import fetch from 'node-fetch';
import { awsInstanceProfileCredentials, isS3Endpoint } from './aws-helper';
import { awsInstanceProfileCredentials, isAWSS3Endpoint } from './aws-helper';

// mock node-fetch module
jest.mock('node-fetch');
Expand Down Expand Up @@ -107,30 +107,30 @@ describe('awsInstanceProfileCredentials', () => {

describe('isS3Endpoint', () => {
it('checks a valid s3 endpoint', () => {
expect(isS3Endpoint('s3.amazonaws.com')).toBe(true);
expect(isAWSS3Endpoint('s3.amazonaws.com')).toBe(true);
});

it('checks a valid s3 regional endpoint', () => {
expect(isS3Endpoint('s3.dualstack.us-east-1.amazonaws.com')).toBe(true);
expect(isAWSS3Endpoint('s3.dualstack.us-east-1.amazonaws.com')).toBe(true);
});

it('checks a valid s3 cn endpoint', () => {
expect(isS3Endpoint('s3.cn-north-1.amazonaws.com.cn')).toBe(true);
expect(isAWSS3Endpoint('s3.cn-north-1.amazonaws.com.cn')).toBe(true);
});

it('checks a valid s3 fips GovCloud endpoint', () => {
expect(isS3Endpoint('s3-fips.us-gov-west-1.amazonaws.com')).toBe(true);
expect(isAWSS3Endpoint('s3-fips.us-gov-west-1.amazonaws.com')).toBe(true);
});

it('checks a valid s3 PrivateLink endpoint', () => {
expect(isS3Endpoint('vpce-1a2b3c4d-5e6f.s3.us-east-1.vpce.amazonaws.com')).toBe(true);
expect(isAWSS3Endpoint('vpce-1a2b3c4d-5e6f.s3.us-east-1.vpce.amazonaws.com')).toBe(true);
});

it('checks an invalid s3 endpoint', () => {
expect(isS3Endpoint('amazonaws.com')).toBe(false);
expect(isAWSS3Endpoint('amazonaws.com')).toBe(false);
});

it('checks non-s3 endpoint', () => {
expect(isS3Endpoint('minio.kubeflow')).toBe(false);
expect(isAWSS3Endpoint('minio.kubeflow')).toBe(false);
});
});
2 changes: 1 addition & 1 deletion frontend/server/aws-helper.ts
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ async function getIAMInstanceProfile(): Promise<string | undefined> {
*
* @param endpoint minio endpoint to check.
*/
export function isS3Endpoint(endpoint: string = ''): boolean {
export function isAWSS3Endpoint(endpoint: string = ''): boolean {
return !!endpoint.match(/s3.{0,}\.amazonaws\.com\.?.{0,}/i);
}

Expand Down
Loading

0 comments on commit 6723d3d

Please sign in to comment.