Skip to content

Commit

Permalink
Fix partitioning issue (#3)
Browse files Browse the repository at this point in the history
* Fix for batching

* Add extra safety around the batch partitioner
  • Loading branch information
alexjbush authored Jun 21, 2019
1 parent fd7a868 commit 2a51bce
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 22 deletions.
8 changes: 5 additions & 3 deletions src/main/scala/com/coxautodata/SparkDistCP.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@ import com.coxautodata.objects._
import com.coxautodata.utils.{CopyUtils, FileListUtils, PathUtils}
import org.apache.hadoop.fs._
import org.apache.log4j.Level
import org.apache.spark.TaskContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{SaveMode, SparkSession}
import org.apache.spark.{HashPartitioner, TaskContext}

/**
* Spark-based DistCp application.
Expand Down Expand Up @@ -188,9 +188,11 @@ object SparkDistCP extends Logging {
* and repartition the RDD so files in the same batches are in the same partitions
*/
private[coxautodata] def batchAndPartitionFiles(rdd: RDD[CopyDefinitionWithDependencies], maxFilesPerTask: Int, maxBytesPerTask: Long): RDD[((Int, Int), CopyDefinitionWithDependencies)] = {
val batched = rdd.mapPartitionsWithIndex(generateBatchedFileKeys(maxFilesPerTask, maxBytesPerTask))
val partitioner = rdd.partitioner.getOrElse(new HashPartitioner(rdd.partitions.length))
val sorted = rdd.map(v => (v.source.uri.toString, v)).repartitionAndSortWithinPartitions(partitioner).map(_._2)
val batched = sorted.mapPartitionsWithIndex(generateBatchedFileKeys(maxFilesPerTask, maxBytesPerTask)) //sorted

batched.partitionBy(CopyPartitioner(batched.map(_._1).reduceByKey(_ max _).collect()))
batched.partitionBy(CopyPartitioner(batched))
}

/**
Expand Down
11 changes: 8 additions & 3 deletions src/main/scala/com/coxautodata/objects/CopyPartitioner.scala
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
package com.coxautodata.objects

import org.apache.spark.Partitioner
import org.apache.spark.rdd.RDD

/**
* Custom partitioner based on the indexes array containing (partitionid, number of batches within partition)
* Will handle missing partitions.
*/
case class CopyPartitioner(indexes: Array[(Int, Int)]) extends Partitioner {

private val indexesAsMap = indexes.toMap
val indexesAsMap: Map[Int, Int] = indexes.toMap

override val numPartitions: Int = indexes.map(_._2).sum + indexes.length

Expand All @@ -19,9 +20,13 @@ case class CopyPartitioner(indexes: Array[(Int, Int)]) extends Partitioner {
override def getPartition(key: Any): Int = key match {
case (p: Int, i: Int) =>
if (!indexesAsMap.keySet.contains(p)) throw new RuntimeException(s"Key partition $p of key [($p, $i)] was not found in the indexes [${indexesAsMap.keySet.mkString(", ")}].")
else if (i > indexesAsMap(p)) throw new RuntimeException(s"Key index $i of key [($p, $i)] is outside range [<=${indexesAsMap(p)}].")
else partitionOffsets(p) + i
// Modulo the batch id to prevent exceptions if the batch id is out of the range
partitionOffsets(p) + (i % (indexesAsMap(p) + 1))
case u => throw new RuntimeException(s"Partitioned does not support key [$u]. Must be (Int, Int).")
}

}

object CopyPartitioner {
def apply(rdd: RDD[((Int, Int), CopyDefinitionWithDependencies)]): CopyPartitioner = new CopyPartitioner(rdd.map(_._1).reduceByKey(_ max _).collect())
}
49 changes: 36 additions & 13 deletions src/test/scala/com/coxautodata/TestSparkDistCP.scala
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package com.coxautodata

import com.coxautodata.SparkDistCP._
import com.coxautodata.objects.{CopyDefinitionWithDependencies, Directory, File, SerializableFileStatus}
import com.coxautodata.objects.{CopyDefinitionWithDependencies, CopyPartitioner, Directory, File, SerializableFileStatus}
import com.coxautodata.utils.FileListUtils.listFiles
import com.coxautodata.utils.FileListing
import org.apache.hadoop.fs.Path
Expand Down Expand Up @@ -67,12 +67,12 @@ class TestSparkDistCP extends TestSpec {
case ((pp, i), d) => ((p, pp, i), d.source.getPath.toString)
}
} should contain theSameElementsAs Seq(
((0, 0, 0), "/one"),
((2, 1, 0), "/two"),
((0, 0, 0), "/three"),
((2, 1, 0), "/file1"),
((1, 0, 1), "/file2"),
((3, 1, 1), "/file3")
((0, 0, 0), "/file1"),
((0, 0, 0), "/file3"),
((1, 1, 0), "/file2"),
((1, 1, 0), "/one"),
((2, 1, 1), "/three"),
((2, 1, 1), "/two")
)


Expand Down Expand Up @@ -106,18 +106,41 @@ class TestSparkDistCP extends TestSpec {
case ((pp, i), d) => ((p, pp, i), d.source.getPath.toString)
}
} should contain theSameElementsAs Seq(
((0, 0, 0), "/one"),
((0, 0, 0), "/two"),
((1, 0, 1), "/three"),
((1, 0, 1), "/file1"),
((2, 0, 2), "/file2"),
((3, 0, 3), "/file3")
((0, 0, 0), "/file1"),
((0, 0, 0), "/file2"),
((1, 0, 1), "/file3"),
((1, 0, 1), "/one"),
((2, 0, 2), "/three"),
((2, 0, 2), "/two")
)


spark.stop()
}

it("produce predictable batching") {
val spark = new SparkContext(new SparkConf().setAppName("test").setMaster("local[1]"))

val in = List(
CopyDefinitionWithDependencies(SerializableFileStatus(new Path("/1").toUri, 1, File), new Path("/dest/file1").toUri, Seq.empty),
CopyDefinitionWithDependencies(SerializableFileStatus(new Path("/3").toUri, 3000, File), new Path("/dest/file3").toUri, Seq.empty),
CopyDefinitionWithDependencies(SerializableFileStatus(new Path("/2").toUri, 1, File), new Path("/dest/file2").toUri, Seq.empty)
)

val inRDD = spark
.parallelize(in)
.repartition(1)


val unsorted = batchAndPartitionFiles(inRDD, 3, 2000).partitioner.get.asInstanceOf[CopyPartitioner]

val sorted = batchAndPartitionFiles(inRDD.sortBy(_.source.uri.toString), 3, 2000).partitioner.get.asInstanceOf[CopyPartitioner]

unsorted.indexesAsMap should be (sorted.indexesAsMap)

spark.stop()
}

}

describe("run") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ class TestCopyPartitioner extends FunSpec with Matchers {
partitioner.getPartition((4, 1))
}.getMessage should be("Key partition 4 of key [(4, 1)] was not found in the indexes [0, 1, 2, 3].")

intercept[RuntimeException] {
partitioner.getPartition((2, 1))
}.getMessage should be("Key index 1 of key [(2, 1)] is outside range [<=0].")
partitioner.getPartition((2, 0)) should be(5)

partitioner.getPartition((2, 1)) should be(5)

}

Expand Down

0 comments on commit 2a51bce

Please sign in to comment.