Skip to content

Commit

Permalink
Databricks loader: Generate STS tokens for copying from S3 (close #954)
Browse files Browse the repository at this point in the history
  • Loading branch information
spenes committed Jul 4, 2022
1 parent 07a98f2 commit c38cdef
Show file tree
Hide file tree
Showing 17 changed files with 274 additions and 74 deletions.
26 changes: 26 additions & 0 deletions config/databricks.config.reference.hocon
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,29 @@
"httpPath": "/databricks/http/path",
# User agent name for Databricks connection. Optional, default value "snowplow-rdbloader-oss"
"userAgent": "snowplow-rdbloader-oss"

# Optional, default method is 'NoCreds'
# Specifies the auth method to use with 'COPY INTO' statement.
"loadAuthMethod": {
# With 'NoCreds', no credentials will be passed to 'COPY INTO' statement.
# Databricks cluster needs to have permission to access transformer
# output S3 bucket. More information can be found here:
# https://docs.databricks.com/administration-guide/cloud-configurations/aws/instance-profiles.html
"type": "NoCreds"
}
#"loadAuthMethod": {
# # With 'TempCreds', temporary credentials will be created for every
# # load operation and these temporary credentials will be passed to
# # 'COPY INTO' statement. With this way, Databricks cluster doesn't need
# # permission to access to transformer output S3 bucket.
# # This access will be provided by temporary credentials.
# "type": "TempCreds"
# # IAM role that is used while creating temporary credentials
# # Created credentials will allow to access resources specified in the given role
# # In our case, “s3:GetObject*”, “s3:ListBucket”, and “s3:GetBucketLocation” permissions
# # for transformer output S3 bucket should be specified in the role.
# "roleArn": "arn:aws:iam::123456789:role/role_name"
#}
},

