Skip to content

Commit 1bf2263

Browse files
Kobzolrrevenantt
authored andcommitted
INT: add smart pointers, Default impl and local variable support for
AddStructFieldsFix
1 parent 851ecfe commit 1bf2263

File tree

9 files changed

+372
-19
lines changed

9 files changed

+372
-19
lines changed

src/main/kotlin/org/rust/ide/annotator/fixes/AddStructFieldsFix.kt

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,14 @@ import com.intellij.psi.PsiElement
1212
import com.intellij.psi.PsiFile
1313
import org.rust.ide.annotator.calculateMissingFields
1414
import org.rust.lang.core.psi.RsDefaultValueBuilder
15+
import org.rust.lang.core.psi.RsPatBinding
1516
import org.rust.lang.core.psi.RsPsiFactory
1617
import org.rust.lang.core.psi.RsStructLiteral
18+
import org.rust.lang.core.psi.ext.RsElement
1719
import org.rust.lang.core.psi.ext.RsFieldsOwner
1820
import org.rust.lang.core.psi.ext.fields
1921
import org.rust.lang.core.resolve.knownItems
22+
import org.rust.lang.core.resolve.processLocalVariables
2023
import org.rust.lang.core.resolve.ref.deepResolve
2124
import org.rust.openapiext.buildAndRunTemplate
2225
import org.rust.openapiext.createSmartPointer
@@ -50,7 +53,13 @@ class AddStructFieldsFix(
5053
val body = structLiteral.structLiteralBody
5154
val fieldsToAdd = calculateMissingFields(body, decl)
5255
val defaultValueBuilder = RsDefaultValueBuilder(decl.knownItems, body.containingMod, RsPsiFactory(project), recursive)
53-
val addedFields = defaultValueBuilder.fillStruct(body, decl.fields, fieldsToAdd)
56+
57+
val addedFields = defaultValueBuilder.fillStruct(
58+
body,
59+
decl.fields,
60+
fieldsToAdd,
61+
RsDefaultValueBuilder.getVisibleBindings(startElement)
62+
)
5463
editor?.buildAndRunTemplate(body, addedFields.mapNotNull { it.expr?.createSmartPointer() })
5564
}
5665
}

src/main/kotlin/org/rust/ide/inspections/RsFieldInitShorthandInspection.kt

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,18 @@ class RsFieldInitShorthandInspection : RsLocalInspectionTool() {
2828
override fun getFamilyName(): String = "Use initialization shorthand"
2929

3030
override fun applyFix(project: Project, descriptor: ProblemDescriptor) {
31-
val field = descriptor.psiElement as RsStructLiteralField
32-
field.expr?.delete()
33-
field.colon?.delete()
31+
applyShorthandInit(descriptor.psiElement as RsStructLiteralField)
3432
}
3533
}
3634
)
3735
}
3836
}
37+
38+
companion object {
39+
fun applyShorthandInit(field: RsStructLiteralField)
40+
{
41+
field.expr?.delete()
42+
field.colon?.delete()
43+
}
44+
}
3945
}

src/main/kotlin/org/rust/ide/inspections/fixes/InitializeWithDefaultValueFix.kt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import org.rust.lang.core.psi.*
1414
import org.rust.lang.core.psi.ext.RsElement
1515
import org.rust.lang.core.psi.ext.ancestorOrSelf
1616
import org.rust.lang.core.resolve.knownItems
17+
import org.rust.lang.core.resolve.processLocalVariables
1718
import org.rust.lang.core.types.declaration
1819
import org.rust.lang.core.types.type
1920
import org.rust.openapiext.buildAndRunTemplate
@@ -29,7 +30,8 @@ class InitializeWithDefaultValueFix(element: RsElement) : LocalQuickFixAndIntent
2930
val declaration = patBinding.ancestorOrSelf<RsLetDecl>() ?: return
3031
val semicolon = declaration.semicolon ?: return
3132
val psiFactory = RsPsiFactory(project)
32-
val initExpr = RsDefaultValueBuilder(declaration.knownItems, declaration.containingMod, psiFactory).buildFor(patBinding.type)
33+
val initExpr = RsDefaultValueBuilder(declaration.knownItems, declaration.containingMod, psiFactory, true)
34+
.buildFor(patBinding.type, RsDefaultValueBuilder.getVisibleBindings(startElement as RsElement))
3335

