文章

深入 Kotlin 协程(三)封装一个协程

经过前面的学习,我们已经知道 Continuation 不是真正意义上的协程,也知道这套 API 足够灵活,可以封装出各种协程实现。遵循 Java 中 Thread 的习俗,这节开始来自己打造一个上层 API,定义出协程对应的对象 Coroutine,并实现相关操作的接口。

对于大部分人,学习协程的目的是使用协程,Kotlin 官方的协程框架也是采用这种风格,因此完成了自己打造的框架,对于官方 API 的理解可以提升一个层次,再使用的话自然手到擒来了。

源码在每一章节的开头(如果有必要的话)

delay 函数

CoroutineA 的目标是只使用协程标准库,即到不导入任何 kotlinx.coroutines 包。一个非常方便的函数 delay() 其实属于 kotlinx,所以也自己写一个吧!

所谓延时,无非是挂起协程,在某段时间后恢复。挂起协程的操作我们已经轻车熟路,至于定时,用原生工具就好啦。

private val executor = Executors.newScheduledThreadPool(1) { runnable ->
  Thread(runnable).apply { isDaemon = true }
}

suspend fun delay(delay: Long, timeUnit: TimeUnit = TimeUnit.MILLISECONDS) {
  if (delay <= 0) return
  suspendCoroutine { continuation ->
    executor.schedule({ continuation.resume(Unit) }, delay, timeUnit)
  }
}

这个函数没有适配协程取消,后面再完善。

定义协程

定义协程描述

参考 Thread,想完整描述一个协程,至少应该具有以下字段与方法:

  • isActive: 协程是否仍在执行(非完成或取消,不考虑挂起)
  • join: 阻塞当前协程,直到这个协程完成
  • cancel: 取消协程
  • addOnCancelCallback: 协程取消的回调
  • addOnCompleteCallback: 协程完成的回调
  • removeCallback: 移除注册的回调

那么可以定义出接口,为和 Kotlin 官方的实现靠拢,这里也取名为 Job:

interface Job : CoroutineContext.Element {
  object Key : CoroutineContext.Key<Job>
  interface Callback
  override val key: CoroutineContext.Key<*> get() = Key

  val isActive: Boolean
  fun addOnCancelCallback(onCancel: () -> Unit): Callback
  fun addOnCompleteCallback(onComplete: () -> Unit): Callback
  fun removeCallback(callback: Callback)
  fun cancel()
  suspend fun join()
}

internal class CancellationCallback(val onCancel: () -> Unit) : Job.Callback
internal class CompletionCallback<T>(val onComplete: (Result<T>) -> Unit) : Job.Callback

有几个需要注意的地方:

  • 继承 Element 接口是为了便于放入上下文中保存。
  • join() 定义为挂起函数,因为只有在协程环境里执行才能阻塞协程,否则没有意义。
  • 完成回调没有返回协程结果。结果用更优雅的手段返回,否则又出现回调地狱了。
  • 两个 Callback 的包装类是为了能够有统一的父接口,方便后续代码编写。
  • 添加回调要把这个回调返回是便于使用 lambda 语法调用,可以储存这个回调供后续移除。

定义状态

协程肯定得有状态。这里指的不是挂起/执行这种微观的,而是已完成/已取消这种宏观级别的。目前来讲需要三个状态:

  • 未完成:协程已经启动,还未执行完成。协程创建后默认启动,简单起见我们不再支持 Lazy 模式。
  • 正在取消:协程被取消,正在等待内部配合取消。此状态下协程返回后应该抛出取消异常,然后进入完成状态。
  • 完成:协程执行完成(包括正常结束、抛出异常、被取消)

注意一个细节,「未完成」之所以不定义成「正在运行」是为了避免歧义。「未完成」状态的协程不一定正在运行,也可能被挂起,不过这是微观状态,不是我们需要操心的。

流转如下:

stateDiagram-v2
direction LR

[*] --> Incomplete
Incomplete --> Complete: body end
Incomplete --> Cancelling: cancel
Cancelling --> Complete: body end
Complete --> [*]

注意,状态是单向流转的。这意味着所有状态回调最多被调用一次。

那么状态类如下:

sealed class CoroutineState(prev: CoroutineState?) {
  private var callbacks: List<Job.Callback>

  init {
    callbacks = prev?.callbacks ?: emptyList()
  }

