Skip to content

Commit 0a97c6c

Browse files
committed
stdlib: DeepRecursiveFunction KT-31741
Introduces coroutine-based framework to execute deeply recursive functions that utilities the heap and thus avoid StackOverflowError. Fixes KT-31741
1 parent 151890d commit 0a97c6c

File tree

6 files changed

+389
-0
lines changed

6 files changed

+389
-0
lines changed

libraries/stdlib/js-v1/src/kotlin/coroutines/intrinsics/IntrinsicsJs.kt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,12 @@ public actual inline fun <R, T> (suspend R.() -> T).startCoroutineUninterceptedO
4545
completion: Continuation<T>
4646
): Any? = this.asDynamic()(receiver, completion, false)
4747

48+
@InlineOnly
49+
internal actual inline fun <R, P, T> (suspend R.(P) -> T).startCoroutineUninterceptedOrReturn(
50+
receiver: R,
51+
param: P,
52+
completion: Continuation<T>
53+
): Any? = this.asDynamic()(receiver, param, completion, false)
4854

4955
/**
5056
* Creates unintercepted coroutine without receiver and with result type [T].

libraries/stdlib/jvm/src/kotlin/coroutines/intrinsics/IntrinsicsJvm.kt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,12 @@ public actual inline fun <R, T> (suspend R.() -> T).startCoroutineUninterceptedO
5050
completion: Continuation<T>
5151
): Any? = (this as Function2<R, Continuation<T>, Any?>).invoke(receiver, completion)
5252

53+
@InlineOnly
54+
internal actual inline fun <R, P, T> (suspend R.(P) -> T).startCoroutineUninterceptedOrReturn(
55+
receiver: R,
56+
param: P,
57+
completion: Continuation<T>
58+
): Any? = (this as Function3<R, P, Continuation<T>, Any?>).invoke(receiver, param, completion)
5359

5460
// JVM declarations
5561

libraries/stdlib/src/kotlin/coroutines/CoroutinesIntrinsicsH.kt

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,13 @@ public expect inline fun <R, T> (suspend R.() -> T).startCoroutineUninterceptedO
4444
completion: Continuation<T>
4545
): Any?
4646

47+
// Internal version that support arity-2 suspending functions, might be made public in the future if needed
48+
internal expect inline fun <R, P, T> (suspend R.(P) -> T).startCoroutineUninterceptedOrReturn(
49+
receiver: R,
50+
param: P,
51+
completion: Continuation<T>
52+
): Any?
53+
4754
@SinceKotlin("1.3")
4855
public expect fun <T> (suspend () -> T).createCoroutineUnintercepted(
4956
completion: Continuation<T>
Lines changed: 213 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,213 @@
1+
/*
2+
* Copyright 2010-2020 JetBrains s.r.o. and Kotlin Programming Language contributors.
3+
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
4+
*/
5+
6+
package kotlin
7+
8+
import kotlin.coroutines.*
9+
import kotlin.coroutines.intrinsics.*
10+
11+
/**
12+
* Defines deep recursive function that keeps its stack on the heap,
13+
* which allows very deep recursive computations that do not use the actual call stack.
14+
* To initiate a call to this deep recursive function use its [invoke] function.
15+
* As a rule of thumb, it should be used if recursion goes deeper than a thousand calls.
16+
*
17+
* The [DeepRecursiveFunction] takes one parameter of type [T] and returns a result of type [R].
18+
* The [block] of code defines the body of a recursive function. In this block
19+
* [callRecursive][DeepRecursiveScope.callRecursive] function can be used to make a recursive call
20+
* to the declared function. Other instances of [DeepRecursiveFunction] can be called
21+
* in this scope with `callRecursive` extension, too.
22+
*
23+
* For example, take a look at the following recursive tree class and a deeply
24+
* recursive instance of this tree with 100K nodes:
25+
*
26+
* ```
27+
* class Tree(val left: Tree? = null, val right: Tree? = null)
28+
* val deepTree = generateSequence(Tree()) { Tree(it) }.take(100_000).last()
29+
* ```
30+
*
31+
* A regular recursive function can be defined to compute a depth of a tree:
32+
*
33+
* ```
34+
* fun depth(t: Tree?): Int =
35+
* if (t == null) 0 else max(depth(t.left), depth(t.right)) + 1
36+
* println(depth(deepTree)) // StackOverflowError
37+
* ```
38+
*
39+
* If this `depth` function is called for a `deepTree` it produces [StackOverflowError] because of deep recursion.
40+
* However, the `depth` function can be rewritten using `DeepRecursiveFunction` in the following way, and then
41+
* it successfully computes [`depth(deepTree)`][DeepRecursiveFunction.invoke] expression:
42+
*
43+
* ```
44+
* val depth = DeepRecursiveFunction<Tree?, Int> { t ->
45+
* if (t == null) 0 else max(callRecursive(t.left), callRecursive(t.right)) + 1
46+
* }
47+
* println(depth(deepTree)) // Ok
48+
* ```
49+
*
50+
* Deep recursive functions can also mutually call each other using a heap for the stack via
51+
* [callRecursive][DeepRecursiveScope.callRecursive] extension. For example, the
52+
* following pair of mutually recursive functions computes the number of tree nodes at even depth in the tree.
53+
*
54+
* ```
55+
* val mutualRecursion = object {
56+
* val even: DeepRecursiveFunction<Tree?, Int> = DeepRecursiveFunction { t ->
57+
* if (t == null) 0 else odd.callRecursive(t.left) + odd.callRecursive(t.right) + 1
58+
* }
59+
* val odd: DeepRecursiveFunction<Tree?, Int> = DeepRecursiveFunction { t ->
60+
* if (t == null) 0 else even.callRecursive(t.left) + even.callRecursive(t.right)
61+
* }
62+
* }
63+
* ```
64+
*
65+
* @param [T] the function parameter type.
66+
* @param [R] the function result type.
67+
* @param block the function body.
68+
*/
69+
@SinceKotlin("1.4")
70+
public class DeepRecursiveFunction<T, R>(
71+
internal val block: suspend DeepRecursiveScope<T, R>.(T) -> R
72+
)
73+
74+
/**
75+
* Initiates a call to this deep recursive function, forming a root of the call tree.
76+
*
77+
* This operator should not be used from inside of [DeepRecursiveScope] as it uses the call stack slot for
78+
* initial recursive invocation. From inside of [DeepRecursiveScope] use
79+
* [callRecursive][DeepRecursiveScope.callRecursive].
80+
*/
81+
@SinceKotlin("1.4")
82+
public operator fun <T, R> DeepRecursiveFunction<T, R>.invoke(value: T): R =
83+
DeepRecursiveScopeImpl<T, R>(block, value).runCallLoop()
84+
85+
/**
86+
* A scope class for [DeepRecursiveFunction] function declaration that defines [callRecursive] methods to
87+
* recursively call this function or another [DeepRecursiveFunction] putting the call activation frame on the heap.
88+
*
89+
* @param [T] function parameter type.
90+
* @param [R] function result type.
91+
*/
92+
@RestrictsSuspension
93+
@SinceKotlin("1.4")
94+
public sealed class DeepRecursiveScope<T, R> {
95+
/**
96+
* Makes recursive call to this [DeepRecursiveFunction] function putting the call activation frame on the heap,
97+
* as opposed to the actual call stack that is used by a regular recursive call.
98+
*/
99+
public abstract suspend fun callRecursive(value: T): R
100+
101+
/**
102+
* Makes call to the specified [DeepRecursiveFunction] function putting the call activation frame on the heap,
103+
* as opposed to the actual call stack that is used by a regular call.
104+
*/
105+
public abstract suspend fun <U, S> DeepRecursiveFunction<U, S>.callRecursive(value: U): S
106+
107+
@Deprecated(
108+
level = DeprecationLevel.ERROR,
109+
message =
110+
"'invoke' should not be called from DeepRecursiveScope. " +
111+
"Use 'callRecursive' to do recursion in the heap instead of the call stack.",
112+
replaceWith = ReplaceWith("this.callRecursive(value)")
113+
)
114+
@Suppress("UNUSED_PARAMETER")
115+
public operator fun DeepRecursiveFunction<*, *>.invoke(value: Any?): Nothing =
116+
throw UnsupportedOperationException("Should not be called from DeepRecursiveScope")
117+
}
118+
119+
// ================== Implementation ==================
120+
121+
private typealias DeepRecursiveFunctionBlock = suspend DeepRecursiveScope<*, *>.(Any?) -> Any?
122+
123+
private val UNDEFINED_RESULT = Result.success(COROUTINE_SUSPENDED)
124+
125+
@Suppress("UNCHECKED_CAST")
126+
private class DeepRecursiveScopeImpl<T, R>(
127+
block: suspend DeepRecursiveScope<T, R>.(T) -> R,
128+
value: T
129+
) : DeepRecursiveScope<T, R>(), Continuation<R> {
130+
// Active function block
131+
private var function: DeepRecursiveFunctionBlock = block as DeepRecursiveFunctionBlock
132+
133+
// Value to call function with
134+
private var value: Any? = value
135+
136+
// Continuation of the current call
137+
private var cont: Continuation<Any?>? = this as Continuation<Any?>
138+
139+
// Completion result (completion of the whole call stack)
140+
private var result: Result<Any?> = UNDEFINED_RESULT
141+
142+
override val context: CoroutineContext
143+
get() = EmptyCoroutineContext
144+
145+
override fun resumeWith(result: Result<R>) {
146+
this.cont = null
147+
this.result = result
148+
}
149+
150+
override suspend fun callRecursive(value: T): R = suspendCoroutineUninterceptedOrReturn { cont ->
151+
// calling the same function that is currently active
152+
this.cont = cont as Continuation<Any?>
153+
this.value = value
154+
COROUTINE_SUSPENDED
155+
}
156+
157+
override suspend fun <U, S> DeepRecursiveFunction<U, S>.callRecursive(value: U): S = suspendCoroutineUninterceptedOrReturn { cont ->
158+
// calling another recursive function
159+
val function = block as DeepRecursiveFunctionBlock
160+
with(this@DeepRecursiveScopeImpl) {
161+
val currentFunction = this.function
162+
if (function !== currentFunction) {
163+
// calling a different function -- create a trampoline to restore function ref
164+
this.function = function
165+
this.cont = crossFunctionCompletion(currentFunction, cont as Continuation<Any?>)
166+
} else {
167+
// calling the same function -- direct
168+
this.cont = cont as Continuation<Any?>
169+
}
170+
this.value = value
171+
}
172+
COROUTINE_SUSPENDED
173+
}
174+
175+
private fun crossFunctionCompletion(
176+
currentFunction: DeepRecursiveFunctionBlock,
177+
cont: Continuation<Any?>
178+
): Continuation<Any?> = Continuation(EmptyCoroutineContext) {
179+
this.function = currentFunction
180+
// When going back from a trampoline we cannot just call cont.resume (stack usage!)
181+
// We delegate the cont.resumeWith(it) call to runCallLoop
182+
this.cont = cont
183+
this.result = it
184+
}
185+
186+
@Suppress("UNCHECKED_CAST")
187+
fun runCallLoop(): R {
188+
while (true) {
189+
// Note: cont is set to null in DeepRecursiveScopeImpl.resumeWith when the whole computation completes
190+
val result = this.result
191+
val cont = this.cont
192+
?: return (result as Result<R>).getOrThrow() // done -- final result
193+
// The order of comparison is important here for that case of rogue class with broken equals
194+
if (UNDEFINED_RESULT == result) {
195+
// call "function" with "value" using "cont" as completion
196+
val r = try {
197+
// This is block.startCoroutine(this, value, cont)
198+
function.startCoroutineUninterceptedOrReturn(this, value, cont)
199+
} catch (e: Throwable) {
200+
cont.resumeWithException(e)
201+
continue
202+
}
203+
// If the function returns without suspension -- calls its continuation immediately
204+
if (r !== COROUTINE_SUSPENDED)
205+
cont.resume(r as R)
206+
} else {
207+
// we returned from a crossFunctionCompletion trampoline -- call resume here
208+
this.result = UNDEFINED_RESULT // reset result back
209+
cont.resumeWith(result)
210+
}
211+
}
212+
}
213+
}
Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
/*
2+
* Copyright 2010-2020 JetBrains s.r.o. and Kotlin Programming Language contributors.
3+
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
4+
*/
5+
6+
package test.utils
7+
8+
import kotlin.test.*
9+
10+
class DeepRecursiveTest {
11+
@Test
12+
fun testSimpleReturn() {
13+
// just returns a value without any recursive calls
14+
val ok = DeepRecursiveFunction<Int, String> { i -> "Ok$i" }
15+
assertEquals("Ok42", ok(42))
16+
}
17+
18+
@Test
19+
fun testDeepTreeDepth() {
20+
val n = 100_000
21+
assertEquals(n, depth(deepTree(n)))
22+
}
23+
24+
@Test
25+
fun testBinaryTreeDepth() {
26+
val k = 15
27+
assertEquals(k, depth(binaryTree(k)))
28+
}
29+
30+
private class MutualRec {
31+
val even: DeepRecursiveFunction<Tree?, Int> = DeepRecursiveFunction { t ->
32+
if (t == null) 0 else odd.callRecursive(t.left) + odd.callRecursive(t.right) + 1
33+
}
34+
35+
val odd: DeepRecursiveFunction<Tree?, Int> = DeepRecursiveFunction { t ->
36+
if (t == null) 0 else even.callRecursive(t.left) + even.callRecursive(t.right)
37+
}
38+
}
39+
40+
@Test
41+
fun testDeepTreeOddEvenNodesMutual() {
42+
val n = 100_000
43+
val dt = deepTree(n)
44+
val rec = MutualRec()
45+
assertEquals(n / 2, rec.even(dt))
46+
assertEquals(n / 2, rec.odd(dt))
47+
}
48+
49+
@Test
50+
fun testBinaryTreeOddEvenNodesMutual() {
51+
val k = 15
52+
val bt = binaryTree(k)
53+
val rec = MutualRec()
54+
assertEquals(21845, rec.even(bt))
55+
assertEquals(10922, rec.odd(bt))
56+
}
57+
58+
private class MutualAndDirectMixRec {
59+
val b: DeepRecursiveFunction<Int, String> = DeepRecursiveFunction { i -> "b$i" }
60+
61+
val a: DeepRecursiveFunction<Int, String> = DeepRecursiveFunction { i ->
62+
when (i) {
63+
// mix callRecursive calls to other function and in this context
64+
0 -> b.callRecursive(1) + callRecursive(2) + aa().callRecursive(3)
65+
else -> "a$i"
66+
}
67+
}
68+
69+
fun aa() = a
70+
}
71+
72+
@Test
73+
fun testMutualAndDirectMix() {
74+
// mix of callRecursion on this scope and on other DRF
75+
val rec = MutualAndDirectMixRec()
76+
val s = rec.a.invoke(0)
77+
assertEquals("b1a2a3", s)
78+
}
79+
80+
private class EqualToAnythingClassRec {
81+
var nullCount = 0
82+
83+
val a: DeepRecursiveFunction<Tree?, EqualToAnything> = DeepRecursiveFunction { t ->
84+
if (t == null) EqualToAnything(nullCount++) else b.callRecursive(t.left)
85+
}
86+
87+
val b: DeepRecursiveFunction<Tree?, EqualToAnything> = DeepRecursiveFunction { t ->
88+
if (t == null) EqualToAnything(nullCount++) else a.callRecursive(t.left)
89+
}
90+
}
91+
92+
@Test
93+
fun testEqualToAnythingClass() {
94+
// Mutually recursive tail calls & broken equals
95+
val rec = EqualToAnythingClassRec()
96+
val result = rec.a.invoke(deepTree(100))
97+
assertEquals(1, rec.nullCount)
98+
assertEquals(0, result.i)
99+
}
100+
101+
@Test
102+
fun testBadClass() {
103+
val compute = object {
104+
val a: DeepRecursiveFunction<Bad, Bad> = DeepRecursiveFunction { v -> Bad(v.i + 1) }
105+
val b: DeepRecursiveFunction<Bad, Bad> = DeepRecursiveFunction { v ->
106+
when (v.i) {
107+
0 -> callRecursive(Bad(1))
108+
1 -> Bad(a.callRecursive(Bad(19)).i + callRecursive(Bad(2)).i)
109+
2 -> Bad(a.callRecursive(Bad(20)).i + 1)
110+
else -> error("Cannot happen")
111+
}
112+
}
113+
}
114+
assertEquals(42, compute.b(Bad(0)).i)
115+
}
116+
117+
private class Tree(val left: Tree? = null, val right: Tree? = null)
118+
119+
private fun deepTree(n: Int) = generateSequence(Tree()) { Tree(it) }.take(n).last()
120+
121+
private fun binaryTree(k: Int): Tree? =
122+
if (k == 0) null else Tree(binaryTree(k - 1), binaryTree(k - 1))
123+
124+
private val depth = DeepRecursiveFunction<Tree?, Int> { t ->
125+
if (t == null) 0 else maxOf(
126+
callRecursive(t.left),
127+
callRecursive(t.right)
128+
) + 1
129+
}
130+
131+
// It is equals to any other class
132+
private class EqualToAnything(val i: Int) {
133+
override fun equals(other: Any?): Boolean = true
134+
override fun toString(): String = "OK"
135+
}
136+
137+
// Throws exception on all object methods
138+
private class Bad(val i: Int) {
139+
override fun equals(other: Any?): Boolean = error("BAD")
140+
override fun hashCode(): Int = error("BAD")
141+
override fun toString(): String = error("BAD")
142+
}
143+
}

0 commit comments

Comments
 (0)