3436
if (declaration.eq == null) {
3537
declaration.addBefore(psiFactory.createEq(), semicolon)

src/main/kotlin/org/rust/lang/core/psi/RsDefaultValueBuilder.kt

Lines changed: 101 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,10 @@
55

66
package org.rust.lang.core.psi
77

8+
import org.rust.ide.inspections.RsFieldInitShorthandInspection
89
import org.rust.lang.core.psi.ext.*
910
import org.rust.lang.core.resolve.KnownItems
11+
import org.rust.lang.core.resolve.processLocalVariables
1012
import org.rust.lang.core.types.ty.*
1113
import org.rust.lang.core.types.type
1214

@@ -19,32 +21,63 @@ class RsDefaultValueBuilder(
1921
private val defaultValue: RsExpr
2022
get() = psiFactory.createExpression("()")
2123

22-
fun buildFor(ty: Ty): RsExpr {
24+
private fun buildForSmartPtr(ty: TyAdt, bindings: Map<String, RsPatBinding>): RsExpr {
25+
val item = ty.item
26+
val name = ty.item.name!!
27+
val parameter = ty.typeParameterValues[item.typeParameters[0]]!!
28+
return psiFactory.createAssocFunctionCall(name, "new", listOf(buildFor(parameter, bindings)))
29+
}
30+
31+
fun buildFor(ty: Ty, bindings: Map<String, RsPatBinding>): RsExpr {
2332
return when (ty) {
2433
is TyBool -> psiFactory.createExpression("false")
2534
is TyInteger -> psiFactory.createExpression("0")
2635
is TyFloat -> psiFactory.createExpression("0.0")
2736
is TyChar -> psiFactory.createExpression("''")
2837
is TyReference -> when (ty.referenced) {
2938
is TyStr -> psiFactory.createExpression("\"\"")
30-
else -> psiFactory.createRefExpr(buildFor(ty.referenced), listOf(ty.mutability))
39+
else -> psiFactory.createRefExpr(buildFor(ty.referenced, bindings), listOf(ty.mutability))
3140
}
3241
is TyAdt -> {
42+
val smartPointers = listOf(
43+
items.Box,
44+
items.Rc,
45+
items.Arc,
46+
items.Cell,
47+
items.RefCell,
48+
items.UnsafeCell,
49+
items.Mutex
50+
)
51+
3352
val item = ty.item
53+
if (item in smartPointers) {
54+
return buildForSmartPtr(ty, bindings)
55+
}
56+
57+
var default = this.defaultValue
58+
if (item.implLookup.isDefault(ty)) {
59+
default = psiFactory.createAssocFunctionCall("Default", "default", listOf())
60+
}
61+
3462
val name = item.name!! // `!!` is because it isn't possible to acquire TyAdt with anonymous item
3563
when (item) {
3664
items.Option -> psiFactory.createExpression("None")
3765
items.String -> psiFactory.createExpression("\"\".to_string()")
3866
items.Vec -> psiFactory.createExpression("vec![]")
3967
is RsStructItem -> if (item.kind == RsStructKind.STRUCT && item.canBeInstantiatedIn(mod)) {
68+
if (item.implLookup.isDefault(ty)) {
69+
return default
70+
}
71+
4072
when {
4173
item.blockFields != null -> {
4274
val structLiteral = psiFactory.createStructLiteral(name)
4375
if (recursive) {
4476
fillStruct(
4577
structLiteral.structLiteralBody,
4678
item.namedFields,
47-
item.namedFields
79+
item.namedFields,
80+
bindings
4881
)
4982
}
5083
structLiteral
@@ -53,7 +86,7 @@ class RsDefaultValueBuilder(
5386
val argExprs = if (recursive) {
5487
item.positionalFields
5588
.map { it.typeReference.type }
56-
.map { buildFor(it) }
89+
.map { buildFor(it, bindings) }
5790
} else {
5891
emptyList()
5992
}
@@ -62,23 +95,27 @@ class RsDefaultValueBuilder(
6295
else -> psiFactory.createExpression(name)
6396
}
6497
} else {
65-
defaultValue
98+
default
6699
}
67100
is RsEnumItem -> {
101+
if (item.implLookup.isDefault(ty)) {
102+
return default
103+
}
104+
68105
val variantWithoutFields = item.enumBody
69106
?.enumVariantList
70107
?.find { it.isFieldless }
71108
?.name
72109
variantWithoutFields?.let { psiFactory.createExpression("$name::$it") }
73-
?: defaultValue
110+
?: default
74111
}
75-
else -> defaultValue
112+
else -> default
76113
}
77114
}
78115
is TySlice, is TyArray -> psiFactory.createExpression("[]")
79116
is TyTuple -> {
80117
val text = ty.types.joinToString(prefix = "(", separator = ", ", postfix = ")") { tupleElement ->
81-
buildFor(tupleElement).text
118+
buildFor(tupleElement, bindings).text
82119
}
83120
psiFactory.createExpression(text)
84121
}
@@ -89,13 +126,16 @@ class RsDefaultValueBuilder(
89126
fun fillStruct(
90127
structLiteral: RsStructLiteralBody,
91128
declaredFields: List<RsFieldDecl>,
92-
fieldsToAdd: List<RsFieldDecl>
129+
fieldsToAdd: List<RsFieldDecl>,
130+
bindings: Map<String, RsPatBinding>
93131
): List<RsStructLiteralField> {
94132
val forceMultiLine = structLiteral.structLiteralFieldList.isEmpty() && fieldsToAdd.size > 2
95133

96134
val addedFields = mutableListOf<RsStructLiteralField>()
97135
for (fieldDecl in fieldsToAdd) {
98-
val field = specializedCreateStructLiteralField(fieldDecl) ?: continue
136+
val field = findLocalBinding(fieldDecl, bindings)
137+
?: specializedCreateStructLiteralField(fieldDecl, bindings)
138+
?: continue
99139
val addBefore = findPlaceToAdd(field, structLiteral.structLiteralFieldList, declaredFields)
100140
val added = if (addBefore == null) {
101141
ensureTrailingComma(structLiteral.structLiteralFieldList)
@@ -114,6 +154,43 @@ class RsDefaultValueBuilder(
114154
return addedFields
115155
}
116156

157+
private fun findLocalBinding(fieldDecl: RsFieldDecl, bindings: Map<String, RsPatBinding>): RsStructLiteralField? {
158+
val name = fieldDecl.name ?: return null
159+
val type = fieldDecl.typeReference?.type ?: return null
160+
161+
val binding = bindings[name] ?: return null
162+
return when {
163+
type == binding.type -> {
164+
val field = psiFactory.createStructLiteralField(name, psiFactory.createExpression(name))
165+
RsFieldInitShorthandInspection.applyShorthandInit(field)
166+
field
167+
}
168+
isRefContainer(type, binding.type) -> {
169+
val expr = buildReference(type, psiFactory.createExpression(name))
170+
psiFactory.createStructLiteralField(name, expr)
171+
}
172+
else -> null
173+
}
174+
}
175+
176+
private fun isRefContainer(container: Ty, type: Ty): Boolean {
177+
return when (container) {
178+
type -> true
179+
is TyReference -> isRefContainer(container.referenced, type)
180+
else -> false
181+
}
182+
}
183+
184+
private fun buildReference(type: Ty, expr: RsExpr): RsExpr {
185+
return when (type) {
186+
is TyReference -> {
187+
val inner = type.referenced
188+
psiFactory.createRefExpr(buildReference(inner, expr), listOf(type.mutability))
189+
}
190+
else -> expr
191+
}
192+
}
193+
117194
private fun findPlaceToAdd(
118195
fieldToAdd: RsStructLiteralField,
119196
existingFields: List<RsStructLiteralField>,
@@ -154,10 +231,22 @@ class RsDefaultValueBuilder(
154231
return null
155232
}
156233

157-
private fun specializedCreateStructLiteralField(fieldDecl: RsFieldDecl): RsStructLiteralField? {
234+
private fun specializedCreateStructLiteralField(fieldDecl: RsFieldDecl, bindings: Map<String, RsPatBinding>): RsStructLiteralField? {
158235
val fieldName = fieldDecl.name ?: return null
159236
val fieldType = fieldDecl.typeReference?.type ?: return null
160-
val fieldLiteral = buildFor(fieldType)
237+
val fieldLiteral = buildFor(fieldType, bindings)
161238
return psiFactory.createStructLiteralField(fieldName, fieldLiteral)
162239
}
240+
241+
companion object {
242+
fun getVisibleBindings(place: RsElement): Map<String, RsPatBinding> {
243+
val bindings = HashMap<String, RsPatBinding>()
244+
processLocalVariables(place) { variable ->
245+
variable.name?.let {
246+
bindings[it] = variable
247+
}
248+
}
249+
return bindings
250+
}
251+
}
163252
}

src/main/kotlin/org/rust/lang/core/resolve/ImplLookup.kt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -725,6 +725,7 @@ class ImplLookup(
725725
fun isClone(ty: Ty): Boolean = ty.isTraitImplemented(items.Clone)
726726
fun isSized(ty: Ty): Boolean = ty.isTraitImplemented(items.Sized)
727727
fun isDebug(ty: Ty): Boolean = ty.isTraitImplemented(items.Debug)
728+
fun isDefault(ty: Ty): Boolean = ty.isTraitImplemented(items.Default)
728729
fun isPartialEq(ty: Ty, rhsType: Ty = ty): Boolean = ty.isTraitImplemented(items.PartialEq, rhsType)
729730
fun isIntoIterator(ty: Ty): Boolean = ty.isTraitImplemented(items.IntoIterator)
730731

src/main/kotlin/org/rust/lang/core/resolve/KnownItems.kt

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,13 @@ class KnownItems(
107107
// Some old versions of stdlib contain `Ord` trait without lang attribute
108108
val Ord: RsTraitItem? get() = findItem("core::cmp::Ord")
109109
val Debug: RsTraitItem? get() = findLangItem("debug_trait")
110-
val Box: RsStructItem? get() = findLangItem("owned_box", "alloc")
110+
val Box: RsStructOrEnumItemElement? get() = findLangItem("owned_box", "alloc")
111+
val Rc: RsStructOrEnumItemElement? get() = findItem("alloc::rc::Rc")
112+
val Arc: RsStructOrEnumItemElement? get() = findItem("alloc::sync::Arc") ?: findItem("alloc::arc::Arc")
113+
val Cell: RsStructOrEnumItemElement? get() = findItem("core::cell::Cell")
114+
val RefCell: RsStructOrEnumItemElement? get() = findItem("core::cell::RefCell")
115+
val UnsafeCell: RsStructOrEnumItemElement? get() = findItem("core::cell::UnsafeCell")
116+
val Mutex: RsStructOrEnumItemElement? get() = findItem("std::sync::mutex::Mutex")
111117
}
112118

113119
interface KnownItemsLookup {

src/main/kotlin/org/rust/lang/core/resolve/NameResolution.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -593,7 +593,7 @@ fun processLocalVariables(place: RsElement, processor: (RsPatBinding) -> Unit) {
593593
processLexicalDeclarations(scope, cameFrom, VALUES) { v ->
594594
val el = v.element
595595
if (el is RsPatBinding) processor(el)
596-
true
596+
false
597597
}
598598
}
599599
}

0 commit comments

Comments
 (0)