  /** 协程启动后的状态,直到完成或取消 */
  class Incomplete(prev: CoroutineState? = null) : CoroutineState(prev)

  /** 被取消后的状态,等待内部配合取消。内部返回后抛出 CancellationException,进入 [Complete] 状态。 */
  class Cancelling(prev: CoroutineState? = null) : CoroutineState(prev)

  /** 协程执行完成(包括正常结束、抛出异常、被取消) */
  class Complete<T>(val value: T? = null, val exception: Throwable? = null, prev: CoroutineState? = null) : CoroutineState(prev)
}

估摸着很多人(包括我)都搞不懂为啥要把回调储存在状态里。 其实这是为了解决并发问题。后面协程实现中可以看到,状态使用了 AtomicReference 存储来确保更新的原子性,假设按照一般思路,把回调储存在 Job 中,那么回调的注册与状态的更新就不是原子操作(暂时不考虑加锁,对于这种通用库来说太重了)。设想这样一个场景:

注册回调时要先判断状态(例如若已经被取消了,就要立即调用取消回调),A 请求注册取消回调,内部先判断状态,判断完毕后 A 被挂起,B 请求取消,然后 A 恢复执行。此时 A 以为当前还在执行,所以正常注册回调,而这个取消回调永远不会被触发。如下图所示:

sequenceDiagram

participant CallerA
participant CallerB
participant Coroutine

Note right of Coroutine: Incomplete
CallerA ->> Coroutine: getState
Coroutine ->> CallerA: Incomplete
CallerB ->> Coroutine: cancel
Note right of Coroutine: Cancelling
CallerA ->> Coroutine: save callback to list (never called)

由此可见,回调的注册必须与状态绑定变成一个原子操作才行。

除此之外,储存回调的容器也必须是不可变的。否则又会出现并发问题,如下:

sequenceDiagram

participant CallerA
participant CallerB
participant Coroutine

Note right of Coroutine: Incomplete
CallerA ->> Coroutine: confirmState
CallerA ->> Coroutine: getCallbacks
Coroutine ->> CallerA: callbacks
CallerB ->> Coroutine: cancel
Note right of Coroutine: Cancelling
CallerA ->> Coroutine: add callback to list (never called)

至于容器那个变量,这里设计成可变的便于 API 设计,我们每次状态改变都会创建新的状态对象,所以不会导致问题。

List.toMutableList() 默认使用的是 ArrayList 实现,插入与删除性能不高。所以这里自己实现一些扩展函数,改用 LinkedList 实现。

private fun <T> List<T>.add(ele: T): List<T> = LinkedList(this).apply {
  add(ele)
}

private fun <T> List<T>.remove(ele: T): List<T> = LinkedList<T>().also { newList ->
  for (t in this)
    if (t != ele) newList.add(t)
}

/** 遍历指定类型的元素 */
private inline fun <reified T> List<*>.loopOn(crossinline action: (T) -> Unit) {
  for (s in this)
    if (s is T) action(s)
}

基于这些,可以再写几个便捷函数,帮助添加/删除回调:

fun with(callback: Job.Callback): CoroutineState = this.apply {
  callbacks = callbacks.add(callback)
}

fun without(callback: Job.Callback): CoroutineState = this.apply {
  callbacks = callbacks.remove(callback)
}

终于把状态类定义完了。

实现协程

</> 源码

定义完了下面就是实现。考虑到未来带有返回值的协程等,这里先写一个抽象实现定义如下:

abstract class AbstractCoroutine<T>(context: CoroutineContext) : Job, Continuation<T>{
  protected val state = AtomicReference<CoroutineState>(CoroutineState.Incomplete())
  override val context: CoroutineContext = context + this
}

虽然是抽象的,实际上几乎实现了所有功能。

它需要实现 Continuation 才能拿到协程执行结果,进而维护内部状态并通知外部注册的回调。

isActive 的实现没什么好说的,根据 state 的值来返回就行了。

背景知识

AtomicReference.updateAndGet() 等方法内部使用 CAS 算法,其 labdam 参数可能被多次调用(如果这一次更新没有成功的话),所以注意幂等性。

添加回调

前面提到过,添加回调不能直接添加,要先判断状态,具体来讲有三种可能:

  • 还没到达回调状态:添加回调
  • 已经到达回调状态:直接执行回调
  • 已经到达后续状态:忽略回调

同时记得每次状态改变(包括添加回调)都必须创建新的状态对象来保证并发安全。

添加取消回调:

