5
5
6
6
package org.rust.lang.core.psi
7
7
8
+ import org.rust.ide.inspections.RsFieldInitShorthandInspection
8
9
import org.rust.lang.core.psi.ext.*
9
10
import org.rust.lang.core.resolve.KnownItems
11
+ import org.rust.lang.core.resolve.processLocalVariables
10
12
import org.rust.lang.core.types.ty.*
11
13
import org.rust.lang.core.types.type
12
14
@@ -19,32 +21,63 @@ class RsDefaultValueBuilder(
19
21
private val defaultValue: RsExpr
20
22
get() = psiFactory.createExpression(" ()" )
21
23
22
- fun buildFor (ty : Ty ): RsExpr {
24
+ private fun buildForSmartPtr (ty : TyAdt , bindings : Map <String , RsPatBinding >): RsExpr {
25
+ val item = ty.item
26
+ val name = ty.item.name!!
27
+ val parameter = ty.typeParameterValues[item.typeParameters[0 ]]!!
28
+ return psiFactory.createAssocFunctionCall(name, " new" , listOf (buildFor(parameter, bindings)))
29
+ }
30
+
31
+ fun buildFor (ty : Ty , bindings : Map <String , RsPatBinding >): RsExpr {
23
32
return when (ty) {
24
33
is TyBool -> psiFactory.createExpression(" false" )
25
34
is TyInteger -> psiFactory.createExpression(" 0" )
26
35
is TyFloat -> psiFactory.createExpression(" 0.0" )
27
36
is TyChar -> psiFactory.createExpression(" ''" )
28
37
is TyReference -> when (ty.referenced) {
29
38
is TyStr -> psiFactory.createExpression(" \"\" " )
30
- else -> psiFactory.createRefExpr(buildFor(ty.referenced), listOf (ty.mutability))
39
+ else -> psiFactory.createRefExpr(buildFor(ty.referenced, bindings ), listOf (ty.mutability))
31
40
}
32
41
is TyAdt -> {
42
+ val smartPointers = listOf (
43
+ items.Box ,
44
+ items.Rc ,
45
+ items.Arc ,
46
+ items.Cell ,
47
+ items.RefCell ,
48
+ items.UnsafeCell ,
49
+ items.Mutex
50
+ )
51
+
33
52
val item = ty.item
53
+ if (item in smartPointers) {
54
+ return buildForSmartPtr(ty, bindings)
55
+ }
56
+
57
+ var default = this .defaultValue
58
+ if (item.implLookup.isDefault(ty)) {
59
+ default = psiFactory.createAssocFunctionCall(" Default" , " default" , listOf ())
60
+ }
61
+
34
62
val name = item.name!! // `!!` is because it isn't possible to acquire TyAdt with anonymous item
35
63
when (item) {
36
64
items.Option -> psiFactory.createExpression(" None" )
37
65
items.String -> psiFactory.createExpression(" \"\" .to_string()" )
38
66
items.Vec -> psiFactory.createExpression(" vec![]" )
39
67
is RsStructItem -> if (item.kind == RsStructKind .STRUCT && item.canBeInstantiatedIn(mod)) {
68
+ if (item.implLookup.isDefault(ty)) {
69
+ return default
70
+ }
71
+
40
72
when {
41
73
item.blockFields != null -> {
42
74
val structLiteral = psiFactory.createStructLiteral(name)
43
75
if (recursive) {
44
76
fillStruct(
45
77
structLiteral.structLiteralBody,
46
78
item.namedFields,
47
- item.namedFields
79
+ item.namedFields,
80
+ bindings
48
81
)
49
82
}
50
83
structLiteral
@@ -53,7 +86,7 @@ class RsDefaultValueBuilder(
53
86
val argExprs = if (recursive) {
54
87
item.positionalFields
55
88
.map { it.typeReference.type }
56
- .map { buildFor(it) }
89
+ .map { buildFor(it, bindings ) }
57
90
} else {
58
91
emptyList()
59
92
}
@@ -62,23 +95,27 @@ class RsDefaultValueBuilder(
62
95
else -> psiFactory.createExpression(name)
63
96
}
64
97
} else {
65
- defaultValue
98
+ default
66
99
}
67
100
is RsEnumItem -> {
101
+ if (item.implLookup.isDefault(ty)) {
102
+ return default
103
+ }
104
+
68
105
val variantWithoutFields = item.enumBody
69
106
?.enumVariantList
70
107
?.find { it.isFieldless }
71
108
?.name
72
109
variantWithoutFields?.let { psiFactory.createExpression(" $name ::$it " ) }
73
- ? : defaultValue
110
+ ? : default
74
111
}
75
- else -> defaultValue
112
+ else -> default
76
113
}
77
114
}
78
115
is TySlice , is TyArray -> psiFactory.createExpression(" []" )
79
116
is TyTuple -> {
80
117
val text = ty.types.joinToString(prefix = " (" , separator = " , " , postfix = " )" ) { tupleElement ->
81
- buildFor(tupleElement).text
118
+ buildFor(tupleElement, bindings ).text
82
119
}
83
120
psiFactory.createExpression(text)
84
121
}
@@ -89,13 +126,16 @@ class RsDefaultValueBuilder(
89
126
fun fillStruct (
90
127
structLiteral : RsStructLiteralBody ,
91
128
declaredFields : List <RsFieldDecl >,
92
- fieldsToAdd : List <RsFieldDecl >
129
+ fieldsToAdd : List <RsFieldDecl >,
130
+ bindings : Map <String , RsPatBinding >
93
131
): List <RsStructLiteralField > {
94
132
val forceMultiLine = structLiteral.structLiteralFieldList.isEmpty() && fieldsToAdd.size > 2
95
133
96
134
val addedFields = mutableListOf<RsStructLiteralField >()
97
135
for (fieldDecl in fieldsToAdd) {
98
- val field = specializedCreateStructLiteralField(fieldDecl) ? : continue
136
+ val field = findLocalBinding(fieldDecl, bindings)
137
+ ? : specializedCreateStructLiteralField(fieldDecl, bindings)
138
+ ? : continue
99
139
val addBefore = findPlaceToAdd(field, structLiteral.structLiteralFieldList, declaredFields)
100
140
val added = if (addBefore == null ) {
101
141
ensureTrailingComma(structLiteral.structLiteralFieldList)
@@ -114,6 +154,43 @@ class RsDefaultValueBuilder(
114
154
return addedFields
115
155
}
116
156
157
+ private fun findLocalBinding (fieldDecl : RsFieldDecl , bindings : Map <String , RsPatBinding >): RsStructLiteralField ? {
158
+ val name = fieldDecl.name ? : return null
159
+ val type = fieldDecl.typeReference?.type ? : return null
160
+
161
+ val binding = bindings[name] ? : return null
162
+ return when {
163
+ type == binding.type -> {
164
+ val field = psiFactory.createStructLiteralField(name, psiFactory.createExpression(name))
165
+ RsFieldInitShorthandInspection .applyShorthandInit(field)
166
+ field
167
+ }
168
+ isRefContainer(type, binding.type) -> {
169
+ val expr = buildReference(type, psiFactory.createExpression(name))
170
+ psiFactory.createStructLiteralField(name, expr)
171
+ }
172
+ else -> null
173
+ }
174
+ }
175
+
176
+ private fun isRefContainer (container : Ty , type : Ty ): Boolean {
177
+ return when (container) {
178
+ type -> true
179
+ is TyReference -> isRefContainer(container.referenced, type)
180
+ else -> false
181
+ }
182
+ }
183
+
184
+ private fun buildReference (type : Ty , expr : RsExpr ): RsExpr {
185
+ return when (type) {
186
+ is TyReference -> {
187
+ val inner = type.referenced
188
+ psiFactory.createRefExpr(buildReference(inner, expr), listOf (type.mutability))
189
+ }
190
+ else -> expr
191
+ }
192
+ }
193
+
117
194
private fun findPlaceToAdd (
118
195
fieldToAdd : RsStructLiteralField ,
119
196
existingFields : List <RsStructLiteralField >,
@@ -154,10 +231,22 @@ class RsDefaultValueBuilder(
154
231
return null
155
232
}
156
233
157
- private fun specializedCreateStructLiteralField (fieldDecl : RsFieldDecl ): RsStructLiteralField ? {
234
+ private fun specializedCreateStructLiteralField (fieldDecl : RsFieldDecl , bindings : Map < String , RsPatBinding > ): RsStructLiteralField ? {
158
235
val fieldName = fieldDecl.name ? : return null
159
236
val fieldType = fieldDecl.typeReference?.type ? : return null
160
- val fieldLiteral = buildFor(fieldType)
237
+ val fieldLiteral = buildFor(fieldType, bindings )
161
238
return psiFactory.createStructLiteralField(fieldName, fieldLiteral)
162
239
}
240
+
241
+ companion object {
242
+ fun getVisibleBindings (place : RsElement ): Map <String , RsPatBinding > {
243
+ val bindings = HashMap <String , RsPatBinding >()
244
+ processLocalVariables(place) { variable ->
245
+ variable.name?.let {
246
+ bindings[it] = variable
247
+ }
248
+ }
249
+ return bindings
250
+ }
251
+ }
163
252
}
0 commit comments