Hacking an Scala Async / Await syntax with Cats Effect and Continuations
NOTE: This document is WIP.
The story:
Virtual Threads were introduced into the JVM with Java 21, and with them, we got these internal apis:
import jdk.internal.vm.Continuation;
import jdk.internal.vm.ContinuationScope;
they are soooo internal that you require this flag:
--add-exports java.base/jdk.internal.vm=ALL-UNNAMED
And I did use them. Quite a few months ago, I built koka-shokki (name which should mean something about effect or so). This was an experiment that used and abused these apis to provide an effect system which looks more imperative rather than with a classic monadic style. This is how it looks like:
def divide(
x: With[Optional, Int],
y: With[Optional, Int]
): Int <~ Optional :| Lefted[String] =
WithCapability[Optional :| Lefted[String], Int] {
val result: With[Lefted[String], Int] =
if (y.? == 0) asLeft("division by zero") else ok(x.? / y.?)
result.?
}
But there is a catch! I really never implemented an IO effect. And for some reason, my curiosity just ask me a few days ago to considerate something related to but not quite what koka-shokki is: what if this could allow Cats Effect's IO to run in an Async/await syntax?
So this is about that! But is an exploration rather than anything else, since I'm writing this at the same time that I'm writing the proof of concept. Just for fun :).
The API of continuations (from koka-shokki perspective):
This is what basically introduces the await syntax in koka-shokki:
def apply[C[_] <: Capability[?], A](f: Scope[C] ?=> A): With[C, A]
Since I don't quite remember what I wrote, let's analyze it a bit.
So we have a function with two generic arguments C
and A
. Where
C
is the "effect" executed, which is important explicitly for
koka-shokki, since we are constraining this to Cats IO only. Then
A
is just the normal return type of the burrito. Now there is
f
which is kind of special. The function f
is using Scala 3
context functions, so this will only work in Scala 3 (since in
Scala 2.13 this would be more verbose). The point of using these
types of functions, if I remember correctly, is that it should feel
more natural: you shouldn't have to know that the Scope is there
implicitly when calling apply:
WithCapability[Optional :| Lefted[String], Int] {
// do something!
awaitMe.?
}
So, why the scope then?
The scope is used when calling the await operator .?
, since it has to
do the yield-ing into the continuation. Then .?
requires a Scope
as an implicit parameter. So this improves UX, since we don't have to
worry about explicitly passing around a parameter.
How to Continuation and Scope
Internally, koka-shokki uses two Scopes. One is ContinuationScope
from
the package jdk.internal.vm
. The other one is Scope
which wraps the
yield to the relevant ContinuationScope
.
So we have this:
var r: With[C, A] = null.asInstanceOf[With[C, A]]
val scope: ContinuationScope = new ContinuationScope("WithCapability")
val scopeInstance: Scope[C] = new Scope {
override def ko[B, K[_] <: Capability[?]](
x: With[K, B]
)(implicit ev: K[B] <:< C[B]): B = {
r = x.asInstanceOf[
With[C, A]
] // The cast is safe because of the proof from the implicit parameter
Continuation.`yield`(scope)
throw new RuntimeException("unreachable")
}
}
Here r
is the actual return type. Which is set when the await .?
operator is executed (as ko
is the relevant function for .?
).
Basically when we encounter an .?
, we set r to the current Effect
being executed with that .?
when the Effect is "throwing" and
then we yield into scope
. See:
def ?(using scope: Scope[C]): A = value match {
case Left(channel) => scope.ko(self)
case Right(value) => value
}
So either we continue execution or we yield with the effect. Then, returning to the previous function, it follows:
val cont: Continuation = new Continuation(
scope,
() => {
val result = f(using scopeInstance)
r = new With[C, A] {
val value = Right(result)
}
Continuation.`yield`(scope)
()
}
)
cont.run
which basically sets the return value and yields, so we continue
execution from cont.run
.
How are we going to do it then?
Let's try first defining the type of the function that introduces the await syntax:
extension [A](effect: IO[A]) def await: A
def async[A](f: Scope ?=> IO[A]): IO[A]
If we think about the koka-s implementation, I used a variable to store the return value which was set when awaiting. This was easy to do as we can pattern match and extract the "non error" value from it. But with IOs this could be quite tricky.
Here it goes the first attempt
def async[A](f: Scope ?=> IO[A]): IO[A] = for {
dispatch <- Dispatcher.parallel[IO].allocated
m <- Semaphore[IO](1)
returnValue <- Deferred[IO, Either[Throwable, A]]
doneV <- Ref.of[IO, Boolean](false)
scope: ContinuationScope <- IO.delay(ContinuationScope(s"async-io"))
scopeInstance: Scope <- IO.delay(new Scope {
override val cscope: ContinuationScope = scope
override val dispatcher: Dispatcher[IO] = dispatch._1
override val sem: Semaphore[IO] = m
override val deferred: Deferred[IO, Either[Throwable, Any]] =
returnValue.asInstanceOf
override val done: Ref[IO, Boolean] = doneV
})
continuation <- IO.delay(
new Continuation(
scope,
() => {
val result = Try {
f(using scopeInstance)
}
dispatch._1.unsafeRunSync(
IO.fromTry(result)
.flatten
.attempt
.flatMap(v => returnValue.complete(v)) *> doneV.set(true)
)
Continuation.`yield`(scope)
()
}
)
)
_ <- m.acquire
_ <- (m.release *> IO.delay(continuation.run()) *> m.acquire)
.whileM_(doneV.get.map(!_))
result <- returnValue.get.flatMap(IO.fromEither).guarantee(dispatch._2)
} yield result
A lot is happening here. Let's start with this section:
// ...
for {
dispatch <- Dispatcher.parallel[IO].allocated
conti <- Deferred[IO, Continuation]
m <- Semaphore[IO](1)
returnValue <- Deferred[IO, Either[Throwable, A]]
doneV <- Ref.of[IO, Boolean](false)
// ...
} yield ()
Here dispatch
is what does the trick: it provides the evaluation context, so
we can have a result in direct style (we'll see later how). m
is used to
block the continuation from returning into the place it yielded. I'm using a
semaphore from Cats Effect since I'm assuming it would yield (well, actually
"cede") execution to another fiber while the effect executed by dispatch
isn't completely evaluated. And you may think (or maybe no, who knows) that
a mutex should be enough, but I did it this way because in the "continuation
loop" I'm inverting the order of taking and freeing a value from the semaphore.
So is for convenience: the standard Mutex would be more annoying to work with.
returnValue
is mostly obvious; is the actual promise that is going to be
fulfilled either at the successful end of the async block or when an
exception is encountered during await. doneV
is just a helper that keeps
track of the current state of returnValue
.
// ...
for {
scopeInstance: Scope <- IO.delay(new Scope {
override val cscope: ContinuationScope = scope
override val dispatcher: Dispatcher[IO] = dispatch._1
override val sem: Semaphore[IO] = m
override val deferred: Deferred[IO, Either[Throwable, Any]] =
returnValue.asInstanceOf
override val done: Ref[IO, Boolean] = doneV
})
// ...
} yield ()
scopeInstance
is just going to be the implicit context passed by the
context function into each await call. And then the continuation:
// ...
for {
continuation <- IO.delay(
new Continuation(
scope,
() => {
val result = Try {
f(using scopeInstance)
}
dispatch._1.unsafeRunSync(
IO.fromTry(result)
.flatten
.attempt
.flatMap(v => returnValue.complete(v)) *> doneV.set(true)
)
Continuation.`yield`(scope)
()
}
)
)
// ...
} yield ()
It simply wraps the execution of the async block to catch any impure exception, then lifts that wrapped value as an IO, to finally sequentially complete the deferred IO / promise. Then it follows the execution loop:
// ...
for {
_ <- m.acquire
_ <- (m.release *> IO.delay(continuation.run()) *> m.acquire)
.whileM_(doneV.get.map(!_))
result <- returnValue.get.flatMap(IO.fromEither).guarantee(dispatch._2)
// ...
} yield ()
First we acquire the lock and intermediately it is released. This, even
if quite redundad, is necessary to maintain the execution order inside the
loop, which is as follows: we release the current lock and continue with
the execution until we encounter a yield, then we try to acquire the lock
which will "lock" the current fiber (and allow execution of other fibers)
until the awaited value is fully resolved; then we are allowed to
immediately free the mutex and continue execution of the function. This
will happen in a loop until deferred
has been set with a result.
Then the return value is just lifted to an IO and we free the dispatcher.
That's basically all for async
, now for the await
part:
extension [A](effect: IO[A])
def await(implicit scope: Scope): A = {
var resumed: A = null.asInstanceOf
scope.dispatcher.unsafeRunSync(scope.sem.acquire)
scope.dispatcher.unsafeRunAndForget(
effect
.map(result => {
resumed = result
})
.onError(throwable =>
scope.deferred.complete(Left(throwable)) *> scope.done.set(true)
)
.guarantee(scope.sem.release)
)
Continuation.`yield`(scope.cscope)
resumed
}
We first initialize an impure var resumed
to null
. Then we acquire a
the lock from the semaphore, which prevents the execution loop from
returning into execution until resumed is actually the expected result
from the awaited effect (this lock is the same lock discussed in the part
of the execution loop). Then we fire in an asynchronous way the execution
of the awaited effect, such that either we set resumed
or we complete
the deferred
promise with an error (and by extension, we interrupt
the execution loop). In either way, we free the lock after any of those
two things are done. Then we yield execution; well, not "then" exactly
since it could occur in any moment during the execution of the effect.
This yield is the reason for for the locking mechanism in the main loop,
and a hacky way of locking just the fiber and not the actual thread.
And that's pretty much it.
Does it work?
Surprisingly yes, it does indeed work. Here is the mwe:
def funtionAsyncTest(someNumber: IO[Int]): IO[Int] = async {
val next = someNumber.map(_ + 1).await
val ref = Ref.of[IO, Int](next).await
val counter = Ref.of[IO, Int](100).await
Console[IO].println(s"Current value is ${next}. Starting loop.").await
async {
val thisValue = ref.get.await
Console[IO].println(s"Current value is ${next}. Ref ${thisValue}.").await
ref.set(thisValue + 1).await
counter.update(_ - 1).await
Console[IO].println(s"Current value is ${next}. Ref has been set.")
IO.unit
}.whileM_(counter.get.map(_ > 0)).await
Console[IO].println(s"Current value is ${next}. End loop.").await
ref.get
}
and even using the non-monadic loops from the language:
object Main extends IOApp.Simple {
override def run: IO[Unit] = async {
var fibers = Seq.empty[FiberIO[Unit]]
for (x <- 0 to 999) {
val fiber = async {
val returnValue = funtionAsyncTest(IO.pure(x)).await
IO { assert(returnValue == x + 101) }
}.start.await
fibers = fibers.appended(fiber)
}
fibers.map(_.join.await)
IO.unit
}
}
It works, almost perfectly that is. But...
Here come the _problems_
I did manage to get a deadlock, and I think I know why. Somehow it comes to me that the reason for which this is happening is the same reason that I suspected that could be a problem: context switching. That's why I believed it wouldn't just be one attempt.
Normal context switching is fine. The problem (maybe) arises when the runtime decides to switch the execution of one fiber to another work thread.
Internally, Continuation.yield
calls the following code:
public static boolean yield(ContinuationScope scope) {
Continuation cont = JLA.getContinuation(currentCarrierThread());
// ...
}
it then checks if the cont has the scope scope
, and tries again
with the parent continuation of cont
if not. But that is mostly
unimportant. What we care about is currentCarrierThread
. This
is a problem provided that the CE runtime switch execution of the
fiber to another thread, as it would not be able to also carry the
continuation to its new thread.
Misc problems
While they should not impose a semantic problem in the execution, there is a fair amount of data structures and allocations in this first implementation. I haven't benchmarked anything yet, but I believe that there is going to be a performance hit.
What's the plan?
- First I want to address the issue regarding the context switching,
maybe passing the cont using IOLocal rather than relaying in the
jre implementation of using the
currentCarrierThread
- Second I want to improve the actual usage of the locking mechanism
mostly to reduce allocations and make everything a little bit better
performance-wise, even if more ugly and impure. I mean, as long as
is a well scoped usage of impure operations, it should be fine;
even Ref is using impure things internally (
AtomicReference
)!
The second try
WIP.