override fun addOnCancelCallback(onCancel: () -> Unit): Job.Callback =
  CancellationCallback(onCancel).also { callback ->
    val newState = state.updateAndGet { prev ->
      when (prev) {
        is CoroutineState.Cancelling,
        is CoroutineState.Complete<*> -> prev // 不可能被取消了,忽略这个回调
        is CoroutineState.Incomplete -> CoroutineState.Incomplete(prev).with(callback)
      }
    }
    if (newState is CoroutineState.Cancelling) {
      onCancel() // 已经被取消,立即调用
    }
  }

添加完成回调:

// 内部使用的带有结果的完成回调
private fun doOnComplete(block: (Result<T>) -> Unit): Job.Callback =
  CompletionCallback(block).also { callback ->
    val newState = state.updateAndGet { prev ->
      when (prev) {
        // 这里不能偷懒合并写成 prev.with(callback),必须要创建新对象才可以
        is CoroutineState.Incomplete -> CoroutineState.Incomplete(prev).with(callback)
        is CoroutineState.Cancelling -> CoroutineState.Cancelling(prev).with(callback)
        is CoroutineState.Complete<*> -> prev // 已经完成了,不用添加
      }
    }
    if (newState is CoroutineState.Complete<*>) {
      // 已经完成了,立即调用
      if (newState.exception != null) block(Result.failure(newState.exception))
      else block(Result.success(newState.value as T))
    }
  }

override fun addOnCompleteCallback(onComplete: () -> Unit): Job.Callback =
  doOnComplete { onComplete() }

移除回调

取消回调比较简单,不用判断状态了。

override fun removeCallback(callback: Job.Callback) {
  state.updateAndGet { prev ->
    when (prev) {
      is CoroutineState.Incomplete -> CoroutineState.Incomplete(prev).without(callback)
      is CoroutineState.Cancelling -> CoroutineState.Cancelling(prev).without(callback)
      is CoroutineState.Complete<*> -> CoroutineState.Complete(prev).without(callback)
    }
  }
}

不厌其烦地提醒,不能偷懒写成 prev.without(callback),必须创建新对象保证并发安全。

join

有了回调注册,join() 的实现就简单了。听起来比较高大上,所谓「阻塞协程直到完成」,翻译一下不就是挂起当前协程,等触发完成回调后再恢复就行了呗。通过 suspendCoroutine() 的挂起操作用过很多遍了。

别忘了要判断状态,如果协程已经完成就立即返回,不然就真的挂而不起了。

private suspend fun joinSuspend() = suspendCoroutine { continuation ->
  doOnComplete {
    continuation.resume(Unit)
  }
}

override suspend fun join() {
  when (state.get()) {
    is CoroutineState.Complete<*> -> return
    is CoroutineState.Incomplete,
    is CoroutineState.Cancelling -> joinSuspend()
  }
}

这里在概念上有一点混淆。如果你有 「挂起了协程,那协程怎么完成?」 的疑问,就是被绕进去了。挂起的是当前的协程,也就是执行 join() 的协程,它不影响内部的协程。

cancel

取消操作也比较简单,因为并不需要真的取消,只要设置一个标记(也就是状态)。记得通知回调们。

override fun cancel() {
  // 注意这里不是 updateAndGet 了,我们需要判断的是之前的状态
  val prevState = state.getAndUpdate { prev ->
    when (prev) {
      is CoroutineState.Incomplete -> CoroutineState.Cancelling(prev)
      is CoroutineState.Cancelling,
      is CoroutineState.Complete<*> -> prev
    }
  }
  if (prevState is CoroutineState.Incomplete) {
    prevState.notifyCancellation() // 通知回调
  }
}

不要粗心把逻辑写成了「如果新的状态是取消,则通知回调」,这样多次取消可能多次通知回调。就算每次回调后就把对应的 Callback 也不行,cancel() 可能被并发调用,取消回调的触发与移除并不被 AtomicReference 保护,如果在移除之前线程被挂起,那么后一个调用将重复回调。

resumeWith

最后就是处理 Continuation 的完成回调了。这也意味着我们这个协程已经执行完成。更新一下状态并通知回调就行。

override fun resumeWith(result: Result<T>) {
  val newState = state.updateAndGet { prev ->
    when (prev) {
      is CoroutineState.Incomplete,
      is CoroutineState.Cancelling -> CoroutineState.Complete(
        result.getOrNull(), result.exceptionOrNull(), prev
      )
      is CoroutineState.Complete<*> -> throw IllegalStateException("Already completed!")
    }
  }
  newState.notifyCompletion(result)
  newState.clear()
}

