1
1
package love.forte.plugin.suspendtrans.ir
2
2
3
- import love.forte.plugin.suspendtrans.*
3
+ import love.forte.plugin.suspendtrans.SuspendTransformConfiguration
4
+ import love.forte.plugin.suspendtrans.SuspendTransformUserData
5
+ import love.forte.plugin.suspendtrans.SuspendTransformUserDataKey
6
+ import love.forte.plugin.suspendtrans.fqn
4
7
import love.forte.plugin.suspendtrans.utils.*
5
8
import org.jetbrains.kotlin.backend.common.IrElementTransformerVoidWithContext
6
9
import org.jetbrains.kotlin.backend.common.extensions.FirIncompatiblePluginAPI
@@ -9,16 +12,16 @@ import org.jetbrains.kotlin.descriptors.CallableDescriptor
9
12
import org.jetbrains.kotlin.descriptors.SimpleFunctionDescriptor
10
13
import org.jetbrains.kotlin.ir.IrStatement
11
14
import org.jetbrains.kotlin.ir.ObsoleteDescriptorBasedAPI
12
- import org.jetbrains.kotlin.ir.builders.irBlockBody
13
- import org.jetbrains.kotlin.ir.builders.irCall
14
- import org.jetbrains.kotlin.ir.builders.irGet
15
- import org.jetbrains.kotlin.ir.builders.irReturn
15
+ import org.jetbrains.kotlin.ir.builders.*
16
16
import org.jetbrains.kotlin.ir.declarations.*
17
17
import org.jetbrains.kotlin.ir.expressions.IrBody
18
+ import org.jetbrains.kotlin.ir.expressions.IrCall
19
+ import org.jetbrains.kotlin.ir.expressions.IrTypeOperator
20
+ import org.jetbrains.kotlin.ir.expressions.impl.IrTypeOperatorCallImpl
18
21
import org.jetbrains.kotlin.ir.symbols.IrSimpleFunctionSymbol
19
- import org.jetbrains.kotlin.ir.types.isSubtypeOfClass
20
- import org.jetbrains.kotlin.ir.types.typeWith
21
- import org.jetbrains.kotlin.ir.util.*
22
+ import org.jetbrains.kotlin.ir.types.*
23
+ import org.jetbrains.kotlin.ir.util.isAnnotationWithEqualFqName
24
+ import org.jetbrains.kotlin.ir.util.primaryConstructor
22
25
import org.jetbrains.kotlin.name.ClassId
23
26
import org.jetbrains.kotlin.name.FqName
24
27
@@ -92,7 +95,8 @@ class SuspendTransformTransformer(
92
95
93
96
private fun resolveFunctionBodyByDescriptor (declaration : IrFunction , descriptor : CallableDescriptor ): IrFunction ? {
94
97
val userData = descriptor.getUserData(SuspendTransformUserDataKey ) ? : return null
95
- val callableFunction = pluginContext.referenceFunctions(userData.transformer.transformFunctionInfo.toCallableId()).firstOrNull()
98
+ val callableFunction =
99
+ pluginContext.referenceFunctions(userData.transformer.transformFunctionInfo.toCallableId()).firstOrNull()
96
100
? : throw IllegalStateException (" Transform function ${userData.transformer.transformFunctionInfo} not found" )
97
101
98
102
val generatedOriginFunction = resolveFunctionBody(declaration, userData.originFunction, callableFunction)
@@ -112,7 +116,7 @@ class SuspendTransformTransformer(
112
116
currentAnnotations.any { a -> a.isAnnotationWithEqualFqName(name) }
113
117
addAll(currentAnnotations)
114
118
115
- val syntheticFunctionIncludes = userData.transformer.originFunctionIncludeAnnotations
119
+ val syntheticFunctionIncludes = userData.transformer.originFunctionIncludeAnnotations
116
120
117
121
syntheticFunctionIncludes.forEach { include ->
118
122
val classId = include.classInfo.toClassId()
@@ -205,24 +209,69 @@ private fun generateTransformBodyForFunction(
205
209
// println(transformTargetFunctionCall.owner.valueParameters)
206
210
val owner = transformTargetFunctionCall.owner
207
211
208
- if (owner.valueParameters.size > 1 ) {
209
- val secondType = owner.valueParameters[1 ].type
210
- val coroutineScopeTypeName = " kotlinx.coroutines.CoroutineScope" .fqn
211
- val coroutineScopeTypeClassId = ClassId .topLevel(" kotlinx.coroutines.CoroutineScope" .fqn)
212
- val coroutineScopeTypeNameUnsafe = coroutineScopeTypeName.toUnsafe()
213
- if (secondType.isClassType(coroutineScopeTypeNameUnsafe)) {
214
- function.dispatchReceiverParameter?.also { dispatchReceiverParameter ->
215
- context.referenceClass(coroutineScopeTypeClassId)?.also { coroutineScopeRef ->
216
- if (dispatchReceiverParameter.type.isSubtypeOfClass(coroutineScopeRef)) {
217
- // put 'this' to second arg
218
- putValueArgument(1 , irGet(dispatchReceiverParameter))
219
- }
220
- }
221
- }
222
- }
212
+ // CoroutineScope
213
+ val ownerValueParameters = owner.valueParameters
223
214
215
+ if (ownerValueParameters.size > 1 ) {
216
+ for (index in 1 .. ownerValueParameters.lastIndex) {
217
+ val valueParameter = ownerValueParameters[index]
218
+ val type = valueParameter.type
219
+ tryResolveCoroutineScopeValueParameter(type, context, function, owner, this @irBlockBody, index)
220
+ }
224
221
}
225
222
226
223
})
227
224
}
228
225
}
226
+
227
+ private val coroutineScopeTypeName = " kotlinx.coroutines.CoroutineScope" .fqn
228
+ private val coroutineScopeTypeClassId = ClassId .topLevel(" kotlinx.coroutines.CoroutineScope" .fqn)
229
+ private val coroutineScopeTypeNameUnsafe = coroutineScopeTypeName.toUnsafe()
230
+
231
+ /* *
232
+ * 解析类型为 CoroutineScope 的参数。
233
+ * 如果当前参数类型为 CoroutineScope:
234
+ * - 如果当前 receiver 即为 CoroutineScope 类型,将其填充
235
+ * - 如果当前 receiver 不是 CoroutineScope 类型,但是此参数可以为 null,
236
+ * 则使用 safe-cast 将 receiver 转化为 CoroutineScope ( `dispatcher as? CoroutineScope` )
237
+ * - 其他情况忽略此参数(适用于此参数有默认值的情况)
238
+ */
239
+ private fun IrCall.tryResolveCoroutineScopeValueParameter (
240
+ type : IrType ,
241
+ context : IrPluginContext ,
242
+ function : IrFunction ,
243
+ owner : IrSimpleFunction ,
244
+ builderWithScope : IrBuilderWithScope ,
245
+ index : Int
246
+ ) {
247
+ if (! type.isClassType(coroutineScopeTypeNameUnsafe)) {
248
+ return
249
+ }
250
+
251
+ function.dispatchReceiverParameter?.also { dispatchReceiverParameter ->
252
+ context.referenceClass(coroutineScopeTypeClassId)?.also { coroutineScopeRef ->
253
+ if (dispatchReceiverParameter.type.isSubtypeOfClass(coroutineScopeRef)) {
254
+ // put 'this' to the arg
255
+ putValueArgument(index, builderWithScope.irGet(dispatchReceiverParameter))
256
+ } else {
257
+ val scopeType = coroutineScopeRef.defaultType
258
+
259
+ val scopeParameter = owner.valueParameters.getOrNull(1 )
260
+
261
+ if (scopeParameter?.type?.isNullable() == true ) {
262
+ val irSafeAs = IrTypeOperatorCallImpl (
263
+ startOffset,
264
+ endOffset,
265
+ scopeType,
266
+ IrTypeOperator .SAFE_CAST ,
267
+ scopeType,
268
+ builderWithScope.irGet(dispatchReceiverParameter)
269
+ )
270
+
271
+ putValueArgument(index, irSafeAs)
272
+ }
273
+ // irAs(irGet(dispatchReceiverParameter), coroutineScopeRef.defaultType)
274
+ }
275
+ }
276
+ }
277
+ }
0 commit comments