Skip to content

Commit

Permalink
ALL: Initial support for generics
Browse files Browse the repository at this point in the history
  • Loading branch information
MineGame159 committed Feb 21, 2024
1 parent d5ae248 commit d11956c
Show file tree
Hide file tree
Showing 56 changed files with 2,627 additions and 630 deletions.
13 changes: 9 additions & 4 deletions cmd/cmd/build.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"github.com/fatih/color"
"github.com/spf13/cobra"
"log"
"strings"
"time"
)

Expand Down Expand Up @@ -76,20 +77,24 @@ func generateEntrypoint(project *workspace.Project) *ir.Module {
mainBlock := main.Block("")

if function != nil {
f := function.Underlying()
var fbMain ir.Value

if ast.IsPrimitive(function.Returns, ast.I32) {
fbMain = m.Declare(function.MangledName(), &ir.FuncType{Returns: ir.I32})
var name strings.Builder
f.MangledName(&name)

if ast.IsPrimitive(f.Returns(), ast.I32) {
fbMain = m.Declare(name.String(), &ir.FuncType{Returns: ir.I32})
} else {
fbMain = m.Declare(function.MangledName(), &ir.FuncType{})
fbMain = m.Declare(name.String(), &ir.FuncType{})
}

call := mainBlock.Add(&ir.CallInst{
Callee: fbMain,
Args: nil,
})

if ast.IsPrimitive(function.Returns, ast.I32) {
if ast.IsPrimitive(f.Returns(), ast.I32) {
mainBlock.Add(&ir.RetInst{Value: call})
} else {
mainBlock.Add(&ir.RetInst{Value: &ir.IntConst{
Expand Down
9 changes: 6 additions & 3 deletions cmd/cmd/test.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ func testCmd(_ *cobra.Command, _ []string) {
_, _ = namespaceStyle.Printf("%s.", part)
}

if s := test.Method(); s != nil {
_, _ = namespaceStyle.Printf("%s.", s.Name)
if receiver := test.Receiver(); receiver != nil {
_, _ = namespaceStyle.Printf("%s.", receiver.Underlying().Name)
}

_, _ = testStyle.Print(test.TestName())
Expand Down Expand Up @@ -166,10 +166,13 @@ func generateTestsEntrypoint(tests []*ast.Func) *ir.Module {
mainBlock := main.Block("entry")

for i, test := range tests {
var name strings.Builder
test.MangledName(&name)

mainBlock.Add(&ir.CallInst{
Callee: run,
Args: []ir.Value{
m.Declare(test.MangledName(), testType),
m.Declare(name.String(), testType),
&ir.IntConst{Typ: ir.I32, Value: ir.Unsigned(uint64(i))},
},
})
Expand Down
133 changes: 84 additions & 49 deletions cmd/lsp/completion.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,23 @@ func getCompletions(project *workspace.Project, file *ast.File, pos core.Pos) *p
}

c := completions{}
resolver := ast.NewCombinedResolver(baseResolver)
combinedResolver := ast.NewCombinedResolver(baseResolver)

for _, decl := range file.Decls {
if using, ok := decl.(*ast.Using); ok {
if r := baseResolver.GetResolver(using.Name); r != nil {
resolver.Add(r)
combinedResolver.Add(r)
}
}
}

// Leaf
leaf := ast.GetLeaf(file, pos)

if leaf != nil {
if !ast.IsNil(leaf) {
var resolver ast.Resolver = combinedResolver
buildResolverStack(&resolver, leaf)

if isInFunctionBody(pos, leaf) {
switch parent := leaf.Parent().(type) {
case *ast.Resolvable:
Expand All @@ -53,6 +56,12 @@ func getCompletions(project *workspace.Project, file *ast.File, pos core.Pos) *p
} else {
// Non leaf
node := ast.Get(file, pos)
if ast.IsNil(node) {
return c.get()
}

var resolver ast.Resolver = combinedResolver
buildResolverStack(&resolver, node)

if isInFunctionBody(pos, node) {
switch node := node.(type) {
Expand All @@ -75,9 +84,10 @@ func getCompletions(project *workspace.Project, file *ast.File, pos core.Pos) *p

case *ast.StructInitializer:
if isAfterCst(pos, node, scanner.LeftBrace, false) {
if s, ok := ast.As[*ast.Struct](node.Type); ok {
for _, field := range s.Fields {
c.addNode(protocol.CompletionItemKindField, field.Name, printType(field.Type))
if s, ok := ast.As[ast.StructType](node.Type); ok {
for i := 0; i < s.FieldCount(); i++ {
field := s.FieldIndex(i)
c.addNode(protocol.CompletionItemKindField, field.Name(), printType(field.Type()))
}
}
}
Expand Down Expand Up @@ -119,13 +129,13 @@ func getTypeCompletions(root ast.RootResolver, resolver ast.Resolver, c *complet

case *ast.Struct:
for _, field := range node.Fields {
if field.Type == nil && isAfterNode(pos, field.Name) {
if field.Type() == nil && isAfterNode(pos, field.Name()) {
getGlobalCompletions(resolver, c, true)
}
}

case *ast.Field:
if isAfterNode(pos, node.Name) {
if isAfterNode(pos, node.Name()) {
getGlobalCompletions(resolver, c, true)
}

Expand Down Expand Up @@ -183,7 +193,7 @@ func getNamespaceCompletions(root ast.RootResolver, c *completions, pos core.Pos

if resolver != nil {
for _, child := range resolver.GetChildren() {
c.add(protocol.CompletionItemKindModule, child, "")
c.add(protocol.CompletionItemKindModule, child, "", false)
}
}
}
Expand All @@ -196,17 +206,19 @@ func getMemberCompletions(resolver ast.Resolver, c *completions, member *ast.Mem
getResolverCompletions(c, resolver, false)

case ast.TypeResultKind, ast.ValueResultKind:
if s, ok := asThroughPointer[*ast.Struct](member.Value.Result().Type); ok {
fields := s.Fields
static := false
if s, ok := asThroughPointer[ast.StructType](member.Value.Result().Type); ok {
static := member.Value.Result().Kind == ast.TypeResultKind

if member.Value.Result().Kind == ast.TypeResultKind {
fields = s.StaticFields
static = true
}

for _, field := range fields {
c.addNode(protocol.CompletionItemKindField, field.Name, printType(field.Type))
if static {
for i := 0; i < s.StaticFieldCount(); i++ {
field := s.StaticFieldIndex(i)
c.addNode(protocol.CompletionItemKindField, field.Name(), printType(field.Type()))
}
} else {
for i := 0; i < s.FieldCount(); i++ {
field := s.FieldIndex(i)
c.addNode(protocol.CompletionItemKindField, field.Name(), printType(field.Type()))
}
}

for _, method := range resolver.GetMethods(s, static) {
Expand Down Expand Up @@ -235,8 +247,8 @@ func getIdentifierCompletions(resolver ast.Resolver, c *completions, pos core.Po
names := utils.NewSet[string]()

// This
if s := function.Method(); s != nil {
c.add(protocol.CompletionItemKindVariable, "this", printType(s))
if s := function.Receiver(); s != nil {
c.add(protocol.CompletionItemKindVariable, "this", printType(s), false)
}

// Parameters
Expand Down Expand Up @@ -286,42 +298,42 @@ func getStmtCompletions(c *completions, node ast.Node) {
}

if ok {
c.add(protocol.CompletionItemKindKeyword, "var", "")
c.add(protocol.CompletionItemKindSnippet, "if", "")
c.add(protocol.CompletionItemKindSnippet, "while", "")
c.add(protocol.CompletionItemKindSnippet, "for", "")
c.add(protocol.CompletionItemKindKeyword, "return", "")
c.add(protocol.CompletionItemKindKeyword, "break", "")
c.add(protocol.CompletionItemKindKeyword, "continue", "")
c.add(protocol.CompletionItemKindKeyword, "var", "", false)
c.add(protocol.CompletionItemKindSnippet, "if", "", false)
c.add(protocol.CompletionItemKindSnippet, "while", "", false)
c.add(protocol.CompletionItemKindSnippet, "for", "", false)
c.add(protocol.CompletionItemKindKeyword, "return", "", false)
c.add(protocol.CompletionItemKindKeyword, "break", "", false)
c.add(protocol.CompletionItemKindKeyword, "continue", "", false)
}
}

func getGlobalCompletions(resolver ast.Resolver, c *completions, symbolsOnlyTypes bool) {
// Primitive types
c.add(protocol.CompletionItemKindStruct, "void", "")
c.add(protocol.CompletionItemKindStruct, "bool", "")
c.add(protocol.CompletionItemKindStruct, "void", "", false)
c.add(protocol.CompletionItemKindStruct, "bool", "", false)

c.add(protocol.CompletionItemKindStruct, "u8", "")
c.add(protocol.CompletionItemKindStruct, "u16", "")
c.add(protocol.CompletionItemKindStruct, "u32", "")
c.add(protocol.CompletionItemKindStruct, "u64", "")
c.add(protocol.CompletionItemKindStruct, "u8", "", false)
c.add(protocol.CompletionItemKindStruct, "u16", "", false)
c.add(protocol.CompletionItemKindStruct, "u32", "", false)
c.add(protocol.CompletionItemKindStruct, "u64", "", false)

c.add(protocol.CompletionItemKindStruct, "i8", "")
c.add(protocol.CompletionItemKindStruct, "i16", "")
c.add(protocol.CompletionItemKindStruct, "i32", "")
c.add(protocol.CompletionItemKindStruct, "i64", "")
c.add(protocol.CompletionItemKindStruct, "i8", "", false)
c.add(protocol.CompletionItemKindStruct, "i16", "", false)
c.add(protocol.CompletionItemKindStruct, "i32", "", false)
c.add(protocol.CompletionItemKindStruct, "i64", "", false)

c.add(protocol.CompletionItemKindStruct, "f32", "")
c.add(protocol.CompletionItemKindStruct, "f64", "")
c.add(protocol.CompletionItemKindStruct, "f32", "", false)
c.add(protocol.CompletionItemKindStruct, "f64", "", false)

if !symbolsOnlyTypes {
// Builtin identifiers
c.add(protocol.CompletionItemKindKeyword, "true", "bool")
c.add(protocol.CompletionItemKindKeyword, "false", "bool")
c.add(protocol.CompletionItemKindKeyword, "true", "bool", false)
c.add(protocol.CompletionItemKindKeyword, "false", "bool", false)

c.add(protocol.CompletionItemKindFunction, "sizeof", "(<type>) u32")
c.add(protocol.CompletionItemKindFunction, "alignof", "(<type>) u32")
c.add(protocol.CompletionItemKindFunction, "typeof", "(<expression>) u32")
c.add(protocol.CompletionItemKindFunction, "sizeof", "(<type>) u32", false)
c.add(protocol.CompletionItemKindFunction, "alignof", "(<type>) u32", false)
c.add(protocol.CompletionItemKindFunction, "typeof", "(<expression>) u32", false)
}

// Language defined types and functions
Expand All @@ -330,7 +342,7 @@ func getGlobalCompletions(resolver ast.Resolver, c *completions, symbolsOnlyType

func getResolverCompletions(c *completions, resolver ast.Resolver, symbolsOnlyTypes bool) {
for _, child := range resolver.GetChildren() {
c.add(protocol.CompletionItemKindModule, child, "")
c.add(protocol.CompletionItemKindModule, child, "", false)
}

c.symbolsOnlyTypes = symbolsOnlyTypes
Expand All @@ -339,6 +351,9 @@ func getResolverCompletions(c *completions, resolver ast.Resolver, symbolsOnlyTy

func (c *completions) VisitSymbol(node ast.Node) {
switch node := node.(type) {
case *ast.Generic:
c.addNode(protocol.CompletionItemKindTypeParameter, node, "")

case *ast.Struct:
c.addNode(protocol.CompletionItemKindStruct, node.Name, "")

Expand Down Expand Up @@ -429,11 +444,20 @@ var commitCharacters = []string{".", ";"}

func (c *completions) addNode(kind protocol.CompletionItemKind, name ast.Node, detail string) {
if !ast.IsNil(name) {
c.add(kind, name.String(), detail)
generics := false

switch node := name.Parent().(type) {
case *ast.Struct:
generics = len(node.GenericParams) > 0
case *ast.Func:
generics = len(node.GenericParams) > 0
}

c.add(kind, name.String(), detail, generics)
}
}

func (c *completions) add(kind protocol.CompletionItemKind, name, detail string) {
func (c *completions) add(kind protocol.CompletionItemKind, name, detail string, generics bool) {
item := protocol.CompletionItem{
Kind: kind,
Label: name,
Expand All @@ -442,8 +466,19 @@ func (c *completions) add(kind protocol.CompletionItemKind, name, detail string)
}

switch kind {
case protocol.CompletionItemKindStruct:
if generics {
item.InsertText = name + "![$1]"
item.InsertTextFormat = protocol.InsertTextFormatSnippet
}

case protocol.CompletionItemKindFunction, protocol.CompletionItemKindMethod:
item.InsertText = name + "($1)"
if generics {
item.InsertText = name + "![$1]($2)"
} else {
item.InsertText = name + "($1)"
}

item.InsertTextFormat = protocol.InsertTextFormatSnippet

case protocol.CompletionItemKindSnippet:
Expand Down
9 changes: 5 additions & 4 deletions cmd/lsp/definition.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@ func getDefinition(node ast.Node, pos core.Pos) []protocol.Location {

// Get definition based on the leaf node
switch node := node.(type) {
case *ast.Identifier:
return getDefinitionExprResult(node.Result())
case *ast.Resolvable:
return newDefinition(node.Resolved())
case *ast.Token:
Expand All @@ -32,12 +30,15 @@ func getDefinitionToken(token *ast.Token) []protocol.Location {
return newDefinition(parent.Resolved())
}

case *ast.Identifier:
return getDefinitionExprResult(parent.Result())

case *ast.Member:
return getDefinitionExprResult(parent.Result())

case *ast.InitField:
if s, ok := ast.As[*ast.Struct](parent.Parent().(*ast.StructInitializer).Type); ok {
if _, field := s.GetField(token.String()); field != nil {
if s, ok := ast.As[ast.StructType](parent.Parent().(*ast.StructInitializer).Type); ok {
if field := s.FieldName(token.String()); field != nil {
return newDefinition(field)
}
}
Expand Down
1 change: 1 addition & 0 deletions cmd/lsp/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ func (h *handler) Initialize(_ context.Context, params *protocol.InitializeParam
protocol.SemanticTokenEnumMember,
protocol.SemanticTokenNamespace,
protocol.SemanticTokenInterface,
protocol.SemanticTokenTypeParameter,
},
TokenModifiers: []protocol.SemanticTokenModifiers{},
},
Expand Down
19 changes: 19 additions & 0 deletions cmd/lsp/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,25 @@ import (
"fireball/core/scanner"
)

func buildResolverStack(resolver *ast.Resolver, node ast.Node) {
if !ast.IsNil(node.Parent()) {
buildResolverStack(resolver, node.Parent())
}

switch node := node.(type) {
case *ast.Struct:
*resolver = ast.NewGenericResolver(*resolver, node.GenericParams)

case *ast.Impl:
if s, ok := ast.As[*ast.Struct](node.Type); ok && len(s.GenericParams) > 0 {
*resolver = ast.NewGenericResolver(*resolver, s.GenericParams)
}

case *ast.Func:
*resolver = ast.NewGenericResolver(*resolver, node.GenericParams)
}
}

func printType(type_ ast.Type) string {
return ast.PrintTypeOptions(type_, ast.TypePrintOptions{ParamNames: true})
}
Expand Down
Loading

0 comments on commit d11956c

Please sign in to comment.