和取消不同,Continuation 的状态流转决定了不可能多次调用 resumeWith() 方法,也无需额外判断了。

构造器

到此为止,一个简单的协程已经封装完了,但用起来还有一点烦琐。那就再加个构造器。

首先添加一个基本的无返回值的协程类:

class SimpleCoroutine(context: CoroutineContext) : AbstractCoroutine<Unit>(context)

然后是构造器:

fun launch(context: CoroutineContext = EmptyCoroutineContext, block: suspend () -> Unit): Job =
  SimpleCoroutine(context).apply {
    block.startCoroutine(this)
  }

来试试我们的迷你协程库吧:

suspend fun main() {
  val job = launch {
    delay(1000)
  }
  job.addOnCompleteCallback {
    println("Complete")
  }
  job.join()
}
// launch, delay, job 都是我们自己实现的,没有导入 kotlinx.coroutines

完美!

带返回值的协程

前面实现的抽象协程已经支持返回值了,只是没有提供 API 把结果取出来。现在创建一个新接口,同样模仿官方框架,取名为 Deferred,定义如下:

interface Deferred<T> : Job {
  suspend fun await(): T
}

await() 行为如下:

  • 若协程已经完成则返回结果或抛出异常。
  • 若协程未完成则挂起直到完成(类似 join

基于抽象协程,几乎没有什么要写的代码了,实现一下 join() 返回结果就行。

class DeferredCoroutine<T>(context: CoroutineContext) : AbstractCoroutine<T>(context), Deferred<T> {
  override suspend fun await(): T = when (val s = state.get()) {
    is CoroutineState.Complete<*> -> s.exception?.let { throw it } ?: s.value as T
    else -> awaitSuspend()
  }

  // 和 join 的实现几乎一样
  private suspend fun awaitSuspend() = suspendCoroutine { continuation ->
    doOnComplete { result ->
      continuation.resumeWith(result)
    }
  }
}

最后再补一个构造器:

fun <T> async(context: CoroutineContext = EmptyCoroutineContext, block: suspend () -> T): Deferred<T> =
  DeferredCoroutine<T>(context).apply {
    block.startCoroutine(this)
  }

完成!

调度器

</> 源码

默认行为

目前我们的协程还没有任何线程调度器,默认执行在启动协程的线程。但不完全是这样! 执行下面的代码:

suspend fun main() {
  val job = launch {
    println(Thread.currentThread()) // Thread[main,5,main]
    delay(1000)
    println(Thread.currentThread()) // Thread[Thread-0,5,main]
  }
  job.join()
}

虽然看起来没有切线程,但实际上两行打印执行在不同的线程上。代码执行在哪个线程取决于挂起后从哪里调用的 continuation.resume(Unit)。目前我们的 delay 实现用到了 executor,显然提交的任务在后台线程执行。

设计调度器

调度器的本质是执行某一代码,所以定义非常简单:

interface Dispatcher {
  fun dispatch(block: () -> Unit)
}

除了首次启动,后面只有挂起的时候才有机会调度。为了在挂起时能够插一脚,需要实现一个协程拦截器。拦截器前几篇接触过,稍稍回顾一下,它是一类协程上下文,对应的 Key 是 ContinuationInterceptor.Key,根据 Kotlin 伴生对象的语法,使用时可以缩写为 ContinuationInterceptor。核心方法 interceptContinuation() 用于把要恢复执行的 Continuation 包装成另一个,包装类里面可以选择是否恢复、如何恢复等,自然也包括在哪个线程恢复。

那么就先来实现这个包装类吧:

private class DispatchedContinuation<T>(private val delegate: Continuation<T>, private val dispatcher: Dispatcher) : Continuation<T> {
  override val context: CoroutineContext = delegate.context
  override fun resumeWith(result: Result<T>) {
    dispatcher.dispatch { delegate.resumeWith(result) }
  }
}

根据设计,在哪个线程恢复是 Dispatcher 的职责,所以直接甩锅过去。

最后是上下文的实现,用上面的包装类包一下:

open class DispatcherContext(private val dispatcher: Dispatcher) : ContinuationInterceptor {
  override val key: CoroutineContext.Key<*> = ContinuationInterceptor.Key

  override fun <T> interceptContinuation(continuation: Continuation<T>): Continuation<T> =
    DispatchedContinuation(continuation, dispatcher)
}

这样调度器的框架就搭好了,可以随时添加到协程上下文中。

实现调度器

Kotlin 官方框架中默认调度到后台线程,并且为 CPU 密集型任务设计,我们也来模仿实现一个。显然这里需要一个线程池,并且因为是 CPU 密集任务,超过物理 CPU 个数的线程没有实际意义。

实现如下,比较简单:

object DefaultDispatcher : Dispatcher {
  private val executor = Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors()) { runnable ->
    Thread(runnable).apply { isDaemon = true }
  }

  override fun dispatch(block: () -> Unit) {
    executor.submit(block)
  }
}

