Skip to content

Commit

Permalink
LIVY-998: Support connecting to existing sessions using session name …
Browse files Browse the repository at this point in the history
…via Thrift Server
  • Loading branch information
Asif Khatri committed Apr 22, 2024
1 parent f615f27 commit 6303398
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -155,12 +155,13 @@ class LivyThriftSessionManager(val server: LivyThriftServer, val livyConf: LivyC
}

/**
* If the user specified an existing sessionId to use, the corresponding session is returned,
* otherwise a new session is created and returned.
* If the user specified an existing sessionId or session name to use, the corresponding session
* is returned, otherwise a new session is created and returned.
*/
private def getOrCreateLivySession(
def getOrCreateLivySession(
sessionHandle: SessionHandle,
sessionId: Option[Int],
sessionName: Option[String],
username: String,
createLivySession: () => InteractiveSession): InteractiveSession = {
sessionId match {
Expand All @@ -183,7 +184,27 @@ class LivyThriftSessionManager(val server: LivyThriftServer, val livyConf: LivyC
}
}
case None =>
createLivySession()
sessionName match {
case Some(name) =>
server.livySessionManager.get(name) match {
case None =>
createLivySession()
case Some(session) if !server.isAllowedToUse(username, session) =>
warn(s"$username has no modify permissions to InteractiveSession $name.")
throw new IllegalAccessException(
s"$username is not allowed to use InteractiveSession $name.")
case Some(session) =>
if (session.state.isActive) {
info(s"Reusing Session $name for $sessionHandle.")
session
} else {
warn(s"InteractiveSession $name is not active anymore.")
throw new IllegalArgumentException(s"Session $name is not active anymore.")
}
}
case None =>
createLivySession()
}
}
}

