Skip to content

Commit

Permalink
[core] async gather (#860)
Browse files Browse the repository at this point in the history
New `gather` methods in `Fiber` and `Async` to collect successful
computations out of a sequence. It also allows specifying a max number
of expected results. See scaladocs for more information.
  • Loading branch information
fwbrasil authored Nov 26, 2024
1 parent e996c17 commit 70f91b4
Show file tree
Hide file tree
Showing 5 changed files with 705 additions and 8 deletions.
91 changes: 91 additions & 0 deletions kyo-core/shared/src/main/scala/kyo/Async.scala
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,8 @@ object Async:
/** Races multiple computations and returns the result of the first to complete. When one computation completes, all other computations
* are interrupted.
*
* WARNING: Executes all computations in parallel without bounds. Use with caution on large sequences to avoid resource exhaustion.
*
* @param seq
* The sequence of computations to race
* @return
Expand Down Expand Up @@ -188,6 +190,95 @@ object Async:
): A < (Abort[E] & Async & Ctx) =
race[E, A, Ctx](first +: rest)

/** Concurrently executes effects and collects their successful results.
*
* WARNING: Executes all computations in parallel without bounds. Use with caution on large sequences to avoid resource exhaustion.
*
* Similar to the sequence-based gather, but accepts varargs input.
*
* @param first
* First effect to execute
* @param rest
* Rest of the effects to execute
* @return
* Successful results as a Chunk
*/
inline def gather[E, A: Flat, Ctx](
first: A < (Abort[E] & Async & Ctx),
rest: A < (Abort[E] & Async & Ctx)*
)(
using frame: Frame
): Chunk[A] < (Abort[E] & Async & Ctx) =
gather(first +: rest)

/** Concurrently executes effects and collects up to `max` successful results.
*
* WARNING: Executes all computations in parallel without bounds. Use with caution on large sequences to avoid resource exhaustion.
*
* Similar to the sequence-based gather with max, but accepts varargs input.
*
* @param max
* Maximum number of successful results to collect
* @param first
* First effect to execute
* @param rest
* Rest of the effects to execute
* @return
* Successful results as a Chunk (size <= max)
*/
inline def gather[E, A: Flat, Ctx](max: Int)(
first: A < (Abort[E] & Async & Ctx),
rest: A < (Abort[E] & Async & Ctx)*
)(
using frame: Frame
): Chunk[A] < (Abort[E] & Async & Ctx) =
gather(max)(first +: rest)

/** Concurrently executes effects and collects their successful results.
*
* WARNING: Executes all computations in parallel without bounds. Use with caution on large sequences to avoid resource exhaustion.
*
* Executes all effects concurrently and returns successful results in completion order. If all computations fail, the last encountered
* error is propagated. The operation completes when all effects have either succeeded or failed.
*
* @tparam Ctx
* Context requirements
* @param seq
* Sequence of effects to execute
* @return
* Successful results as a Chunk
*/
inline def gather[E, A: Flat, Ctx](seq: Seq[A < (Abort[E] & Async & Ctx)])(
using frame: Frame
): Chunk[A] < (Abort[E] & Async & Ctx) =
_gather(seq.size)(seq)

/** Concurrently executes effects and collects up to `max` successful results.
*
* WARNING: Executes all computations in parallel without bounds. Use with caution on large sequences to avoid resource exhaustion.
*
* Similar to `gather`, but completes early once the specified number of `max` successful results is reached. If not enough successes
* occur and all remaining computations fail, the last encountered error is propagated.
*
* @param max
* Maximum number of successful results to collect
* @param seq
* Sequence of effects to execute
* @return
* Successful results as a Chunk (size <= max)
*/
inline def gather[E, A: Flat, Ctx](max: Int)(seq: Seq[A < (Abort[E] & Async & Ctx)])(
using frame: Frame
): Chunk[A] < (Abort[E] & Async & Ctx) =
_gather(max)(seq)

private def _gather[E, A: Flat, Ctx](max: Int)(seq: Seq[A < (Abort[E] & Async & Ctx)])(
using
boundary: Boundary[Ctx, Async & Abort[E]],
frame: Frame
): Chunk[A] < (Abort[E] & Async & Ctx) =
Fiber._gather(max)(seq.size, seq).map(_.get)

/** Runs multiple computations in parallel with unlimited parallelism and returns their results.
*
* Unlike [[parallel]], this method starts all computations immediately without any concurrency control. This can lead to resource
Expand Down
199 changes: 192 additions & 7 deletions kyo-core/shared/src/main/scala/kyo/Fiber.scala
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package kyo

export Fiber.Promise
import java.util.concurrent.atomic.AtomicInteger
import java.util.Arrays
import kyo.Result.Panic
import kyo.Tag
import kyo.internal.FiberPlatformSpecific
Expand Down Expand Up @@ -253,6 +253,8 @@ object Fiber extends FiberPlatformSpecific:
/** Races multiple Fibers and returns a Fiber that completes with the result of the first to complete. When one Fiber completes, all
* other Fibers are interrupted.
*
* WARNING: Executes all computations in parallel without bounds. Use with caution on large sequences to avoid resource exhaustion.
*
* @param seq
* The sequence of Fibers to race
* @return
Expand All @@ -271,9 +273,9 @@ object Fiber extends FiberPlatformSpecific:
frame: Frame,
safepoint: Safepoint
): Fiber[E, A] < (IO & Ctx) =
IO {
IO.Unsafe {
class State extends IOPromise[E, A] with Function1[Result[E, A], Unit]:
val pending = new AtomicInteger(seq.size)
val pending = AtomicInt.Unsafe.init(seq.size)
def apply(result: Result[E, A]): Unit =
val last = pending.decrementAndGet() == 0
result.fold(e => if last then completeDiscard(e))(v => completeDiscard(Result.success(v)))
Expand All @@ -283,7 +285,7 @@ object Fiber extends FiberPlatformSpecific:
import state.*
boundary { (trace, context) =>
IO {
val interruptPanic = Result.Panic(Fiber.Interrupted(frame))
inline def interruptPanic = Result.Panic(Fiber.Interrupted(frame))
foreach(seq) { (_, v) =>
val fiber = IOTask(v, safepoint.copyTrace(trace), context)
state.onComplete(_ => discard(fiber.interrupt(interruptPanic)))
Expand All @@ -294,6 +296,189 @@ object Fiber extends FiberPlatformSpecific:
}
}

/** Concurrently executes effects and collects their successful results.
*
* WARNING: Executes all computations in parallel without bounds. Use with caution on large sequences to avoid resource exhaustion.
*
* Executes all effects concurrently and returns successful results in completion order. If all computations fail, the last encountered
* error is propagated. The operation completes when all effects have either succeeded or failed.
*
* @tparam Ctx
* Context requirements
* @param seq
* Sequence of effects to execute
* @return
* Fiber containing successful results as a Chunk
*/
inline def gather[E, A: Flat, Ctx](seq: Seq[A < (Abort[E] & Async & Ctx)])(
using
frame: Frame,
safepoint: Safepoint
): Fiber[E, Chunk[A]] < (IO & Ctx) =
val total = seq.size
_gather(total)(total, seq)
end gather

/** Concurrently executes effects and collects up to `max` successful results.
*
* WARNING: Executes all computations in parallel without bounds. Use with caution on large sequences to avoid resource exhaustion.
*
* Similar to `gather`, but completes early once the specified number of `max` successful results is reached. If not enough successes
* occur and all remaining computations fail, the last encountered error is propagated.
*
* @param max
* Maximum number of successful results to collect
* @param seq
* Sequence of effects to execute
* @return
* Fiber containing successful results as a Chunk (size <= max)
*/
inline def gather[E, A: Flat, Ctx](max: Int)(seq: Seq[A < (Abort[E] & Async & Ctx)])(
using
frame: Frame,
safepoint: Safepoint
): Fiber[E, Chunk[A]] < (IO & Ctx) =
_gather(max)(seq.size, seq)

private[kyo] def _gather[E, A: Flat, Ctx](max: Int)(total: Int, seq: Seq[A < (Abort[E] & Async & Ctx)])(
using
boundary: Boundary[Ctx, IO & Abort[E]],
frame: Frame,
safepoint: Safepoint
): Fiber[E, Chunk[A]] < (IO & Ctx) =
if total == 0 || max <= 0 then Fiber.success(Chunk.empty)
else
IO.Unsafe {
class State extends IOPromise[E, Chunk[A]] with Function2[Int, Result[E, A], Unit]:
val results = new Array[AnyRef](max)

// Helper array to store original indices to maintain ordering
// Initialized to Int.MaxValue to handle partial results
val indices = new Array[Int](max)
Arrays.fill(indices, Int.MaxValue)

// Packed representation to avoid allocations and ensure atomicity
// - lower 32 bits => successful results count (ok)
// - higher 32 bits => failed results count (nok)
val packed = AtomicLong.Unsafe.init(0)

def apply(idx: Int, result: Result[E, A]): Unit =
@tailrec def loop(): Unit =
// Atomically update both ok/nok counters using CAS
val p = packed.get()
val ok = (p & 0xffffffffL) + (if result.isSuccess then 1 else 0)
val nok = (p >>> 32) + (if result.isFail then 1 else 0)
val np = (nok << 32) | ok
if !packed.cas(p, np) then
// CAS failed, retry the update
loop()
else
val okInt = ok.toInt
result match
case Result.Success(v) =>
if ok <= max then
// Store successful result and its original index for ordering
indices(okInt - 1) = idx
results(okInt - 1) = v.asInstanceOf[AnyRef]
case result: Result.Error[?] =>
if ok == 0 && ok + nok == total then
// If we have no successful results and all computations have completed,
// propagate the last encountered error since there's nothing else to return
completeDiscard(result)
end match
// Complete if we have max successes or all results are in
if ok > 0 && (ok == max || ok + nok == total) then
val size = okInt.min(max)

// Handle race condition
waitForResults(results, size)

// Restore original ordering but limit size since later
// results might still arrive and we want to avoid races
quickSort(indices, results, size)

// Limit final result to max successful results
completeDiscard(Result.success(Chunk.fromNoCopy(results).take(size)))
end if
end if
end loop
loop()
end apply
end State
val state = new State
import state.*
boundary { (trace, context) =>
IO {
inline def interruptPanic = Result.Panic(Fiber.Interrupted(frame))
foreach(seq) { (idx, v) =>
val fiber = IOTask(v, safepoint.copyTrace(trace), context)
state.onComplete(_ => discard(fiber.interrupt(interruptPanic)))
fiber.onComplete(state(idx, _))
}
state
}
}
}

/** Busy waits until all results are present in the `_gather` array.
*
* This is necessary because there's a race condition between:
* - One fiber successfully incrementing the counter via CAS
* - Another fiber seeing the updated counter and trying to complete the gather
* - The first fiber hasn't written its result to the array yet
*
* Without this wait, we might start processing results before all fibers have finished writing their values to the array.
*/
private def waitForResults(results: Array[AnyRef], size: Int): Unit =
@tailrec def loop(i: Int): Unit =
if i < size then
if results(i) == null then
loop(0)
else
loop(i + 1)
loop(0)
end waitForResults

/** Custom quicksort that sorts both indices and results arrays together.
*
* Since `_gather` collects results as they complete but needs to preserve input sequence order, we sort before returning. This
* specialized implementation avoids allocating tuples or wrapper objects by sorting both arrays in-place.
*/
private[kyo] def quickSort(indices: Array[Int], results: Array[AnyRef], items: Int): Unit =

def swap(i: Int, j: Int): Unit =
val tempIdx = indices(i)
indices(i) = indices(j)
indices(j) = tempIdx

val tempRes = results(i)
results(i) = results(j)
results(j) = tempRes
end swap

@tailrec def partitionLoop(low: Int, hi: Int, pivot: Int, i: Int, j: Int): Int =
if j >= hi then
swap(i, pivot)
i
else if indices(j) < indices(pivot) then
swap(i, j)
partitionLoop(low, hi, pivot, i + 1, j + 1)
else
partitionLoop(low, hi, pivot, i, j + 1)

def partition(low: Int, hi: Int): Int =
partitionLoop(low, hi, hi, low, low)

def loop(low: Int, hi: Int): Unit =
if low < hi then
val p = partition(low, hi)
loop(low, p - 1)
loop(p + 1, hi)

if items > 0 then
loop(0, items - 1)
end quickSort

/** Runs multiple computations in parallel with a specified level of parallelism and returns a Fiber that completes with their results.
*
* This method allows you to execute a sequence of computations with controlled parallelism by grouping them into batches. If any
Expand Down Expand Up @@ -356,15 +541,15 @@ object Fiber extends FiberPlatformSpecific:
seq.size match
case 0 => Fiber.success(Seq.empty)
case _ =>
IO {
IO.Unsafe {
class State extends IOPromise[E, Seq[A]] with ((Int, Result[E, A]) => Unit):
val results = (new Array[Any](seq.size)).asInstanceOf[Array[A]]
val pending = new AtomicInteger(seq.size)
val pending = AtomicInt.Unsafe.init(seq.size)
def apply(idx: Int, result: Result[E, A]): Unit =
result.fold(this.interruptDiscard) { value =>
results(idx) = value
if pending.decrementAndGet() == 0 then
this.completeDiscard(Result.success(ArraySeq.unsafeWrapArray(results)))
this.completeDiscard(Result.success(Chunk.fromNoCopy(results)))
}
end State
val state = new State
Expand Down
34 changes: 34 additions & 0 deletions kyo-core/shared/src/test/scala/kyo/AsyncTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -669,4 +669,38 @@ class AsyncTest extends Test:
}
}

"gather" - {
"sequence" - {
"delegates to Fiber.gather" in run {
for
result <- Async.gather(Seq(IO(1), IO(2), IO(3)))
yield assert(result == Chunk(1, 2, 3))
}

"with max limit delegates to Fiber.gather" in run {
for
result <- Async.gather(2)(Seq(IO(1), IO(2), IO(3)))
yield
assert(result.size == 2)
assert(result.forall(Seq(1, 2, 3).contains))
}
}

"varargs" - {
"delegates to sequence-based gather" in run {
for
result <- Async.gather(IO(1), IO(2), IO(3))
yield assert(result == Chunk(1, 2, 3))
}

"with max limit delegates to sequence-based gather" in run {
for
result <- Async.gather(2)(IO(1), IO(2), IO(3))
yield
assert(result.size == 2)
assert(result.forall(Seq(1, 2, 3).contains))
}
}
}

end AsyncTest
Loading

0 comments on commit 70f91b4

Please sign in to comment.