唯一要补充的是 isDaemon,设置为守护线程后就不会阻止 JVM 退出。否则线程池始终维护这几个线程,很可能导致 JVM 虚拟机长期运行。

在 Android 等 UI 程序中通常存在一个「主线程」,开发者也经常需要切换到主线程去操作 UI。这类调度器的实现原理一样,只是需要一个工具把任务抛到主线程上去。这个工具带有事件循环的框架都会提供,比如 Android 的 Handler。这个 Demo 是纯 kotlin 项目,为了保持精简就不实现了。

默认添加调度器

本节一开始就演示了,因为 delay 的实现,什么调度器都不添加时行为是不确定的。那就很有必要默认给一个调度器,好让任务执行在确定类型的线程上。

注意,这里说的是「确定类型的线程」,而不是「确定的线程」。根据调度器的实现,通常不能也无需调度到某个具体的线程。比如刚才的 DefaultDispatcher 就只调度到后台线程。当然,某些平台的主线程调度器除外。

千万别犯糊涂在构造器函数中直接加 DefaultDispatcher 了,这样可能会覆盖先前手动指定的调度器。方便起见封装一个函数:

fun withDefaultDispatcher(context: CoroutineContext): CoroutineContext =
  if (context[ContinuationInterceptor] == null)
    context + Dispatchers.Default // 如果当前没有配置拦截器,就添加默认调度器
  else
    context

最后在构造器函数中调用一下:

fun launch(context: CoroutineContext = EmptyCoroutineContext, block: suspend () -> Unit): Job =
  SimpleCoroutine(context.withDefaultDispatcher()).apply { // 添加了默认调度器
    block.startCoroutine(this)
  }

现在再执行本节一开始的测试代码,可以发现前后都在后台线程中了:

suspend fun main() {
  val job = launch {
    println(Thread.currentThread()) // Thread[Thread-0,5,main]
    delay(1000)
    println(Thread.currentThread()) // Thread[Thread-2,5,main]
  }
  job.join()
}

实际上 delay() 后代码也是由调度器调度的。虽然 delay() 的实现在它内部的线程中恢复了协程的执行,但这个 Continuation 只是一个包装,它的 resumeWith() 方法在调度器里恢复了真正的 Continuation

协程的取消

</> 源码

众所周知,协程的取消需要内部配合。所谓的配合有两种形式:

  1. 协程体只有普通的函数,需要定时检查协程状态,如果已经取消就尽早返回。
  2. 协程体内部挂起,启动了异步操作,此时要想办法注册取消回调,在回调中通知自己的异步操作取消。(delay 是一个典型的这种案例)

第一个情况没啥好说的,协程体内部自己判断就行了。第二个情况则需要扩展给 lambda 的 Continuation 的参数,给它增加相关的 API。既然要扩展,那先看看默认的是个什么东西:

@SinceKotlin("1.3")
@InlineOnly
public suspend inline fun <T> suspendCoroutine(crossinline block: (Continuation<T>) -> Unit): T {
    contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
    return suspendCoroutineUninterceptedOrReturn { c: Continuation<T> ->
        val safe = SafeContinuation(c.intercepted())
        block(safe)
        safe.getOrThrow()
    }
}

看到它是个 SafeContinuation,这个东西在[第一章]({{< ref "深入 Kotlin 协程1.md" >}})中提到过,它有两个作用:

  • 保证这个 Continuation 只被恢复(resume)一次。希望各位还记得,Continuation 是协程中的某一段代码,那当然只应该执行一次。
  • 去除不必要的挂起。

什么是不必要的挂起? 看下面两个函数:

suspend fun noSuspend1(): Int {
  return 1 + 1
}

