diff --git a/extensions/spark/kyuubi-spark-authz/benchmarks/RuleAuthorizationBenchmark-jdk17-results.txt b/extensions/spark/kyuubi-spark-authz/benchmarks/RuleAuthorizationBenchmark-jdk17-results.txt index 9fb611aedbe..87f7c6e3dd8 100644 --- a/extensions/spark/kyuubi-spark-authz/benchmarks/RuleAuthorizationBenchmark-jdk17-results.txt +++ b/extensions/spark/kyuubi-spark-authz/benchmarks/RuleAuthorizationBenchmark-jdk17-results.txt @@ -2,5 +2,5 @@ Java HotSpot(TM) 64-Bit Server VM 17.0.12+8-LTS-286 on Mac OS X 14.6 Apple M3 Collecting files ranger access request: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -50000 files benchmark 1332 1372 44 -0.0 -1331675209.0 1.0X +50000 files benchmark 1281 1310 33 -0.0 -1280563000.0 1.0X diff --git a/extensions/spark/kyuubi-spark-authz/src/main/scala/org/apache/kyuubi/plugin/spark/authz/ranger/RuleAuthorization.scala b/extensions/spark/kyuubi-spark-authz/src/main/scala/org/apache/kyuubi/plugin/spark/authz/ranger/RuleAuthorization.scala index 1c6b17db741..fbb830d6533 100644 --- a/extensions/spark/kyuubi-spark-authz/src/main/scala/org/apache/kyuubi/plugin/spark/authz/ranger/RuleAuthorization.scala +++ b/extensions/spark/kyuubi-spark-authz/src/main/scala/org/apache/kyuubi/plugin/spark/authz/ranger/RuleAuthorization.scala @@ -36,17 +36,18 @@ case class RuleAuthorization(spark: SparkSession) extends Authorization(spark) { val ugi = getAuthzUgi(spark.sparkContext) val (inputs, outputs, opType) = PrivilegesBuilder.build(plan, spark) - // Use a HashMap to deduplicate the same AccessResource and AccessType, it's values will be all + // Use a HashSet to deduplicate the same AccessResource and AccessType, the requests will be all // the non-duplicate requests. - val requests = new mutable.HashMap[(AccessResource, AccessType), (Int, AccessRequest)]() + val requests = new mutable.ArrayBuffer[AccessRequest]() + val requestsSet = new mutable.HashSet[(AccessResource, AccessType)]() def addAccessRequest(objects: Iterable[PrivilegeObject], isInput: Boolean): Unit = { - objects.zipWithIndex.foreach { case (obj, idx) => + objects.foreach { obj => val resource = AccessResource(obj, opType) val accessType = ranger.AccessType(obj, opType, isInput) - if (accessType != AccessType.NONE) { - requests += (resource, accessType) -> - (requests.size, AccessRequest(resource, ugi, opType, accessType)) + if (accessType != AccessType.NONE && !requestsSet.contains((resource, accessType))) { + requests += AccessRequest(resource, ugi, opType, accessType) + requestsSet.add(resource, accessType) } } } @@ -54,7 +55,7 @@ case class RuleAuthorization(spark: SparkSession) extends Authorization(spark) { addAccessRequest(inputs, isInput = true) addAccessRequest(outputs, isInput = false) - val requestArrays = requests.values.toSeq.sortBy(_._1).map(_._2).map { request => + val requestArrays = requests.map { request => val resource = request.getResource.asInstanceOf[AccessResource] resource.objectType match { case ObjectType.COLUMN if resource.getColumns.nonEmpty => @@ -71,7 +72,7 @@ case class RuleAuthorization(spark: SparkSession) extends Authorization(spark) { } case _ => Seq(request) } - } + }.toSeq if (authorizeInSingleCall) { verify(requestArrays.flatten, auditHandler)