"schedules": {
Expand Down Expand Up @@ -173,6 +196,9 @@
# How long loading (actual COPY statements) can take before considering Redshift unhealthy
# Without any progress (i.e. different subfolder) within this period, loader
# will abort the transaction
# If 'TempCreds' load auth method is used, this value will be used as session duration
# of temporary credentials. In that case, it can't be greater than maximum session duration
# of IAM role used for temporary credentials
"loading": "1 hour",

# How long non-loading steps (such as ALTER TABLE or metadata queries) can take
Expand Down
4 changes: 4 additions & 0 deletions modules/databricks-loader/src/main/resources/application.conf
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,9 @@
"type": "databricks"
"catalog": "hive_metastore"
"userAgent": "snowplow-rdbloader-oss"
"loadAuthMethod": {
"type": "NoCreds"
"roleSessionName": "rdb_loader"
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import com.snowplowanalytics.snowplow.rdbloader.config.{Config, StorageTarget}
import com.snowplowanalytics.snowplow.rdbloader.db.Columns.{ColumnsToCopy, ColumnsToSkip, EventTableColumns}
import com.snowplowanalytics.snowplow.rdbloader.db.Migration.{Block, Entity}
import com.snowplowanalytics.snowplow.rdbloader.db.{Manifest, Statement, Target}
import com.snowplowanalytics.snowplow.rdbloader.db.AuthService.LoadAuthMethod
import com.snowplowanalytics.snowplow.rdbloader.discovery.{DataDiscovery, ShreddedType}
import com.snowplowanalytics.snowplow.rdbloader.loading.EventsTable
import doobie.Fragment
Expand Down Expand Up @@ -47,11 +48,11 @@ object Databricks {

override def extendTable(info: ShreddedType.Info): Option[Block] = None

override def getLoadStatements(discovery: DataDiscovery, eventTableColumns: EventTableColumns): LoadStatements = {
val toCopy = ColumnsToCopy.fromDiscoveredData(discovery)
override def getLoadStatements(discovery: DataDiscovery, eventTableColumns: EventTableColumns, loadAuthMethod: LoadAuthMethod): LoadStatements = {
val toCopy = ColumnsToCopy.fromDiscoveredData(discovery)
val toSkip = ColumnsToSkip(getEntityColumnsPresentInDbOnly(eventTableColumns, toCopy))

NonEmptyList.one(Statement.EventsCopy(discovery.base, discovery.compression, toCopy, toSkip))
NonEmptyList.one(Statement.EventsCopy(discovery.base, discovery.compression, toCopy, toSkip, loadAuthMethod))
}

override def createTable(schemas: SchemaList): Block = Block(Nil, Nil, Entity.Table(tgt.schema, schemas.latest.schemaKey))
Expand Down Expand Up @@ -90,24 +91,27 @@ object Databricks {
val frTableName = Fragment.const(qualify(AlertingTempTableName))
val frManifest = Fragment.const(qualify(Manifest.Name))
sql"SELECT run_id FROM $frTableName MINUS SELECT base FROM $frManifest"
case Statement.FoldersCopy(source) =>
case Statement.FoldersCopy(source, loadAuthMethod) =>
val frTableName = Fragment.const(qualify(AlertingTempTableName))
val frPath = Fragment.const0(source)
val frAuth = loadAuthMethodFragment(loadAuthMethod)

sql"""COPY INTO $frTableName
FROM (SELECT _C0::VARCHAR(512) RUN_ID FROM '$frPath')
FILEFORMAT = CSV""";
case Statement.EventsCopy(path, _, toCopy, toSkip) =>
FROM (SELECT _C0::VARCHAR(512) RUN_ID FROM '$frPath' $frAuth)
FILEFORMAT = CSV"""
case Statement.EventsCopy(path, _, toCopy, toSkip, loadAuthMethod) =>
val frTableName = Fragment.const(qualify(EventsTable.MainName))
val frPath = Fragment.const0(path.append("output=good"))
val nonNulls = toCopy.names.map(_.value)
val nulls = toSkip.names.map(c => s"NULL AS ${c.value}")
val currentTimestamp = "current_timestamp() AS load_tstamp"
val allColumns = (nonNulls ::: nulls) :+ currentTimestamp

val frSelectColumns = Fragment.const0(allColumns.mkString(","))
val allColumns = (nonNulls ::: nulls) :+ currentTimestamp
val frAuth = loadAuthMethodFragment(loadAuthMethod)
val frSelectColumns = Fragment.const0(allColumns.mkString(","))

sql"""COPY INTO $frTableName
FROM (
SELECT $frSelectColumns from '$frPath'
SELECT $frSelectColumns from '$frPath' $frAuth
)
FILEFORMAT = PARQUET
COPY_OPTIONS('MERGESCHEMA' = 'TRUE')""";
Expand Down Expand Up @@ -175,4 +179,12 @@ object Databricks {
.filter(name => name.value.startsWith(UnstructPrefix) || name.value.startsWith(ContextsPrefix))
.diff(toCopy.names)
}

private def loadAuthMethodFragment(loadAuthMethod: LoadAuthMethod): Fragment =
loadAuthMethod match {
case LoadAuthMethod.NoCreds =>
Fragment.empty
case LoadAuthMethod.TempCreds(awsAccessKey, awsSecretKey, awsSessionToken) =>
Fragment.const0(s"WITH ( CREDENTIAL (AWS_ACCESS_KEY = '$awsAccessKey', AWS_SECRET_KEY = '$awsSecretKey', AWS_SESSION_TOKEN = '$awsSessionToken') )")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import com.snowplowanalytics.snowplow.rdbloader.common.LoaderMessage.SnowplowEnt
import com.snowplowanalytics.snowplow.rdbloader.config.{Config, StorageTarget}
import com.snowplowanalytics.snowplow.rdbloader.db.Columns.{ColumnName, ColumnsToCopy, ColumnsToSkip}
import com.snowplowanalytics.snowplow.rdbloader.db.{Statement, Target}
import com.snowplowanalytics.snowplow.rdbloader.db.AuthService.LoadAuthMethod

import scala.concurrent.duration.DurationInt
import org.specs2.mutable.Specification
Expand Down Expand Up @@ -50,8 +51,8 @@ class DatabricksSpec extends Specification {

val discovery = DataDiscovery(baseFolder, shreddedTypes, Compression.Gzip)

target.getLoadStatements(discovery, eventsColumns) should be like {
case NonEmptyList(Statement.EventsCopy(path, compression, columnsToCopy, columnsToSkip), Nil) =>
target.getLoadStatements(discovery, eventsColumns, LoadAuthMethod.NoCreds) should be like {
case NonEmptyList(Statement.EventsCopy(path, compression, columnsToCopy, columnsToSkip, LoadAuthMethod.NoCreds), Nil) =>
path must beEqualTo(baseFolder)
compression must beEqualTo(Compression.Gzip)

Expand Down Expand Up @@ -85,12 +86,30 @@ class DatabricksSpec extends Specification {
ColumnName("unstruct_event_com_acme_bbb_1"),
ColumnName("contexts_com_acme_yyy_1"),
))
val statement = Statement.EventsCopy(baseFolder, Compression.Gzip, toCopy, toSkip)
val statement = Statement.EventsCopy(baseFolder, Compression.Gzip, toCopy, toSkip, LoadAuthMethod.NoCreds)

target.toFragment(statement).toString must beLike { case sql =>
sql must contain("SELECT app_id,unstruct_event_com_acme_aaa_1,contexts_com_acme_xxx_1,NULL AS unstruct_event_com_acme_bbb_1,NULL AS contexts_com_acme_yyy_1,current_timestamp() AS load_tstamp from 's3://somewhere/path/output=good/'")
}
}

"create sql with credentials for loading" in {
val toCopy = ColumnsToCopy(List(
ColumnName("app_id"),
ColumnName("unstruct_event_com_acme_aaa_1"),
ColumnName("contexts_com_acme_xxx_1")
))
val toSkip = ColumnsToSkip(List(
ColumnName("unstruct_event_com_acme_bbb_1"),
ColumnName("contexts_com_acme_yyy_1"),
))
val loadAuthMethod = LoadAuthMethod.TempCreds("testAccessKey", "testSecretKey", "testSessionToken")
val statement = Statement.EventsCopy(baseFolder, Compression.Gzip, toCopy, toSkip, loadAuthMethod)

target.toFragment(statement).toString must beLike { case sql =>
sql must contain(s"SELECT app_id,unstruct_event_com_acme_aaa_1,contexts_com_acme_xxx_1,NULL AS unstruct_event_com_acme_bbb_1,NULL AS contexts_com_acme_yyy_1,current_timestamp() AS load_tstamp from 's3://somewhere/path/output=good/' WITH ( CREDENTIAL (AWS_ACCESS_KEY = '${loadAuthMethod.awsAccessKey}', AWS_SECRET_KEY = '${loadAuthMethod.awsSecretKey}', AWS_SESSION_TOKEN = '${loadAuthMethod.awsSessionToken}') )")
}
}
}
}

Expand All @@ -113,7 +132,8 @@ object DatabricksSpec {
"some/path",
StorageTarget.PasswordConfig.PlainText("xxx"),
None,
"useragent"
"useragent",
StorageTarget.LoadAuthMethod.NoCreds
),
Config.Schedules(Nil),
Config.Timeouts(1.minute, 1.minute, 1.minute),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@ package com.snowplowanalytics.snowplow.rdbloader
import scala.concurrent.duration._
import cats.{Apply, Monad}
import cats.implicits._
import cats.effect.{Clock, Concurrent, MonadThrow, Timer}
import cats.effect.{Clock, Concurrent, MonadThrow, Timer, ContextShift}
import fs2.Stream
import com.snowplowanalytics.snowplow.rdbloader.config.{Config, StorageTarget}
import com.snowplowanalytics.snowplow.rdbloader.db.Columns._
import com.snowplowanalytics.snowplow.rdbloader.db.{AtomicColumns, HealthCheck, Manifest, Statement, Control => DbControl}
import com.snowplowanalytics.snowplow.rdbloader.db.{AtomicColumns, HealthCheck, Manifest, Statement, Control => DbControl, AuthService}
import com.snowplowanalytics.snowplow.rdbloader.discovery.{DataDiscovery, NoOperation, Retries}
import com.snowplowanalytics.snowplow.rdbloader.dsl.{AWS, Cache, DAO, FolderMonitoring, Iglu, Logging, Monitoring, StateMonitoring, Transaction}
import com.snowplowanalytics.snowplow.rdbloader.loading.{EventsTable, Load, Stage, TargetCheck}
Expand Down Expand Up @@ -47,10 +47,10 @@ object Loader {
* Unlike `F` it cannot pull `A` out of DB (perform a transaction), but just
* claim `A` is needed and `C[A]` later can be materialized into `F[A]`
*/
def run[F[_]: Transaction[*[_], C]: Concurrent: AWS: Clock: Iglu: Cache: Logging: Timer: Monitoring,
def run[F[_]: Transaction[*[_], C]: Concurrent: AWS: Clock: Iglu: Cache: Logging: Timer: Monitoring: ContextShift,
C[_]: DAO: MonadThrow: Logging](config: Config[StorageTarget], control: Control[F]): F[Unit] = {
val folderMonitoring: Stream[F, Unit] =
FolderMonitoring.run[C, F](config.monitoring.folders, config.readyCheck, config.storage, control.isBusy)
FolderMonitoring.run[F, C](config.monitoring.folders, config.readyCheck, config.storage, config.timeouts, config.region.name, control.isBusy)
val noOpScheduling: Stream[F, Unit] =
NoOperation.run(config.schedules.noOperation, control.makePaused, control.signal.map(_.loading))
val healthCheck =
Expand Down Expand Up @@ -88,7 +88,7 @@ object Loader {
* A primary loading processing, pulling information from discovery streams
* (SQS and retry queue) and performing the load operation itself
*/
private def loadStream[F[_]: Transaction[*[_], C]: Concurrent: AWS: Iglu: Cache: Logging: Timer: Monitoring,
private def loadStream[F[_]: Transaction[*[_], C]: Concurrent: AWS: Iglu: Cache: Logging: Timer: Monitoring: ContextShift,
C[_]: DAO: MonadThrow: Logging](config: Config[StorageTarget], control: Control[F]): Stream[F, Unit] = {
val sqsDiscovery: DiscoveryStream[F] =
DataDiscovery.discover[F](config, control.incrementMessages, control.isBusy)
Expand All @@ -106,7 +106,7 @@ object Loader {
* over to `Load`. A primary function handling the global state - everything
* downstream has access only to `F` actions, instead of whole `Control` object
*/
private def processDiscovery[F[_]: Transaction[*[_], C]: Concurrent: Iglu: Logging: Timer: Monitoring,
private def processDiscovery[F[_]: Transaction[*[_], C]: Concurrent: Iglu: Logging: Timer: Monitoring: ContextShift,
C[_]: DAO: MonadThrow: Logging](config: Config[StorageTarget], control: Control[F])
(discovery: DataDiscovery.WithOrigin): F[Unit] = {
val folder = discovery.origin.base
Expand All @@ -122,7 +122,8 @@ object Loader {
val loading: F[Unit] = backgroundCheck {
for {
start <- Clock[F].instantNow
result <- Load.load[F, C](config, setStageC, control.incrementAttempts, discovery)
loadAuth <- AuthService.getLoadAuthMethod[F](config.storage.loadAuthMethod, config.region.name, config.timeouts.loading)
result <- Load.load[F, C](config, setStageC, control.incrementAttempts, discovery, loadAuth)
attempts <- control.getAndResetAttempts
_ <- result match {
case Right(ingested) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ sealed trait StorageTarget extends Product with Serializable {
def withAutoCommit: Boolean = false
def connectionUrl: String
def properties: Properties
def loadAuthMethod: StorageTarget.LoadAuthMethod
}

object StorageTarget {
Expand Down Expand Up @@ -93,6 +94,8 @@ object StorageTarget {
}
props
}

def loadAuthMethod: LoadAuthMethod = LoadAuthMethod.NoCreds
}

final case class Databricks(
Expand All @@ -103,7 +106,8 @@ object StorageTarget {
httpPath: String,
password: PasswordConfig,
sshTunnel: Option[TunnelConfig],
userAgent: String
userAgent: String,
loadAuthMethod: LoadAuthMethod
) extends StorageTarget {

override def username: String = "token"
Expand Down Expand Up @@ -194,6 +198,8 @@ object StorageTarget {
"Snowflake config requires either jdbcHost or both account and region".asLeft
}
}

def loadAuthMethod: LoadAuthMethod = LoadAuthMethod.NoCreds
}

object Snowflake {
Expand Down Expand Up @@ -300,6 +306,13 @@ object StorageTarget {
}
}

sealed trait LoadAuthMethod extends Product with Serializable

object LoadAuthMethod {
final case object NoCreds extends LoadAuthMethod
final case class TempCreds(roleArn: String, roleSessionName: String) extends LoadAuthMethod
}

/**
* SSH configuration, enabling target to be loaded though tunnel
*
Expand Down Expand Up @@ -347,6 +360,26 @@ object StorageTarget {
case Snowflake.AbortStatement => "ABORT_STATEMENT"
}

implicit def loadAuthMethodDecoder: Decoder[LoadAuthMethod] =
Decoder.instance { cur =>
val typeCur = cur.downField("type")
typeCur.as[String].map(_.toLowerCase) match {
case Right("nocreds") =>
Right(LoadAuthMethod.NoCreds)
case Right("tempcreds") =>
cur.as[LoadAuthMethod.TempCreds]
case Right(other) =>
Left(DecodingFailure(s"Load auth method of type $other is not supported yet. Supported types: 'NoCreds', 'TempCreds'", typeCur.history))
case Left(DecodingFailure(_, List(CursorOp.DownField("type")))) =>
Left(DecodingFailure("Cannot find 'type' string in load auth method", typeCur.history))
case Left(other) =>
Left(other)
}
}

implicit def tempCredsAuthMethodDecoder: Decoder[LoadAuthMethod.TempCreds] =
deriveDecoder[LoadAuthMethod.TempCreds]

implicit def storageTargetDecoder: Decoder[StorageTarget] =
Decoder.instance { cur =>
val typeCur = cur.downField("type")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
/*
* Copyright (c) 2014-2022 Snowplow Analytics Ltd. All rights reserved.
*
* This program is licensed to you under the Apache License Version 2.0,
* and you may not use this file except in compliance with the Apache License Version 2.0.
* You may obtain a copy of the Apache License Version 2.0 at http://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the Apache License Version 2.0 is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the Apache License Version 2.0 for the specific language governing permissions and limitations there under.
*/
package com.snowplowanalytics.snowplow.rdbloader.db

import scala.concurrent.duration.FiniteDuration

import cats.effect.{Concurrent, ContextShift}

import cats.implicits._

import software.amazon.awssdk.regions.Region
import software.amazon.awssdk.services.sts.StsAsyncClient
import software.amazon.awssdk.services.sts.model.AssumeRoleRequest

import com.snowplowanalytics.aws.Common
import com.snowplowanalytics.snowplow.rdbloader.config.StorageTarget

object AuthService {
/**
* Auth method that is used with COPY INTO statement
*/
sealed trait LoadAuthMethod

object LoadAuthMethod {
/**
* Specifies auth method that doesn't use credentials
* Destination should be already configured with some other mean
* for copying from transformer output bucket
*/
final case object NoCreds extends LoadAuthMethod

/**
* Specifies auth method that pass temporary credentials to COPY INTO statement
*/
final case class TempCreds(awsAccessKey: String, awsSecretKey: String, awsSessionToken: String) extends LoadAuthMethod
}

/**
* Get load auth method according to value specified in the config
* If temporary credentials method is specified in the config, it will get temporary credentials
* with sending request to STS service then return credentials.
*/
def getLoadAuthMethod[F[_]: Concurrent: ContextShift](authMethodConfig: StorageTarget.LoadAuthMethod,
region: String,
sessionDuration: FiniteDuration): F[LoadAuthMethod] =
authMethodConfig match {
case StorageTarget.LoadAuthMethod.NoCreds => Concurrent[F].pure(LoadAuthMethod.NoCreds)
case StorageTarget.LoadAuthMethod.TempCreds(roleArn, roleSessionName) =>
for {
stsAsyncClient <- Concurrent[F].delay(
StsAsyncClient.builder()
.region(Region.of(region))
.build()
)
assumeRoleRequest <- Concurrent[F].delay(
AssumeRoleRequest.builder()
.durationSeconds(sessionDuration.toSeconds.toInt)
.roleArn(roleArn)
.roleSessionName(roleSessionName)
.build()
)
response <- Common.fromCompletableFuture(
Concurrent[F].delay(stsAsyncClient.assumeRole(assumeRoleRequest))
)
creds = response.credentials()
} yield LoadAuthMethod.TempCreds(creds.accessKeyId(), creds.secretAccessKey(), creds.sessionToken())
}
}
Loading

0 comments on commit c38cdef

Please sign in to comment.