suspend fun noSuspend2() = suspendCoroutine<Int> { continuation ->
  continuation.resume(1 + 1)
}

大家应该能看出来 noSuspend1() 虽然定义为挂起函数,但它不会挂起,IDEA 里也会提示我们这里的 suspend 关键字多余。但第二个或许有人就搞不清了,其实它也不挂起!只有当挂起点切换的函数调用栈(说人话就是或切线程,或扔到某个事件循环里等待执行)时才会真正挂起。SafeContinuation 的职责之一就是区分这几种情况。

具体来讲咋区分的?根据刚才的总结,只要在 lambda 里实际调用了 resume() 那必然不需要挂起(搁这直接把工作做完就恢复了还挂起啥?你得把工作放到其他地方做,做完的时候在回调中 resume() 才需要挂起呢)。所以方法就是 SafeContinuation 内部维护一个标记,看看执行完 blockresume() 方法有没有被调用过。如果已经调用过了就不再需要挂起。

知道了怎么区分要不要挂起,那到底谁来执行「挂起」?当然是 Kotlin 来执行。我们要做的就是在 suspendCoroutineUninterceptedOrReturn{} lambda 中返回一个标记告诉 Kotlin 要不要挂起就行了,不用操心执行。这个标记就是 kotlin.coroutines.intrinsics.COROUTINE_SUSPENDED,根据约定,如果不需要挂起则应该返回这个挂起函数的结果(返回值)。因此这个 lambda 声明的返回值类型是 Any

CancellableContinuation

SafeContinuation 不可以继承,为了提供取消相关的 API 得建一个 Continuation 的包装类,暂且叫 CancellableContinuation,并且它还得提供 SafeContinuation 原先的功能。为了判断有无必要挂起,需要一个变量记录 resume() 方法是否被调用了。为了防止 resume() 重复调用,内部还需要实现一个状态。

可能有伙伴好奇,为啥不给 Continuation 弄几个扩展函数,从上下文中取出 Job 来注册回调监听呢,这样就不用额外实现一个 CancellableContinuation 了,也省得把 SafeContinuation 的逻辑再写一遍。原因有好几个:

  • 最直观的,收到取消回调后应该抛出 CancellationException,这个工作交给开发者自己做太不靠谱了,还是我们封装一下统一抛吧。

  • 最必要的,只有确定挂起时才能注册 Job 的取消监听。否则开发者很可能写出下面的代码:

    suspend fun calc(): Int = suspendCoroutine { continuation ->
      (continuation.context[Job.Key])?.addOnCancelCallback {
        continuation.resumeWithException(CancellationException()) // 规范要求的
      }
      continuation.resume(1 + 1)
    }
      
    

    这个代码根本不会挂起,但的确可能收到取消回调,后果就是 continuation 被恢复了两遍,出错了。把是否需要注册取消监听交给广大开发者自行判断太难为它们了。

我们的类定义如下:

class CancellableContinuation<T>(private val continuation: Continuation<T>) : Continuation<T> {
  private sealed class State {
    class Incomplete(val cancelCallback: (() -> Unit)? = null) : State()
    class Complete<T>(val value: T?, val exception: Throwable? = null) : State()
    object Cancelled : State()
  }

  private enum class Decision {
    UNDECIDED, SUSPENDED, RESUMED
  }
  
  override val context: CoroutineContext
    get() = continuation.context
  
  private val state = AtomicReference<State>(State.Incomplete())
  private val decision = AtomicReference(Decision.UNDECIDED)
  
  val isCompleted: Boolean
    get() = state.get().let { it is State.Complete<*> || it is State.Cancelled }
}

Job 类似,把回调保存在状态里保证并发安全。但这里简单许多,因为只需要保存一个取消回调就行了,相关代码如下:

fun registerOnCancelCallback(cancelCallback: () -> Unit) {
  val newState = state.updateAndGet { prev ->
    when (prev) {
      is State.Incomplete -> State.Incomplete(cancelCallback)
      State.Cancelled,
      is State.Complete<*> -> prev // 已经完成了就没必要保存取消回调 —— 永远不会被触发
    }
  }
  if (newState === State.Cancelled)
    cancelCallback() // 已经取消的情况下也不用保存回调,直接触发就行
}

接着是收到 Job 取消回调时应该做的事情:

