Page main photo
😼

Hacking an Scala Async / Await syntax with Cats Effect and Continuations

NOTE: This document is WIP.

The story:

Virtual Threads were introduced into the JVM with Java 21, and with them, we got these internal apis:

import jdk.internal.vm.Continuation;
import jdk.internal.vm.ContinuationScope;

they are soooo internal that you require this flag:

--add-exports java.base/jdk.internal.vm=ALL-UNNAMED

And I did use them. Quite a few months ago, I built koka-shokki (name which should mean something about effect or so). This was an experiment that used and abused these apis to provide an effect system which looks more imperative rather than with a classic monadic style. This is how it looks like:

def divide(
      x: With[Optional, Int],
      y: With[Optional, Int]
  ): Int <~ Optional :| Lefted[String] =
    WithCapability[Optional :| Lefted[String], Int] {
      val result: With[Lefted[String], Int] =
        if (y.? == 0) asLeft("division by zero") else ok(x.? / y.?)
      result.?
    }

But there is a catch! I really never implemented an IO effect. And for some reason, my curiosity just ask me a few days ago to considerate something related to but not quite what koka-shokki is: what if this could allow Cats Effect's IO to run in an Async/await syntax?

So this is about that! But is an exploration rather than anything else, since I'm writing this at the same time that I'm writing the proof of concept. Just for fun :).

The API of continuations (from koka-shokki perspective):

This is what basically introduces the await syntax in koka-shokki:

def apply[C[_] <: Capability[?], A](f: Scope[C] ?=> A): With[C, A]

Since I don't quite remember what I wrote, let's analyze it a bit. So we have a function with two generic arguments C and A. Where C is the "effect" executed, which is important explicitly for koka-shokki, since we are constraining this to Cats IO only. Then A is just the normal return type of the burrito. Now there is f which is kind of special. The function f is using Scala 3 context functions, so this will only work in Scala 3 (since in Scala 2.13 this would be more verbose). The point of using these types of functions, if I remember correctly, is that it should feel more natural: you shouldn't have to know that the Scope is there implicitly when calling apply:

WithCapability[Optional :| Lefted[String], Int] {
  // do something!
  awaitMe.?
}

So, why the scope then?

The scope is used when calling the await operator .?, since it has to do the yield-ing into the continuation. Then .? requires a Scope as an implicit parameter. So this improves UX, since we don't have to worry about explicitly passing around a parameter.

How to Continuation and Scope

Internally, koka-shokki uses two Scopes. One is ContinuationScope from the package jdk.internal.vm. The other one is Scope which wraps the yield to the relevant ContinuationScope.

So we have this:

var r: With[C, A] = null.asInstanceOf[With[C, A]]
val scope: ContinuationScope = new ContinuationScope("WithCapability")
val scopeInstance: Scope[C] = new Scope {
  override def ko[B, K[_] <: Capability[?]](
      x: With[K, B]
  )(implicit ev: K[B] <:< C[B]): B = {
    r = x.asInstanceOf[
      With[C, A]
    ] // The cast is safe because of the proof from the implicit parameter
    Continuation.`yield`(scope)
    throw new RuntimeException("unreachable")
  }
}

Here r is the actual return type. Which is set when the await .? operator is executed (as ko is the relevant function for .?). Basically when we encounter an .?, we set r to the current Effect being executed with that .? when the Effect is "throwing" and then we yield into scope. See:

def ?(using scope: Scope[C]): A = value match {
    case Left(channel) => scope.ko(self)
    case Right(value)  => value
  }

So either we continue execution or we yield with the effect. Then, returning to the previous function, it follows:

val cont: Continuation = new Continuation(
  scope,
  () => {
    val result = f(using scopeInstance)
    r = new With[C, A] {
      val value = Right(result)
    }
    Continuation.`yield`(scope)
    ()
  }
)
cont.run

which basically sets the return value and yields, so we continue execution from cont.run.

How are we going to do it then?

Let's try first defining the type of the function that introduces the await syntax:

extension [A](effect: IO[A]) def await: A

def async[A](f: Scope ?=> IO[A]): IO[A]

If we think about the koka-s implementation, I used a variable to store the return value which was set when awaiting. This was easy to do as we can pattern match and extract the "non error" value from it. But with IOs this could be quite tricky.

