Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LIVY-239: Moving the logic to generate session IDs from Session Manager to SessionStore #220

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

package com.cloudera.livy.server.recovery

import java.util.concurrent.atomic.AtomicLong

import scala.reflect.ClassTag

import com.cloudera.livy.LivyConf
Expand All @@ -27,11 +29,16 @@ import com.cloudera.livy.LivyConf
* Livy will use this when session recovery is disabled.
*/
class BlackholeStateStore(livyConf: LivyConf) extends StateStore(livyConf) {

val atomicLong: AtomicLong = new AtomicLong(-1L)

def set(key: String, value: Object): Unit = {}

def get[T: ClassTag](key: String): Option[T] = None

def getChildren(key: String): Seq[String] = List.empty[String]

def remove(key: String): Unit = {}

override def increment(key: String): Long = atomicLong.incrementAndGet()
}
Original file line number Diff line number Diff line change
Expand Up @@ -120,4 +120,10 @@ class FileSystemStateStore(
}

private def absPath(key: String): Path = new Path(fsUri.getPath(), key)

override def increment(key: String): Long = synchronized {
val incrementedValue = get[Long](key).getOrElse(-1L) + 1
set(key, incrementedValue.asInstanceOf[Object])
incrementedValue
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,6 @@ import scala.util.control.NonFatal
import com.cloudera.livy.{LivyConf, Logging}
import com.cloudera.livy.sessions.Session.RecoveryMetadata

private[recovery] case class SessionManagerState(nextSessionId: Int)

/**
* SessionStore provides high level functions to get/save session state from/to StateStore.
*/
Expand All @@ -47,10 +45,6 @@ class SessionStore(
store.set(sessionPath(sessionType, m.id), m)
}

def saveNextSessionId(sessionType: String, id: Int): Unit = {
store.set(sessionManagerPath(sessionType), SessionManagerState(id))
}

/**
* Return all sessions stored in the store with specified session type.
*/
Expand All @@ -68,15 +62,14 @@ class SessionStore(
}

/**
* Return the next unused session id with specified session type.
* If checks the SessionManagerState stored and returns the next free session id.
* If no SessionManagerState is stored, it returns 0.
* Return the next unused session ID from state store with the specified session type.
* If no value is stored state store, it returns 0.
* It saves the next unused session ID to the session store before returning the current value.
*
* @throws Exception If SessionManagerState stored is corrupted, it throws an error.
* @throws Exception If session store is corrupted or unreachable, it throws an error.
*/
def getNextSessionId(sessionType: String): Int = {
store.get[SessionManagerState](sessionManagerPath(sessionType))
.map(_.nextSessionId).getOrElse(0)
store.increment(sessionManagerPath(sessionType)).toInt
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,13 @@ abstract class StateStore(livyConf: LivyConf) extends JsonMapper {
* @throws Exception Throw when persisting the state store fails.
*/
def remove(key: String): Unit

/**
* Gets the Long value for the given key, increments the value, and stores the new value before
* returning the value.
* @return incremented value
*/
def increment(key: String): Long
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,13 @@ package com.cloudera.livy.server.recovery

import scala.collection.JavaConverters._
import scala.reflect.ClassTag
import scala.util.Try
import scala.util.matching.Regex

import org.apache.curator.RetryPolicy
import org.apache.curator.framework.{CuratorFramework, CuratorFrameworkFactory}
import org.apache.curator.framework.api.UnhandledErrorListener
import org.apache.curator.framework.recipes.atomic.{DistributedAtomicLong => DistributedLong}
import org.apache.curator.retry.RetryNTimes
import org.apache.zookeeper.KeeperException.NoNodeException

Expand All @@ -46,18 +50,22 @@ class ZooKeeperStateStore(
}

private val zkAddress = livyConf.get(LivyConf.RECOVERY_STATE_STORE_URL)

require(!zkAddress.isEmpty, s"Please config ${LivyConf.RECOVERY_STATE_STORE_URL.key}.")
private val zkKeyPrefix = livyConf.get(ZK_KEY_PREFIX_CONF)
private val curatorClient = mockCuratorClient.getOrElse {
val retryValue = livyConf.get(ZK_RETRY_CONF)

private val retryValue = livyConf.get(ZK_RETRY_CONF)
private val retryPolicy = Try {
// a regex to match patterns like "m, n" where m and m both are integer values
val retryPattern = """\s*(\d+)\s*,\s*(\d+)\s*""".r
val retryPolicy = retryValue match {
case retryPattern(n, sleepMs) => new RetryNTimes(5, 100)
case _ => throw new IllegalArgumentException(
s"$ZK_KEY_PREFIX_CONF contains bad value: $retryValue. " +
"Correct format is <max retry count>,<sleep ms between retry>. e.g. 5,100")
}
val retryPattern(retryTimes, sleepMsBetweenRetries) = retryValue
new RetryNTimes(retryTimes.toInt, sleepMsBetweenRetries.toInt)
}.getOrElse { throw new IllegalArgumentException(
s"$ZK_RETRY_CONF contains bad value: $retryValue. " +
"Correct format is <max retry count>,<sleep ms between retry>. e.g. 5,100")
}

private val zkKeyPrefix = livyConf.get(ZK_KEY_PREFIX_CONF)
private val curatorClient = mockCuratorClient.getOrElse {
CuratorFrameworkFactory.newClient(zkAddress, retryPolicy)
}

Expand Down Expand Up @@ -113,5 +121,15 @@ class ZooKeeperStateStore(
}
}

override def increment(key: String): Long = {
val distributedSessionId = new DistributedLong(curatorClient, key, retryPolicy)
distributedSessionId.increment() match {
case atomicValue if atomicValue.succeeded() =>
atomicValue.postValue()
case _ =>
throw new java.io.IOException(s"Failed to atomically increment the value for $key")
}
}

private def prefixKey(key: String) = s"/$zkKeyPrefix/$key"
}
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@ class SessionManager[S <: Session, R <: RecoveryMetadata : ClassTag](

protected implicit def executor: ExecutionContext = ExecutionContext.global

protected[this] final val idCounter = new AtomicInteger(0)
protected[this] final val sessions = mutable.LinkedHashMap[Int, S]()

private[this] final val sessionTimeout =
Expand All @@ -78,11 +77,8 @@ class SessionManager[S <: Session, R <: RecoveryMetadata : ClassTag](
mockSessions.getOrElse(recover()).foreach(register)
new GarbageCollector().start()

def nextId(): Int = synchronized {
val id = idCounter.getAndIncrement()
sessionStore.saveNextSessionId(sessionType, idCounter.get())
id
}
// sessionStore.getNextSessionId is guaranteed to return atomic and returns unique IDs.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Extra space between return & atomic.

def nextId(): Int = sessionStore.getNextSessionId(sessionType)

def register(session: S): S = {
info(s"Registering new session ${session.id}")
Expand Down Expand Up @@ -136,18 +132,13 @@ class SessionManager[S <: Session, R <: RecoveryMetadata : ClassTag](
}

private def recover(): Seq[S] = {
// Recover next session id from state store and create SessionManager.
idCounter.set(sessionStore.getNextSessionId(sessionType))

// Retrieve session recovery metadata from state store.
val sessionMetadata = sessionStore.getAllSessions[R](sessionType)

// Recover session from session recovery metadata.
val recoveredSessions = sessionMetadata.flatMap(_.toOption).map(sessionRecovery)

info(s"Recovered ${recoveredSessions.length} $sessionType sessions." +
s" Next session id: $idCounter")

// Print recovery error.
val recoveryFailure = sessionMetadata.filter(_.isFailure).map(_.failed.get)
recoveryFailure.foreach(ex => error(ex.getMessage, ex.getCause))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,12 +89,11 @@ class SessionStoreSpec extends FunSpec with LivyBaseUnitTestSuite {
val stateStore = mock[StateStore]
val sessionStore = new SessionStore(conf, stateStore)

when(stateStore.get[SessionManagerState](sessionManagerPath)).thenReturn(None)
when(stateStore.increment(sessionManagerPath)).thenReturn(0L)
sessionStore.getNextSessionId(sessionType) shouldBe 0

val sms = SessionManagerState(100)
when(stateStore.get[SessionManagerState](sessionManagerPath)).thenReturn(Some(sms))
sessionStore.getNextSessionId(sessionType) shouldBe sms.nextSessionId
when(stateStore.increment(sessionManagerPath)).thenReturn(100)
sessionStore.getNextSessionId(sessionType) shouldBe 100
}

it("should remove session") {
Expand Down