private fun doOnCancel() {
  val prev = state.getAndUpdate { prev ->
    when (prev) {
      is State.Incomplete -> State.Cancelled
      State.Cancelled,
      is State.Complete<*> -> prev
    }
  }
  if (prev is State.Incomplete) {
    prev.cancelCallback?.invoke() // 触发取消回调
      resumeWithException(CancellationException()) // 抛出取消异常
  }
}

注册 Job 取消回调的方法,比较简单:

private fun installCoroutineCancelledCallback() {
  if (isCompleted) return // 当前 Continuation 已经完成,就不用理会回调了
  continuation.context[Job.Key]?.addOnCancelCallback {
    doOnCancel()
  }
}

接着是 resumeWith() 实现:

override fun resumeWith(result: Result<T>) {
  if (decision.compareAndSet(Decision.UNDECIDED, Decision.RESUMED)) {
    // 代表在 lambda 里直接就 resume 了,不需要挂起
    state.set(State.Complete(result.getOrNull(), result.exceptionOrNull()))
  } else if (decision.compareAndSet(Decision.SUSPENDED, Decision.RESUMED)) {
    state.updateAndGet { prev ->
      when (prev) {
        is State.Complete<*> -> throw IllegalStateException("Already completed.")
        else -> State.Complete(result.getOrNull(), result.exceptionOrNull())
      }
    }
    // 挂起后的恢复执行
    continuation.resumeWith(result)
  }
}

最后别忘了给出一个 API 来取得返回值:

fun getResult(): Any? {
  installCoroutineCancelledCallback() // 内部判断了只有未完成时才注册

  if (decision.compareAndSet(Decision.UNDECIDED , Decision.SUSPENDED))
    // 如果还是 UNDECIDED 就代表 resumeWith() 未被调用过,所以需要挂起,返回挂起标记
    return COROUTINE_SUSPENDED
  return when (val s = state.get()) {
    is State.Incomplete -> COROUTINE_SUSPENDED
    State.Cancelled -> throw CancellationException()
    
    // 如果此时已经完成了表示 lambda 内部就执行了 resumeWith(),那么同步返回结果
    is State.Complete<*> -> (s as State.Complete<T>).run {
      exception?.let { e -> throw e } ?: value
    }
  }
}

最最后,模仿 Kotlin,给出一个可取消版本的挂起函数:

suspend fun <T> suspendCancellableCoroutine(block: (CancellableContinuation<T>) -> Unit) =
  suspendCoroutineUninterceptedOrReturn<T> {
    val cancellable = CancellableContinuation(it.intercepted())
    block(cancellable)
    return@suspendCoroutineUninterceptedOrReturn cancellable.getResult()
  }

可取消的 delay

用刚刚完成的 API 来改写 dealy

suspend fun delay(delay: Long, timeUnit: TimeUnit = TimeUnit.MILLISECONDS) {
  if (delay <= 0) return
  suspendCancellableCoroutine { continuation ->
    val future = executor.schedule({ continuation.resume(Unit) }, delay, timeUnit)
    continuation.registerOnCancelCallback {
      future.cancel(true) // 把取消回调传递给内部的异步任务
      // 我们封装的 CancellableContinuation 内部会执行 resumeWithException(CancellationException()),
      // 不用开发者操心了
    }
  }
}

可取消的 join

// AbstractCoroutine
private suspend fun joinSuspend() = suspendCancellableCoroutine { continuation ->
  val callback = doOnComplete {
    continuation.resume(Unit)
  }
  continuation.registerOnCancelCallback {
    removeCallback(callback)
  }
}

override suspend fun join() {
  when (state.get()) {
    is CoroutineState.Complete<*> -> {
      if (coroutineContext[Job.Key]?.isActive == false)
        throw CancellationException("Coroutine is cancelled")
    }

    is CoroutineState.Incomplete,
    is CoroutineState.Cancelling -> joinSuspend()
  }
}	

注意两个地方:

  • join() 中读取的状态是要等待的协程的状态而 coroutineContext 拿到的是调用 join() 的协程。
  • joinSuspend() 中拿到的 continuation 属于调用 joinSuspend() 的协程,而不是等待的协程。但回调是注册到要等待的协程上的。

joinSuspend() 不用额外处理要等待的协程的取消,因为取消也是完成的一种,注册到完成回调会即使触发从而恢复挂起等待协程的执行。只需要关注自己所在的协程的取消就行,这种情况应该取消回调,因为我们的框架会自动用 resumeWithException() 恢复当前协程的执行,后续等待的协程结束触发回调,会再次尝试恢复当前协程,导致异常。

