Skip to content

Commit

Permalink
Unify vault token (apache#116)
Browse files Browse the repository at this point in the history
* Vault variables Unified

* Unified Token
  • Loading branch information
Gschiavon authored Dec 27, 2017
1 parent e550a59 commit 9d2a96b
Show file tree
Hide file tree
Showing 14 changed files with 123 additions and 244 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# Changelog

## 2.2.0.3 (upcoming)

* Unify Vault variables

## 2.2.0.2 (December 26, 2017)

* Added mesos constraints management to spark driver
Expand Down
52 changes: 2 additions & 50 deletions core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
Original file line number Diff line number Diff line change
Expand Up @@ -709,18 +709,9 @@ object SparkSubmit extends CommandLineUtils {
case _ => None
}

val vaultProtocol = args.sparkProperties.get("spark.secret.vault.protocol")
val vaultHost = args.sparkProperties.get("spark.secret.vault.hosts")
val vaultPort = args.sparkProperties.get("spark.secret.vault.port")

val vaultUrlParams = (vaultProtocol, vaultHost, vaultPort)
val vaultUrl = buildVaultUrl(vaultUrlParams)
lazy val vaultToken = getToken(tempToken, roleSecret, vaultUrl)

val (principal, keytab) =
if (vaultUrl.nonEmpty && vaultToken.isDefined) {
val environment = ConfigSecurity.prepareEnvironment(
Option (vaultToken.get), Option(vaultUrl))
if (ConfigSecurity.vaultURI.isDefined) {
val environment = ConfigSecurity.prepareEnvironment
val principal = environment.getOrElse("principal", args.principal)
val keytab = environment.getOrElse("keytabPath", args.keytab)

Expand All @@ -736,45 +727,6 @@ object SparkSubmit extends CommandLineUtils {
(childArgs, childClasspath, sysProps, childMainClass, principal, keytab)
}

/**
*
* @param tempToken Temporal token, either Property one or Environment one
* @param roleSecret Role and Secret ID, either Property one or Environment one
* @param vaultUrl a Vault Url protocol://vaultHost:vaultPort
* @return An option of a token
*/
private def getToken(tempToken: Option[String],
roleSecret: Option[(String, String)],
vaultUrl: String): Option[String] = {

(tempToken, roleSecret) match {
case (Some(tempToken), _) => Some(VaultHelper.getRealToken(vaultUrl, tempToken))
case (_, Some((role, secret))) =>
Some(VaultHelper.getTokenFromAppRole(vaultUrl, role, secret))
case _ => None
}
}

/**
*
* @param vaultUrlParams Is composed of Vault Protocol,
* Vault Host and Vault Port
* @return a Vault Url protocol://vaultHost:vaultPort
*/
private def buildVaultUrl(vaultUrlParams: (Option[String],
Option[String],
Option[String])): String = {

val vaultUrl = vaultUrlParams match {
case (Some(protocol), Some(hosts), Some(port)) =>
s"${protocol}://${
hosts.split(",")
.map(host => s"$host:${port}").mkString(",")}"
case _ => ""
}
vaultUrl
}

/**
* Run the main method of the child class using the provided launch environment.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ object HistoryServer extends Logging {

def main(argStrings: Array[String]): Unit = {
Utils.initDaemon(log)
ConfigSecurity.prepareEnvironment()
ConfigSecurity.prepareEnvironment
new HistoryServerArguments(conf, argStrings)
initSecurity()
val securityManager = createSecurityManager(conf)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -286,8 +286,7 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging {
appId == null) {
printUsageAndExit()
}
ConfigSecurity.prepareEnvironment(scala.util.Try{
VaultHelper.getRealToken(sys.env("VAULT_URI"), sys.env("VAULT_TEMP_TOKEN"))}.toOption)
ConfigSecurity.prepareEnvironment

run(driverUrl, executorId, hostname, cores, appId, workerUrl, userClassPath)
System.exit(0)
Expand Down
77 changes: 41 additions & 36 deletions core/src/main/scala/org/apache/spark/security/ConfigSecurity.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,42 +20,50 @@ import scala.util.{Failure, Success, Try}

import org.apache.spark.internal.Logging

object ConfigSecurity extends Logging{

var vaultToken: Option[String] = None
val vaultUri: Option[String] = getVaultUri(sys.env.get("VAULT_PROTOCOL"),
sys.env.get("VAULT_HOSTS"), sys.env.get("VAULT_PORT"))

def getVaultUri(vaultProtocol: Option[String],
vaultHost: Option[String],
vaultPort: Option[String]): Option[String] = {
(vaultProtocol, vaultHost, vaultPort) match {
case (Some (vaultProtocol), Some (vaultHost), Some (vaultPort) ) =>
val vaultUri = s"$vaultProtocol://$vaultHost:$vaultPort"
logDebug (s"vault uri: $vaultUri found, any Vault Connection will use it")
Option (vaultUri)
case _ =>
logDebug ("No Vault information found, any Vault Connection will fail")
None
}
object ConfigSecurity extends Logging {

lazy val vaultToken: Option[String] =

if (sys.env.get("VAULT_TOKEN").isDefined) {
logDebug("Obtaining vault token using VAULT_TOKEN")
sys.env.get("VAULT_TOKEN")
} else if (sys.env.get("VAULT_TEMP_TOKEN").isDefined) {
logDebug("Obtaining vault token using VAULT_TEMP_TOKEN")
scala.util.Try {
VaultHelper.getRealToken(sys.env.get("VAULT_TEMP_TOKEN"))
}.toOption
} else if (sys.env.get("VAULT_ROLE_ID").isDefined && sys.env.get("VAULT_SECRET_ID").isDefined) {
logDebug("Obtaining vault token using ROLE_ID and SECRET_ID")
Option(VaultHelper.getTokenFromAppRole(
sys.env("VAULT_ROLE_ID"),
sys.env("VAULT_SECRET_ID")))
} else {
logInfo("No Vault token variables provided. Skipping Vault token retrieving")
None
}

def prepareEnvironment(vaultAppToken: Option[String] = None,
vaulHost: Option[String] = None): Map[String, String] = {
lazy val vaultURI: Option[String] = {
if (sys.env.get("VAULT_PROTOCOL").isDefined
&& sys.env.get("VAULT_HOSTS").isDefined
&& sys.env.get("VAULT_PORT").isDefined) {
val vaultProtocol = sys.env.get("VAULT_PROTOCOL").get
val vaultHost = sys.env.get("VAULT_HOSTS").get
val vaultPort = sys.env.get("VAULT_PORT").get
Option(s"$vaultProtocol://$vaultHost:$vaultPort")
} else {
logInfo("No Vault variables provided")
None
}
}

def prepareEnvironment: Map[String, String] = {

logDebug(s"env VAR: ${sys.env.mkString("\n")}")
val secretOptionsMap = ConfigSecurity.extractSecretFromEnv(sys.env)
logDebug(s"secretOptionsMap: ${secretOptionsMap.mkString("\n")}")
loadingConf(secretOptionsMap)
vaultToken = if (vaultAppToken.isDefined) {
vaultAppToken
} else sys.env.get("VAULT_TOKEN")
if(vaultToken.isDefined) {
require(vaultUri.isDefined, "A proper vault host is required")
logDebug(s"env VAR: ${sys.env.mkString("\n")}")
prepareEnvironment(vaultUri.get, vaultToken.get, secretOptionsMap)
}
else Map()
prepareEnvironment(secretOptionsMap)

}


Expand Down Expand Up @@ -101,18 +109,15 @@ object ConfigSecurity extends Logging{
}
}

private def prepareEnvironment(vaultHost: String,
vaultToken: String,
secretOptions: Map[String,
private def prepareEnvironment(secretOptions: Map[String,
Map[String, String]]): Map[String, String] =
secretOptions flatMap {
case ("kerberos", options) =>
KerberosConfig.prepareEnviroment(vaultHost, vaultToken, options)
KerberosConfig.prepareEnviroment(options)
case ("datastore", options) =>
SSLConfig.prepareEnvironment(
vaultHost, vaultToken, SSLConfig.sslTypeDataStore, options)
SSLConfig.prepareEnvironment(SSLConfig.sslTypeDataStore, options)
case ("db", options) =>
DBConfig.prepareEnvironment(vaultHost, vaultToken, options)
DBConfig.prepareEnvironment(options)
case _ => Map.empty[String, String]
}
}
6 changes: 2 additions & 4 deletions core/src/main/scala/org/apache/spark/security/DBConfig.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,9 @@
package org.apache.spark.security

object DBConfig {
def prepareEnvironment(vaultHost: String,
vaultToken: String,
options: Map[String, String]): Map[String, String] = {
def prepareEnvironment(options: Map[String, String]): Map[String, String] = {
options.filter(_._1.endsWith("DB_USER_VAULT_PATH")).flatMap{case (_, path) =>
val (pass, user) = VaultHelper.getPassPrincipalFromVault(vaultHost, path, vaultToken)
val (pass, user) = VaultHelper.getPassPrincipalFromVault(path)
Seq(("spark.db.enable", "true"), ("spark.db.user", user), ("spark.db.pass", pass))
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,11 @@ import org.apache.spark.internal.Logging

object KerberosConfig extends Logging{

def prepareEnviroment(vaultUrl: String,
vaultToken: String,
options: Map[String, String]): Map[String, String] = {
def prepareEnviroment(options: Map[String, String]): Map[String, String] = {
val kerberosVaultPath = options.get("KERBEROS_VAULT_PATH")
if(kerberosVaultPath.isDefined) {
val (keytab64, principal) =
VaultHelper.getKeytabPrincipalFromVault(vaultUrl, vaultToken, kerberosVaultPath.get)
VaultHelper.getKeytabPrincipalFromVault(kerberosVaultPath.get)
val keytabPath = getKeytabPrincipal(keytab64, principal)
Map("principal" -> principal, "keytabPath" -> keytabPath)
} else {
Expand Down
16 changes: 6 additions & 10 deletions core/src/main/scala/org/apache/spark/security/SSLConfig.scala
Original file line number Diff line number Diff line change
Expand Up @@ -34,18 +34,15 @@ object SSLConfig extends Logging {

val sslTypeDataStore = "DATASTORE"

def prepareEnvironment(vaultHost: String,
vaultToken: String,
sslType: String,
def prepareEnvironment(sslType: String,
options: Map[String, String]): Map[String, String] = {

val sparkSSLPrefix = "spark.ssl."

val vaultTrustStorePath = options.get(s"${sslType}_VAULT_TRUSTSTORE_PATH")
val vaultTrustStorePassPath = options.get(s"${sslType}_VAULT_TRUSTSTORE_PASS_PATH")
val trustStore = VaultHelper.getTrustStore(vaultHost, vaultToken, vaultTrustStorePath.get)
val trustPass = VaultHelper.getCertPassForAppFromVault(
vaultHost, vaultTrustStorePassPath.get, vaultToken)
val trustStore = VaultHelper.getTrustStore(vaultTrustStorePath.get)
val trustPass = VaultHelper.getCertPassForAppFromVault(vaultTrustStorePassPath.get)
val trustStorePath = generateTrustStore(sslType, trustStore, trustPass)

logInfo(s"Setting SSL values for $sslType")
Expand All @@ -63,14 +60,13 @@ object SSLConfig extends Logging {
val keyStoreOptions = if (vaultKeystorePath.isDefined && vaultKeystorePassPath.isDefined) {

val (key, certs) =
VaultHelper.getCertKeyForAppFromVault(vaultHost, vaultKeystorePath.get, vaultToken)
VaultHelper.getCertKeyForAppFromVault(vaultKeystorePath.get)

pemToDer(key)
generatePemFile(certs, "cert.crt")
generatePemFile(trustStore, "ca.crt")

val pass = VaultHelper.getCertPassForAppFromVault(
vaultHost, vaultKeystorePassPath.get, vaultToken)
val pass = VaultHelper.getCertPassForAppFromVault( vaultKeystorePassPath.get)

val keyStorePath = generateKeyStore(sslType, certs, key, pass)

Expand All @@ -89,7 +85,7 @@ object SSLConfig extends Logging {
val vaultKeyPassPath = options.get(s"${sslType}_VAULT_KEY_PASS_PATH")

val keyPass = Map(s"$sparkSSLPrefix${sslType.toLowerCase}.keyPassword"
-> VaultHelper.getCertPassForAppFromVault(vaultHost, vaultKeyPassPath.get, vaultToken))
-> VaultHelper.getCertPassForAppFromVault(vaultKeyPassPath.get))

val certFilesPath =
Map(s"$sparkSSLPrefix${sslType.toLowerCase}.certPem.path" -> "/tmp/cert.crt",
Expand Down
Loading

0 comments on commit 9d2a96b

Please sign in to comment.