diff --git a/modules/loader/src/main/scala/com/snowplowanalytics/snowplow/rdbloader/dsl/Environment.scala b/modules/loader/src/main/scala/com/snowplowanalytics/snowplow/rdbloader/dsl/Environment.scala index c943e2c21..e9e4e3dd4 100644 --- a/modules/loader/src/main/scala/com/snowplowanalytics/snowplow/rdbloader/dsl/Environment.scala +++ b/modules/loader/src/main/scala/com/snowplowanalytics/snowplow/rdbloader/dsl/Environment.scala @@ -39,7 +39,6 @@ import com.snowplowanalytics.snowplow.rdbloader.config.{CliConfig, Config, Stora import com.snowplowanalytics.snowplow.rdbloader.config.Config.Cloud import com.snowplowanalytics.snowplow.rdbloader.db.Target import com.snowplowanalytics.snowplow.rdbloader.dsl.metrics._ -import com.snowplowanalytics.snowplow.rdbloader.utils.SSH /** Container for most of interepreters to be used in Main @@ -110,7 +109,6 @@ object Environment { periodicMetrics <- Resource.eval(Metrics.PeriodicMetrics.init[F](reporters, cli.config.monitoring.metrics.period)) implicit0(monitoring: Monitoring[F]) = Monitoring.monitoringInterpreter[F](tracker, sentry, reporters, cli.config.monitoring.webhook, httpClient, periodicMetrics) implicit0(secretStore: SecretStore[F]) = cloudServices.secretStore - _ <- SSH.resource(cli.config.storage.sshTunnel) transaction <- Transaction.interpreter[F](cli.config.storage, blocker) telemetry <- Telemetry.build[F]( cli.config.telemetry, diff --git a/modules/loader/src/main/scala/com/snowplowanalytics/snowplow/rdbloader/dsl/Transaction.scala b/modules/loader/src/main/scala/com/snowplowanalytics/snowplow/rdbloader/dsl/Transaction.scala index da66c811c..23bf2901c 100644 --- a/modules/loader/src/main/scala/com/snowplowanalytics/snowplow/rdbloader/dsl/Transaction.scala +++ b/modules/loader/src/main/scala/com/snowplowanalytics/snowplow/rdbloader/dsl/Transaction.scala @@ -16,7 +16,7 @@ import cats.~> import cats.arrow.FunctionK import cats.implicits._ -import cats.effect.{ContextShift, Blocker, Async, Resource, Timer, ConcurrentEffect, Sync, Effect} +import cats.effect.{ContextShift, Blocker, Resource, Timer, ConcurrentEffect, Sync, Effect} import cats.effect.implicits._ import doobie._ @@ -27,6 +27,7 @@ import doobie.hikari._ import java.sql.SQLException import com.snowplowanalytics.snowplow.rdbloader.config.StorageTarget +import com.snowplowanalytics.snowplow.rdbloader.utils.SSH import com.snowplowanalytics.snowplow.rdbloader.common.cloud.SecretStore @@ -91,7 +92,7 @@ object Transaction { def apply[F[_], C[_]](implicit ev: Transaction[F, C]): Transaction[F, C] = ev - def buildPool[F[_]: Async: ContextShift: Timer: SecretStore]( + def buildPool[F[_]: ConcurrentEffect: ContextShift: Timer: SecretStore]( target: StorageTarget, blocker: Blocker ): Resource[F, Transactor[F]] = @@ -112,6 +113,7 @@ object Transaction { ds.setDataSourceProperties(target.properties) } }) + xa <- target.sshTunnel.fold(Resource.pure[F, Transactor[F]](xa))(SSH.transactor(_, blocker, xa)) } yield xa /** diff --git a/modules/loader/src/main/scala/com/snowplowanalytics/snowplow/rdbloader/utils/SSH.scala b/modules/loader/src/main/scala/com/snowplowanalytics/snowplow/rdbloader/utils/SSH.scala index bc6d7da8f..051b962b4 100644 --- a/modules/loader/src/main/scala/com/snowplowanalytics/snowplow/rdbloader/utils/SSH.scala +++ b/modules/loader/src/main/scala/com/snowplowanalytics/snowplow/rdbloader/utils/SSH.scala @@ -13,9 +13,11 @@ package com.snowplowanalytics.snowplow.rdbloader.utils import cats.Monad -import cats.effect.{ConcurrentEffect, Resource, Sync} +import cats.effect.{Blocker, ContextShift, ConcurrentEffect, Effect, Resource, Sync} +import cats.effect.concurrent.Semaphore import cats.syntax.all._ import cats.effect.syntax.all._ +import doobie.Transactor import com.jcraft.jsch.{JSch, Session, Logger => JLogger} import com.snowplowanalytics.snowplow.rdbloader.config.StorageTarget.TunnelConfig import org.typelevel.log4cats.Logger @@ -29,29 +31,49 @@ object SSH { /** Actual SSH identity data. Both passphrase and key are optional */ case class Identity(passphrase: Option[Array[Byte]], key: Option[Array[Byte]]) - /** Open SSH tunnel, which will be guaranteed to be closed when application exits */ - def resource[F[_]:ConcurrentEffect: Sync: SecretStore](tunnelConfig: Option[TunnelConfig]): Resource[F, Unit] = - tunnelConfig match { - case Some(tunnel) => - Resource.eval{ + final class SSHException(cause: Throwable) extends Exception(s"Error setting up SSH tunnel: ${cause.getMessage}", cause) - Sync[F].delay(JSch.setLogger(new JLogger{ - override def isEnabled(level: Int): Boolean = true + /** A doobie transactor that ensures the SSH tunnel is connected before attempting a connection to the warehouse */ + def transactor[F[_]: ConcurrentEffect: ContextShift: SecretStore, A](config: TunnelConfig, blocker: Blocker, inner: Transactor.Aux[F, A]): Resource[F, Transactor.Aux[F, A]] = + for { + _ <- Resource.eval(configureLogging) + identity <- Resource.eval(getIdentity(config)) + session <- Resource.make(createSession(config, identity))(s => Sync[F].delay(s.disconnect())) + _ <- setPortForwarding(config, session) + sem <- Resource.eval(Semaphore[F](1)) + } yield inner.copy(connect0 = a => Resource.eval(ensureTunnel(session, blocker, sem)) *> inner.connect(a)) - override def log(level: Int, message: String): Unit = level match { - case JLogger.INFO => Logger[F].info("JCsh: " + message).toIO.unsafeRunSync() - case JLogger.ERROR => Logger[F].error("JCsh: " + message).toIO.unsafeRunSync() - case JLogger.DEBUG => Logger[F].debug("JCsh: " + message).toIO.unsafeRunSync() - case JLogger.WARN => Logger[F].warn("JCsh: " + message).toIO.unsafeRunSync() - case JLogger.FATAL => Logger[F].error("JCsh: " + message).toIO.unsafeRunSync() - case _ => Logger[F].warn("NO LOG LEVEL JCsh: " + message).toIO.unsafeRunSync() - } - }))} >> - Resource.make(getIdentity[F](tunnel).flatMap(i => createSession(tunnel, i)))(s => Sync[F].delay(s.disconnect())).void - case None => - Resource.pure[F, Unit](()) + + /** Ensure the SSH tunnel is connected. + * + * Uses a semaphore to prevent multiple fibers trying to connect the session at the same time + */ + def ensureTunnel[F[_]: Sync: ContextShift](session: Session, blocker: Blocker, sem: Semaphore[F]): F[Unit] = + sem.withPermit { + Sync[F].delay(session.isConnected()) + .ifM( + Logger[F].debug("SSH session is already connected"), + blocker.delay(session.connect()) + ) + .adaptError { + case t: Throwable => new SSHException(t) + } } + def configureLogging[F[_]: Effect]: F[Unit] = + Sync[F].delay(JSch.setLogger(new JLogger{ + override def isEnabled(level: Int): Boolean = true + + override def log(level: Int, message: String): Unit = level match { + case JLogger.INFO => Logger[F].info("JCsh: " + message).toIO.unsafeRunSync() + case JLogger.ERROR => Logger[F].error("JCsh: " + message).toIO.unsafeRunSync() + case JLogger.DEBUG => Logger[F].debug("JCsh: " + message).toIO.unsafeRunSync() + case JLogger.WARN => Logger[F].warn("JCsh: " + message).toIO.unsafeRunSync() + case JLogger.FATAL => Logger[F].error("JCsh: " + message).toIO.unsafeRunSync() + case _ => Logger[F].warn("NO LOG LEVEL JCsh: " + message).toIO.unsafeRunSync() + } + })) + /** Convert pure tunnel configuration to configuration with actual key and passphrase */ def getIdentity[F[_]: Monad: Sync: SecretStore](tunnelConfig: TunnelConfig): F[Identity] = tunnelConfig @@ -61,9 +83,12 @@ object SSH { .map { key => Identity(tunnelConfig.bastion.passphrase.map(_.getBytes), key.map(_.getBytes)) } /** - * Create a SSH tunnel to bastion host and set port forwarding to target DB + * Create a SSH session configured for the bastion host. + * + * The returned session is not yet connected and is not yet listening on a local port. + * * @param tunnelConfig SSH-tunnel configuration - * @return either nothing on success and error message on failure + * @param identity SSH identity data */ def createSession[F[_]: Sync](tunnelConfig: TunnelConfig, identity: Identity): F[Session] = Sync[F].delay { @@ -73,9 +98,21 @@ object SSH { jsch.addIdentity("rdb-loader-tunnel-key", identity.key.orNull, null, identity.passphrase.orNull) val sshSession = jsch.getSession(tunnelConfig.bastion.user, tunnelConfig.bastion.host, tunnelConfig.bastion.port) sshSession.setConfig("StrictHostKeyChecking", "no") - sshSession.connect() - val _ = sshSession.setPortForwardingL(tunnelConfig.localPort, tunnelConfig.destination.host, tunnelConfig.destination.port) sshSession } -} + /** + * Start the Session listening on the local port + */ + def setPortForwarding[F[_]: Sync](config: TunnelConfig, session: Session): Resource[F, Unit] = { + val acquire = Sync[F].delay { + session.setPortForwardingL(config.localPort, config.destination.host, config.destination.port) + }.adaptError { + case t: Throwable => new SSHException(t) + }.void + val release = Sync[F].delay { + session.delPortForwardingL(config.localPort) + } + Resource.make(acquire)(_ => release) + } +}