diff --git a/flyteplugins/go/tasks/plugins/k8s/spark/spark.go b/flyteplugins/go/tasks/plugins/k8s/spark/spark.go index b95a689811..4dd2cff4c2 100755 --- a/flyteplugins/go/tasks/plugins/k8s/spark/spark.go +++ b/flyteplugins/go/tasks/plugins/k8s/spark/spark.go @@ -217,24 +217,27 @@ func (sparkResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsCo func addConfig(sparkConfig map[string]string, key string, value string) { if strings.ToLower(strings.TrimSpace(value)) != "true" { + sparkConfig[key] = value return } matches := featureRegex.FindAllStringSubmatch(key, -1) if len(matches) == 0 || len(matches[0]) == 0 { + sparkConfig[key] = value return } featureName := matches[0][len(matches[0])-1] + // Use the first matching feature in-case of duplicates. for _, feature := range GetSparkConfig().Features { if feature.Name == featureName { for k, v := range feature.SparkConfig { sparkConfig[k] = v } - break + return } - } + sparkConfig[key] = value } // Convert SparkJob ApplicationType to Operator CRD ApplicationType diff --git a/flyteplugins/go/tasks/plugins/k8s/spark/spark_test.go b/flyteplugins/go/tasks/plugins/k8s/spark/spark_test.go index fc3586c11b..1cebcc6cab 100755 --- a/flyteplugins/go/tasks/plugins/k8s/spark/spark_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/spark/spark_test.go @@ -45,6 +45,7 @@ var ( "spark.executor.memory": "500M", "spark.flyte.feature1.enabled": "true", "spark.lyft.feature2.enabled": "true", + "spark.lyft.feature3.enabled": "true", } dummyEnvVars = []*core.KeyValuePair{ @@ -391,7 +392,7 @@ func TestBuildResourceSpark(t *testing.T) { for confKey, confVal := range dummySparkConf { exists := false - if featureRegex.MatchString(confKey) { + if featureRegex.MatchString(confKey) && confKey != "spark.lyft.feature3.enabled" { match := featureRegex.FindAllStringSubmatch(confKey, -1) feature := match[0][len(match[0])-1] assert.True(t, feature == "feature1" || feature == "feature2") @@ -417,6 +418,7 @@ func TestBuildResourceSpark(t *testing.T) { assert.Equal(t, dummySparkConf["spark.driver.cores"], sparkApp.Spec.SparkConf["spark.kubernetes.driver.limit.cores"]) assert.Equal(t, dummySparkConf["spark.executor.cores"], sparkApp.Spec.SparkConf["spark.kubernetes.executor.limit.cores"]) assert.Greater(t, len(sparkApp.Spec.SparkConf["spark.kubernetes.driverEnv.FLYTE_START_TIME"]), 1) + assert.Equal(t, dummySparkConf["spark.lyft.feature3.enabled"], sparkApp.Spec.SparkConf["spark.lyft.feature3.enabled"]) assert.Equal(t, len(sparkApp.Spec.Driver.EnvVars["FLYTE_MAX_ATTEMPTS"]), 1)