Expand Down Expand Up @@ -248,7 +269,8 @@ class LivyThriftSessionManager(val server: LivyThriftServer, val livyConf: LivyC
livyServiceUGI.doAs(new PrivilegedExceptionAction[InteractiveSession] {
override def run(): InteractiveSession = {
livySession =
getOrCreateLivySession(sessionHandle, sessionId, username, createLivySession)
getOrCreateLivySession(sessionHandle, sessionId, createInteractiveRequest.name,
username, createLivySession)
synchronized {
managedLivySessionActiveUsers.get(livySession.id).foreach { numUsers =>
managedLivySessionActiveUsers(livySession.id) = numUsers + 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,13 @@ import scala.concurrent.duration.Duration
import org.apache.hive.service.cli.{HiveSQLException, SessionHandle}
import org.junit.Assert._
import org.junit.Test
import org.mockito.Mockito.mock
import org.mockito.Mockito.{mock, when}

import org.apache.livy.LivyConf
import org.apache.livy.server.AccessManager
import org.apache.livy.server.interactive.InteractiveSession
import org.apache.livy.server.recovery.{SessionStore, StateStore}
import org.apache.livy.sessions.InteractiveSessionManager
import org.apache.livy.server.recovery.SessionStore
import org.apache.livy.server.AccessManager
import org.apache.livy.sessions.{InteractiveSessionManager, SessionState}
import org.apache.livy.utils.Clock.sleep

object ConnectionLimitType extends Enumeration {
Expand All @@ -46,7 +46,7 @@ class TestLivyThriftSessionManager {
import ConnectionLimitType._

private def createThriftSessionManager(
limitTypes: ConnectionLimitType*): LivyThriftSessionManager = {
limitTypes: ConnectionLimitType*): (LivyThriftSessionManager, LivyThriftServer) = {
val conf = new LivyConf()
conf.set(LivyConf.LIVY_SPARK_VERSION, sys.env("LIVY_SPARK_VERSION"))
val limit = 3
Expand All @@ -62,21 +62,23 @@ class TestLivyThriftSessionManager {
}

private def createThriftSessionManager(
maxSessionWait: Option[String]): LivyThriftSessionManager = {
maxSessionWait: Option[String]): (LivyThriftSessionManager, LivyThriftServer) = {
val conf = new LivyConf()
conf.set(LivyConf.LIVY_SPARK_VERSION, sys.env("LIVY_SPARK_VERSION"))
maxSessionWait.foreach(conf.set(LivyConf.THRIFT_SESSION_CREATION_TIMEOUT, _))
this.createThriftSessionManager(conf)
}

private def createThriftSessionManager(conf: LivyConf): LivyThriftSessionManager = {
private def createThriftSessionManager(conf: LivyConf): (LivyThriftSessionManager,
LivyThriftServer) = {
val server = new LivyThriftServer(
conf,
mock(classOf[InteractiveSessionManager]),
mock(classOf[SessionStore]),
mock(classOf[AccessManager])
)
new LivyThriftSessionManager(server, conf)
val sessionManager = new LivyThriftSessionManager(server, conf)
(sessionManager, server)
}

private def testLimit(
Expand All @@ -99,7 +101,7 @@ class TestLivyThriftSessionManager {

@Test
def testLimitConnectionsByUser(): Unit = {
val thriftSessionMgr = createThriftSessionManager(User)
val (thriftSessionMgr, _) = createThriftSessionManager(User)
val user = "alice"
val forwardedAddresses = new java.util.ArrayList[String]()
thriftSessionMgr.incrementConnections(user, "10.20.30.40", forwardedAddresses)
Expand All @@ -111,7 +113,7 @@ class TestLivyThriftSessionManager {

@Test
def testLimitConnectionsByIpAddress(): Unit = {
val thriftSessionMgr = createThriftSessionManager(IpAddress)
val (thriftSessionMgr, _) = createThriftSessionManager(IpAddress)
val ipAddress = "10.20.30.40"
val forwardedAddresses = new java.util.ArrayList[String]()
thriftSessionMgr.incrementConnections("alice", ipAddress, forwardedAddresses)
Expand All @@ -123,7 +125,7 @@ class TestLivyThriftSessionManager {

@Test
def testLimitConnectionsByUserAndIpAddress(): Unit = {
val thriftSessionMgr = createThriftSessionManager(UserIpAddress)
val (thriftSessionMgr, _) = createThriftSessionManager(UserIpAddress)
val user = "alice"
val ipAddress = "10.20.30.40"
val userAndAddress = user + ":" + ipAddress
Expand All @@ -149,7 +151,7 @@ class TestLivyThriftSessionManager {

@Test
def testMultipleConnectionLimits(): Unit = {
val thriftSessionMgr = createThriftSessionManager(User, IpAddress)
val (thriftSessionMgr, _) = createThriftSessionManager(User, IpAddress)
val user = "alice"
val ipAddress = "10.20.30.40"
val forwardedAddresses = new java.util.ArrayList[String]()
Expand All @@ -166,7 +168,7 @@ class TestLivyThriftSessionManager {

@Test(expected = classOf[TimeoutException])
def testGetLivySessionWaitForTimeout(): Unit = {
val thriftSessionMgr = createThriftSessionManager(Some("10ms"))
val (thriftSessionMgr, _) = createThriftSessionManager(Some("10ms"))
val sessionHandle = mock(classOf[SessionHandle])
val future = Future[InteractiveSession] {
sleep(100)
Expand All @@ -178,7 +180,7 @@ class TestLivyThriftSessionManager {

@Test(expected = classOf[TimeoutException])
def testGetLivySessionWithTimeoutException(): Unit = {
val thriftSessionMgr = createThriftSessionManager(None)
val (thriftSessionMgr, _) = createThriftSessionManager(None)
val sessionHandle = mock(classOf[SessionHandle])
val future = Future[InteractiveSession] {
throw new TimeoutException("Actively throw TimeoutException in Future.")
Expand All @@ -187,4 +189,72 @@ class TestLivyThriftSessionManager {
Await.ready(future, Duration(30, TimeUnit.SECONDS))
thriftSessionMgr.getLivySession(sessionHandle)
}


@Test
def testGetOrCreateLivySessionDifferentSessions(): Unit = {
val (thriftSessionMgr, server) = createThriftSessionManager(User, IpAddress)
val sessionHandle = mock(classOf[SessionHandle])
val sessionUser = "testUser"
val sessionId1 = Some(1)
val session1 = mock(classOf[InteractiveSession])
when(session1.state).thenReturn(SessionState.Running)
when(session1.owner).thenReturn(sessionUser)
when(server.livySessionManager.get(1)).thenReturn(Some(session1))
val sessionId2 = Some(2)
val session2 = mock(classOf[InteractiveSession])
when(session2.state).thenReturn(SessionState.Running)
when(session2.owner).thenReturn(sessionUser)
when(server.livySessionManager.get(2)).thenReturn(Some(session2))
val result1 = thriftSessionMgr.getOrCreateLivySession(sessionHandle, sessionId1, None,
sessionUser, () => null)
val result2 = thriftSessionMgr.getOrCreateLivySession(sessionHandle, sessionId2, None,
sessionUser, () => null)

assertNotNull(result1)
assertNotNull(result2)
assertNotEquals(result1, result2)
}

@Test
def testGetOrCreateLivySessionExistingSessionByID(): Unit = {
val (thriftSessionMgr, server) = createThriftSessionManager(User, IpAddress)
val sessionHandle = mock(classOf[SessionHandle])
val sessionUser = "testUser"
val sessionId = Some(1)
val session1 = mock(classOf[InteractiveSession])
when(session1.state).thenReturn(SessionState.Running)
when(session1.owner).thenReturn(sessionUser)
when(server.livySessionManager.get(1)).thenReturn(Some(session1))
val result1 = thriftSessionMgr.getOrCreateLivySession(sessionHandle, sessionId, None,
sessionUser, () => null)
val result2 = thriftSessionMgr.getOrCreateLivySession(sessionHandle, sessionId, None,
sessionUser, () => null)

assertNotNull(result1)
assertNotNull(result2)
assertEquals(result1, result2)
}


@Test
def testGetOrCreateLivySessionExistingSessionByName(): Unit = {
val (thriftSessionMgr, server) = createThriftSessionManager(User, IpAddress)
val sessionHandle = mock(classOf[SessionHandle])
val sessionUser = "testUser"
val sessionName = Some("sessionName")
val session1 = mock(classOf[InteractiveSession])
when(session1.state).thenReturn(SessionState.Running)
when(session1.owner).thenReturn(sessionUser)
when(server.livySessionManager.get("sessionName")).thenReturn(Some(session1))
val result1 = thriftSessionMgr.getOrCreateLivySession(sessionHandle, None, sessionName,
sessionUser, () => null)
val result2 = thriftSessionMgr.getOrCreateLivySession(sessionHandle, None, sessionName,
sessionUser, () => null)

assertNotNull(result1)
assertNotNull(result2)
assertEquals(result1, result2)
}

}

0 comments on commit 6303398

Please sign in to comment.