可取消的 await

await 的取消与 join 类似:

private suspend fun awaitSuspend() = suspendCancellableCoroutine { continuation ->
  val callback = doOnComplete { result ->
    continuation.resumeWith(result)
  }
  continuation.registerOnCancelCallback {
    removeCallback(callback)
  }
}

override suspend fun await(): T = when (val s = state.get()) {
  is CoroutineState.Complete<*> -> if (coroutineContext[Job.Key]?.isActive == false) {
    throw CancellationException()
  } else {
    s.exception?.let { throw it } ?: s.value as T
  }
  else -> awaitSuspend()
}

  • awaitSuspend()continuation 是调用 await() 的协程,完成回调注册到的是要等待的协程。

测试:

suspend fun main() {
  val job = launch {
    async {
      delay(5000)
      10
    }.await().also { println("result: $it") }
  }
  launch {
    delay(1000)
    job.cancel()
  }
  job.join()
  println("END")
}

// 效果:程序只等待1秒就结束了,没有打印 result

分析:job 是一个协程,其内部又启动了一个协程并等待结果。

主协程等待 job。但是 job 大约 1 秒后在另一个协程中被取消了,所以不再等待 5 秒的结果直接退出。主协程得以继续执行打印 END 后退出。

异常处理

</> 源码

协程内部当然可能抛异常,事实上前面实现取消的时候已经抛了取消异常。那这些异常最终去哪了?别忘了异常也算是一种退出,所以会触发 Continuation.resumeWith() 前面的实现中我们把 Result 里的异常存入了 Complete 状态,就...没有然后了。

原生的 try-catch 依然可以在协程体内部捕捉异常,但不能跨协程。

异常处理器

类似完成与取消回调,我们来定义一个异常回调,或者叫异常处理器。

fun interface CoroutineExceptionHandler : CoroutineContext.Element {
  companion object Key : CoroutineContext.Key<CoroutineExceptionHandler>
  override val key: CoroutineContext.Key<*>
    get() = Key

  fun handleException(context: CoroutineContext, exception: Throwable)
}

为啥这次不放在状态中了?因为没必要呀,之前这么干是因为取消和完成都与状态有关,为了避免并发问题才这么做。异常处理没有这个问题。放在上下文中用起来更方便。

注意不是所有协程的异常都应该交给处理器,比如:

  1. 对于带有返回值的协程,如果返回值不被使用这个协程就没啥意义,所以即使有异常也应该在获取结果时抛出,而不是转给处理器。
  2. 取消异常比较特殊,它只是一个标记,不用分发下去。

针对第一点,在 AbstractCoroutine 默认不处理异常:

// AbstractCoroutine
/** 处理协程异常。已经处理返回 true */
open fun handleException(e: Throwable) = false

在普通协程 SimpleCoroutine 中处理:

// SimpleCoroutine
override fun handleException(e: Throwable): Boolean {
  context[CoroutineExceptionHandler.Key]?.handleException(context, e)
    ?: Thread.currentThread().let { t ->
      t.uncaughtExceptionHandler.uncaughtException(t, e)
    }
  return true
}

具体来说:

  • 如果设置了异常处理器则转交。
  • 否则交给线程的异常处理器。

刚才总结的第二点,取消异常比较特殊,不应该分发下去。所以在分发异常之前做一个小小的判断:

private fun tryHandleException(e:Throwable):Boolean{
  if ( e is CancellationException)
  return false
  return handleException(e)
}

最后在 resumeWith() 回调中来处理异常:

override fun resumeWith(result: Result<T>) {
  // ...
  newState.notifyCompletion(result)
  (newState as CoroutineState.Complete<T>).exception?.also { e ->
    tryHandleException(e) // 如果有异常就处理它
  }
}

使用

与调度器不同,异常处理器不需要默认添加(默认行为是转交线程的异常处理)。如果要使用,把它添加到上下文就行了:

suspend fun main() {
  val handler = object : CoroutineExceptionHandler {
    override fun handleException(context: CoroutineContext, exception: Throwable) {
      println("Catch exception: $exception")
    }
  }
  launch(handler) {
    delay(1000)
    throw RuntimeException("test")
  }.join()
}
// 打印 Catch exception: java.lang.RuntimeException: test
// 如果不设置处理器,对应线性会异常:Exception in thread "Thread-2"