Here it goes the first attempt

def async[A](f: Scope ?=> IO[A]): IO[A] = for {
  dispatch <- Dispatcher.parallel[IO].allocated
  m <- Semaphore[IO](1)
  returnValue <- Deferred[IO, Either[Throwable, A]]
  doneV <- Ref.of[IO, Boolean](false)
  scope: ContinuationScope <- IO.delay(ContinuationScope(s"async-io"))
  scopeInstance: Scope <- IO.delay(new Scope {
    override val cscope: ContinuationScope = scope
    override val dispatcher: Dispatcher[IO] = dispatch._1
    override val sem: Semaphore[IO] = m
    override val deferred: Deferred[IO, Either[Throwable, Any]] =
      returnValue.asInstanceOf
    override val done: Ref[IO, Boolean] = doneV
  })
  continuation <- IO.delay(
    new Continuation(
      scope,
      () => {
        val result = Try {
          f(using scopeInstance)
        }
        dispatch._1.unsafeRunSync(
          IO.fromTry(result)
            .flatten
            .attempt
            .flatMap(v => returnValue.complete(v)) *> doneV.set(true)
        )
        Continuation.`yield`(scope)
        ()
      }
    )
  )
  _ <- m.acquire
  _ <- (m.release *> IO.delay(continuation.run()) *> m.acquire)
    .whileM_(doneV.get.map(!_))
  result <- returnValue.get.flatMap(IO.fromEither).guarantee(dispatch._2)
} yield result

A lot is happening here. Let's start with this section:

// ...
for {
  dispatch <- Dispatcher.parallel[IO].allocated
  conti <- Deferred[IO, Continuation]
  m <- Semaphore[IO](1)
  returnValue <- Deferred[IO, Either[Throwable, A]]
  doneV <- Ref.of[IO, Boolean](false)
  // ...
} yield ()

