Skip to content

Commit

Permalink
CORE: Improve performance of resolver.GetSymbols()
Browse files Browse the repository at this point in the history
  • Loading branch information
MineGame159 committed Jan 13, 2024
1 parent 8b4c5a0 commit 9bff8af
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 46 deletions.
61 changes: 33 additions & 28 deletions cmd/lsp/completion.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,49 +90,49 @@ func getTypeCompletions(root ast.RootResolver, resolver ast.Resolver, c *complet
case *ast.Struct:
for _, field := range node.Fields {
if field.Type == nil && isAfterNode(pos, field.Name) {
getGlobalCompletions(resolver, c, false)
getGlobalCompletions(resolver, c, true)
}
}

case *ast.Field:
if isAfterNode(pos, node.Name) {
getGlobalCompletions(resolver, c, false)
getGlobalCompletions(resolver, c, true)
}

case *ast.Impl:
if isAfterCst(pos, node, scanner.Impl, true) {
getGlobalCompletions(resolver, c, false)
getGlobalCompletions(resolver, c, true)
}

case *ast.Enum:
if isAfterCst(pos, node, scanner.Colon, true) {
getGlobalCompletions(resolver, c, false)
getGlobalCompletions(resolver, c, true)
}

case *ast.Func:
for _, param := range node.Params {
if param.Type == nil && isAfterNode(pos, param.Name) {
getGlobalCompletions(resolver, c, false)
getGlobalCompletions(resolver, c, true)
}
}

if isAfterCst(pos, node, scanner.RightParen, true) {
getGlobalCompletions(resolver, c, false)
getGlobalCompletions(resolver, c, true)
}

case *ast.Param:
if isAfterNode(pos, node.Name) {
getGlobalCompletions(resolver, c, false)
getGlobalCompletions(resolver, c, true)
}

case *ast.GlobalVar:
if isAfterNode(pos, node.Name) {
getGlobalCompletions(resolver, c, false)
getGlobalCompletions(resolver, c, true)
}

case ast.Type:
if !isComplexType(node) {
getGlobalCompletions(resolver, c, false)
getGlobalCompletions(resolver, c, true)
}
}
}
Expand Down Expand Up @@ -160,7 +160,7 @@ func getMemberCompletions(resolver ast.Resolver, c *completions, member *ast.Mem
switch member.Value.Result().Kind {
case ast.ResolverResultKind:
resolver := member.Value.Result().Resolver()
getResolverCompletions(c, resolver, true)
getResolverCompletions(c, resolver, false)

case ast.TypeResultKind, ast.ValueResultKind:
if s, ok := asThroughPointer[*ast.Struct](member.Value.Result().Type); ok {
Expand Down Expand Up @@ -189,7 +189,7 @@ func getMemberCompletions(resolver ast.Resolver, c *completions, member *ast.Mem

func getIdentifierCompletions(resolver ast.Resolver, c *completions, pos core.Pos, node ast.Node) {
// Types and global functions
getGlobalCompletions(resolver, c, true)
getGlobalCompletions(resolver, c, false)

// Variables
function := ast.GetParent[*ast.Func](node)
Expand Down Expand Up @@ -259,7 +259,7 @@ func getStmtCompletions(c *completions, node ast.Node) {
}
}

func getGlobalCompletions(resolver ast.Resolver, c *completions, functions bool) {
func getGlobalCompletions(resolver ast.Resolver, c *completions, symbolsOnlyTypes bool) {
// Primitive types
c.add(protocol.CompletionItemKindStruct, "void", "")
c.add(protocol.CompletionItemKindStruct, "bool", "")
Expand All @@ -285,31 +285,34 @@ func getGlobalCompletions(resolver ast.Resolver, c *completions, functions bool)
c.add(protocol.CompletionItemKindFunction, "alignof", "(<type>) u32")

// Language defined types and functions
getResolverCompletions(c, resolver, functions)
getResolverCompletions(c, resolver, symbolsOnlyTypes)
}

func getResolverCompletions(c *completions, resolver ast.Resolver, functions bool) {
func getResolverCompletions(c *completions, resolver ast.Resolver, symbolsOnlyTypes bool) {
for _, child := range resolver.GetChildren() {
c.add(protocol.CompletionItemKindModule, child, "")
}

for _, node := range resolver.GetSymbols() {
switch node := node.(type) {
case *ast.Struct:
c.addNode(protocol.CompletionItemKindStruct, node.Name, "")
c.symbolsOnlyTypes = symbolsOnlyTypes
resolver.GetSymbols(c)
}

case *ast.Enum:
c.addNode(protocol.CompletionItemKindEnum, node.Name, "")
func (c *completions) VisitSymbol(node ast.Node) {
switch node := node.(type) {
case *ast.Struct:
c.addNode(protocol.CompletionItemKindStruct, node.Name, "")

case *ast.Func:
if functions {
c.addNode(protocol.CompletionItemKindFunction, node.Name, printType(node))
}
case *ast.Enum:
c.addNode(protocol.CompletionItemKindEnum, node.Name, "")

case *ast.GlobalVar:
if functions {
c.addNode(protocol.CompletionItemKindVariable, node.Name, printType(node.Type))
}
case *ast.Func:
if !c.symbolsOnlyTypes {
c.addNode(protocol.CompletionItemKindFunction, node.Name, printType(node))
}

case *ast.GlobalVar:
if !c.symbolsOnlyTypes {
c.addNode(protocol.CompletionItemKindVariable, node.Name, printType(node.Type))
}
}
}
Expand Down Expand Up @@ -361,6 +364,8 @@ func isInFunctionBody(pos core.Pos, node ast.Node) bool {
// Completions

type completions struct {
symbolsOnlyTypes bool

items []protocol.CompletionItem
}

Expand Down
14 changes: 7 additions & 7 deletions core/ast/resolver.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
package ast

type SymbolVisitor interface {
VisitSymbol(node Node)
}

type Resolver interface {
GetChild(name string) Resolver
GetType(name string) Type
Expand All @@ -11,7 +15,7 @@ type Resolver interface {
GetMethods(type_ Type, static bool) []*Func

GetChildren() []string
GetSymbols() []Node
GetSymbols(visitor SymbolVisitor)
}

type RootResolver interface {
Expand Down Expand Up @@ -104,12 +108,8 @@ func (c *CombinedResolver) GetChildren() []string {
return children
}

func (c *CombinedResolver) GetSymbols() []Node {
var symbols []Node

func (c *CombinedResolver) GetSymbols(visitor SymbolVisitor) {
for _, resolver := range c.resolvers {
symbols = append(symbols, resolver.GetSymbols()...)
resolver.GetSymbols(visitor)
}

return symbols
}
20 changes: 15 additions & 5 deletions core/checker/declarations.go
Original file line number Diff line number Diff line change
Expand Up @@ -272,14 +272,13 @@ func (c *checker) checkNameCollision(decl ast.Decl, name *ast.Token) {
}

// Symbols
for _, node := range c.resolver.GetSymbols() {
if node == decl {
continue
}
symbols := symbolGetter{except: decl}
c.resolver.GetSymbols(&symbols)

if symbols.node != nil {
var name2 *ast.Token

switch node := node.(type) {
switch node := symbols.node.(type) {
case *ast.Struct:
name2 = node.Name
case *ast.Enum:
Expand All @@ -296,3 +295,14 @@ func (c *checker) checkNameCollision(decl ast.Decl, name *ast.Token) {
}
}
}

type symbolGetter struct {
except ast.Node
node ast.Node
}

func (s *symbolGetter) VisitSymbol(node ast.Node) {
if node != s.except && s.node == nil {
s.node = node
}
}
8 changes: 2 additions & 6 deletions core/workspace/namespace.go
Original file line number Diff line number Diff line change
Expand Up @@ -180,17 +180,13 @@ func (n *namespace) GetChildren() []string {
return children
}

func (n *namespace) GetSymbols() []ast.Node {
var symbols []ast.Node

func (n *namespace) GetSymbols(visitor ast.SymbolVisitor) {
for _, file := range n.files {
for _, decl := range file.Ast.Decls {
switch decl.(type) {
case *ast.Struct, *ast.Enum, *ast.Func, *ast.GlobalVar:
symbols = append(symbols, decl)
visitor.VisitSymbol(decl)
}
}
}

return symbols
}

0 comments on commit 9bff8af

Please sign in to comment.