Here dispatch is what does the trick: it provides the evaluation context, so we can have a result in direct style (we'll see later how). m is used to block the continuation from returning into the place it yielded. I'm using a semaphore from Cats Effect since I'm assuming it would yield (well, actually "cede") execution to another fiber while the effect executed by dispatch isn't completely evaluated. And you may think (or maybe no, who knows) that a mutex should be enough, but I did it this way because in the "continuation loop" I'm inverting the order of taking and freeing a value from the semaphore. So is for convenience: the standard Mutex would be more annoying to work with.

returnValue is mostly obvious; is the actual promise that is going to be fulfilled either at the successful end of the async block or when an exception is encountered during await. doneV is just a helper that keeps track of the current state of returnValue.

// ...
for {
  scopeInstance: Scope <- IO.delay(new Scope {
    override val cscope: ContinuationScope = scope
    override val dispatcher: Dispatcher[IO] = dispatch._1
    override val sem: Semaphore[IO] = m
    override val deferred: Deferred[IO, Either[Throwable, Any]] =
      returnValue.asInstanceOf
    override val done: Ref[IO, Boolean] = doneV
  })
  // ...
} yield ()

scopeInstance is just going to be the implicit context passed by the context function into each await call. And then the continuation:

// ...
for {
  continuation <- IO.delay(
    new Continuation(
      scope,
      () => {
        val result = Try {
          f(using scopeInstance)
        }
        dispatch._1.unsafeRunSync(
          IO.fromTry(result)
            .flatten
            .attempt
            .flatMap(v => returnValue.complete(v)) *> doneV.set(true)
        )
        Continuation.`yield`(scope)
        ()
      }
    )
  )
  // ...
} yield ()

It simply wraps the execution of the async block to catch any impure exception, then lifts that wrapped value as an IO, to finally sequentially complete the deferred IO / promise. Then it follows the execution loop:

// ...
for {
  _ <- m.acquire
  _ <- (m.release *> IO.delay(continuation.run()) *> m.acquire)
    .whileM_(doneV.get.map(!_))
  result <- returnValue.get.flatMap(IO.fromEither).guarantee(dispatch._2)
  // ...
} yield ()

First we acquire the lock and intermediately it is released. This, even if quite redundad, is necessary to maintain the execution order inside the loop, which is as follows: we release the current lock and continue with the execution until we encounter a yield, then we try to acquire the lock which will "lock" the current fiber (and allow execution of other fibers) until the awaited value is fully resolved; then we are allowed to immediately free the mutex and continue execution of the function. This will happen in a loop until deferred has been set with a result.

Then the return value is just lifted to an IO and we free the dispatcher.

That's basically all for async, now for the await part:

extension [A](effect: IO[A])
  def await(implicit scope: Scope): A = {
    var resumed: A = null.asInstanceOf
    scope.dispatcher.unsafeRunSync(scope.sem.acquire)
    scope.dispatcher.unsafeRunAndForget(
      effect
        .map(result => {
          resumed = result
        })
        .onError(throwable =>
          scope.deferred.complete(Left(throwable)) *> scope.done.set(true)
        )
        .guarantee(scope.sem.release)
    )
    Continuation.`yield`(scope.cscope)
    resumed
  }

We first initialize an impure var resumed to null. Then we acquire a the lock from the semaphore, which prevents the execution loop from returning into execution until resumed is actually the expected result from the awaited effect (this lock is the same lock discussed in the part of the execution loop). Then we fire in an asynchronous way the execution of the awaited effect, such that either we set resumed or we complete the deferred promise with an error (and by extension, we interrupt the execution loop). In either way, we free the lock after any of those two things are done. Then we yield execution; well, not "then" exactly since it could occur in any moment during the execution of the effect. This yield is the reason for for the locking mechanism in the main loop, and a hacky way of locking just the fiber and not the actual thread.

And that's pretty much it.

Does it work?

Surprisingly yes, it does indeed work. Here is the mwe:

def funtionAsyncTest(someNumber: IO[Int]): IO[Int] = async {
  val next = someNumber.map(_ + 1).await
  val ref = Ref.of[IO, Int](next).await
  val counter = Ref.of[IO, Int](100).await
  Console[IO].println(s"Current value is ${next}. Starting loop.").await
  async {
    val thisValue = ref.get.await
    Console[IO].println(s"Current value is ${next}. Ref ${thisValue}.").await
    ref.set(thisValue + 1).await
    counter.update(_ - 1).await
    Console[IO].println(s"Current value is ${next}. Ref has been set.")
    IO.unit
  }.whileM_(counter.get.map(_ > 0)).await
  Console[IO].println(s"Current value is ${next}. End loop.").await
  ref.get
}

and even using the non-monadic loops from the language:

object Main extends IOApp.Simple {
  override def run: IO[Unit] = async {
    var fibers = Seq.empty[FiberIO[Unit]]
    for (x <- 0 to 999) {
      val fiber = async {
        val returnValue = funtionAsyncTest(IO.pure(x)).await
        IO { assert(returnValue == x + 101) }
      }.start.await
      fibers = fibers.appended(fiber)
    }
    fibers.map(_.join.await)
    IO.unit
  }
}

It works, almost perfectly that is. But...

Here come the _problems_

I did manage to get a deadlock, and I think I know why. Somehow it comes to me that the reason for which this is happening is the same reason that I suspected that could be a problem: context switching. That's why I believed it wouldn't just be one attempt.

Normal context switching is fine. The problem (maybe) arises when the runtime decides to switch the execution of one fiber to another work thread.

Internally, Continuation.yield calls the following code:

public static boolean yield(ContinuationScope scope) {
    Continuation cont = JLA.getContinuation(currentCarrierThread());
    // ...
}

it then checks if the cont has the scope scope, and tries again with the parent continuation of cont if not. But that is mostly unimportant. What we care about is currentCarrierThread. This is a problem provided that the CE runtime switch execution of the fiber to another thread, as it would not be able to also carry the continuation to its new thread.

Misc problems

While they should not impose a semantic problem in the execution, there is a fair amount of data structures and allocations in this first implementation. I haven't benchmarked anything yet, but I believe that there is going to be a performance hit.

What's the plan?

  • First I want to address the issue regarding the context switching, maybe passing the cont using IOLocal rather than relaying in the jre implementation of using the currentCarrierThread
  • Second I want to improve the actual usage of the locking mechanism mostly to reduce allocations and make everything a little bit better performance-wise, even if more ugly and impure. I mean, as long as is a well scoped usage of impure operations, it should be fine; even Ref is using impure things internally (AtomicReference)!

The second try

WIP.

That's it for now

This is the full code, btw. :)