From 508f2e6a30228fded9c489bd44ea418daafa177b Mon Sep 17 00:00:00 2001 From: MineGame159 Date: Thu, 22 Feb 2024 19:45:47 +0100 Subject: [PATCH] CORE: Uhh fixes --- cmd/cmd/test.go | 2 +- cmd/lsp/implementations.go | 3 ++ core/ast/generics.go | 59 ++++++++++++++++++++++------ core/ast/interfaces.go | 4 +- core/ast/interfaces_impl.go | 10 ++++- core/ast/types_manual.go | 4 ++ core/codegen/codegen.go | 6 +++ core/codegen/declarations.go | 20 +++++----- core/codegen/expressions.go | 2 +- core/llvm/instructions.go | 16 ++++++++ core/typeresolver/typeresolver.go | 32 +++------------ core/typeresolver/typespecializer.go | 45 +++++++++++++++++++++ core/workspace/file.go | 5 +++ core/workspace/project.go | 9 +++++ 14 files changed, 164 insertions(+), 53 deletions(-) create mode 100644 core/typeresolver/typespecializer.go diff --git a/cmd/cmd/test.go b/cmd/cmd/test.go index e00b3d8..40978f8 100644 --- a/cmd/cmd/test.go +++ b/cmd/cmd/test.go @@ -72,7 +72,7 @@ func testCmd(_ *cobra.Command, _ []string) { _, _ = namespaceStyle.Printf("%s.", part) } - if receiver := test.Receiver(); receiver != nil { + if receiver, ok := test.Receiver().(ast.StructType); ok { _, _ = namespaceStyle.Printf("%s.", receiver.Underlying().Name) } diff --git a/cmd/lsp/implementations.go b/cmd/lsp/implementations.go index acbcce3..2a252fc 100644 --- a/cmd/lsp/implementations.go +++ b/cmd/lsp/implementations.go @@ -9,6 +9,9 @@ import ( func getImplementations(node ast.Node, pos core.Pos, resolver ast.Resolver) []protocol.Location { leaf := ast.GetLeaf(node, pos) + if ast.IsNil(leaf) { + return nil + } builder := implementationBuilder{} diff --git a/core/ast/generics.go b/core/ast/generics.go index 90bf26f..fe23566 100644 --- a/core/ast/generics.go +++ b/core/ast/generics.go @@ -32,24 +32,36 @@ func (s *Struct) Specialize(types []Type) StructType { fields: fields, } + for i := 0; i < len(staticFields); i++ { + staticFields[i].struct_ = spec + } + + for i := 0; i < len(fields); i++ { + fields[i].struct_ = spec + } + s.Specializations = append(s.Specializations, spec) return spec } func specializeFields(s *Struct, types []Type, fields []*Field) []SpecializedField { - specFields := make([]SpecializedField, len(fields)) + specFields := make([]SpecializedField, 0, len(fields)) - for i, field := range fields { - type_ := specialize(s.GenericParams, types, field.Type_) + for _, field := range fields { + if IsNil(field.Type()) { + continue + } + + type_ := specialize(s.GenericParams, types, field.Type()) if type_ == nil { - type_ = field.Type_ + type_ = field.Type() } - specFields[i] = SpecializedField{ + specFields = append(specFields, SpecializedField{ wrapper: wrapper[*Field]{wrapped: field}, type_: type_, - } + }) } return specFields @@ -70,7 +82,7 @@ func (f *Func) Specialize(types []Type) FuncType { return spec } -func specializeFunc(receiver StructType, specializations *[]*SpecializedFunc, f SpecializableFunc, types []Type) (FuncType, bool) { +func specializeFunc(receiver Type, specializations *[]*SpecializedFunc, f SpecializableFunc, types []Type) (FuncType, bool) { // Check cache for _, spec := range *specializations { if slices.EqualFunc(spec.Types, types, typesEquals) { @@ -241,6 +253,14 @@ func shallowCopyStruct(s StructType, types []Type) *SpecializedStruct { fields: fields, } + for i := 0; i < len(staticFields); i++ { + staticFields[i].struct_ = spec + } + + for i := 0; i < len(fields); i++ { + fields[i].struct_ = spec + } + s.Underlying().Specializations = append(s.Underlying().Specializations, spec) return spec } @@ -281,12 +301,15 @@ func shallowCopyFuncType(f *Func, types []Type) *SpecializedFunc { type SpecializedField struct { wrapper[*Field] - type_ Type + struct_ StructType + type_ Type } func (s *SpecializedField) Clone() Node { return &SpecializedField{ wrapper: s.wrapper, + struct_: s.struct_, + type_: s.type_, } } @@ -294,6 +317,10 @@ func (s *SpecializedField) Underlying() *Field { return s.wrapped } +func (s *SpecializedField) Struct() StructType { + return s.struct_ +} + func (s *SpecializedField) Name() *Token { return s.Underlying().Name_ } @@ -493,7 +520,7 @@ func (p *PartiallySpecializedFunc) Underlying() *Func { return p.wrapped } -func (p *PartiallySpecializedFunc) Receiver() StructType { +func (p *PartiallySpecializedFunc) Receiver() Type { return p.receiver } @@ -536,7 +563,7 @@ func (p *PartiallySpecializedFunc) Specialize(types []Type) FuncType { type SpecializedFunc struct { wrapper[*Func] - receiver StructType + receiver Type Types []Type params []SpecializedParam @@ -577,7 +604,7 @@ func (s *SpecializedFunc) Underlying() *Func { return s.wrapped } -func (s *SpecializedFunc) Receiver() StructType { +func (s *SpecializedFunc) Receiver() Type { return s.receiver } @@ -595,7 +622,15 @@ func (s *SpecializedFunc) Returns() Type { func (s *SpecializedFunc) MangledName(name *strings.Builder) { // Base - receiver := s.Receiver() + var receiver StructType + + if s.Receiver() != nil { + if s, ok := s.Receiver().(StructType); ok { + receiver = s + } else { + panic("ast.SpecializedFunc.MangledName() - Receiver is not a StructType, this shouldn't happen") + } + } if receiver == nil { receiver = s.Underlying().Struct() diff --git a/core/ast/interfaces.go b/core/ast/interfaces.go index 0742119..bdd3fda 100644 --- a/core/ast/interfaces.go +++ b/core/ast/interfaces.go @@ -7,6 +7,8 @@ type FieldLike interface { Underlying() *Field + Struct() StructType + Name() *Token Type() Type } @@ -44,7 +46,7 @@ type FuncType interface { Type Underlying() *Func - Receiver() StructType + Receiver() Type ParameterCount() int ParameterIndex(index int) SpecializedParam diff --git a/core/ast/interfaces_impl.go b/core/ast/interfaces_impl.go index 7096b05..eff2e64 100644 --- a/core/ast/interfaces_impl.go +++ b/core/ast/interfaces_impl.go @@ -8,6 +8,10 @@ func (f *Field) Underlying() *Field { return f } +func (f *Field) Struct() StructType { + return f.Parent().(StructType) +} + func (f *Field) Name() *Token { return f.Name_ } @@ -72,11 +76,15 @@ func (f *Func) Underlying() *Func { return f } -func (f *Func) Receiver() StructType { +func (f *Func) Receiver() Type { if impl, ok := f.Parent().(*Impl); ok && !f.IsStatic() { return impl.Type.(*Struct) } + if inter, ok := f.Parent().(*Interface); ok { + return inter + } + return nil } diff --git a/core/ast/types_manual.go b/core/ast/types_manual.go index 6dc97d5..83fe315 100644 --- a/core/ast/types_manual.go +++ b/core/ast/types_manual.go @@ -65,6 +65,10 @@ func (r *Resolvable) Equals(other Type) bool { panic("ast.Resolvable.Equals() - Not resolved") } + if IsNil(other) { + return false + } + return r.Resolved().Equals(other.Resolved()) } diff --git a/core/codegen/codegen.go b/core/codegen/codegen.go index 8afbc44..0ab0a0d 100644 --- a/core/codegen/codegen.go +++ b/core/codegen/codegen.go @@ -85,7 +85,13 @@ func Emit(ctx *Context, path string, root ast.RootResolver, file *ast.File) *ir. if len(s.GenericParams) > 0 { for _, spec := range s.Specializations { for _, method := range spec.Methods { + var sp specializer + sp.prepare(method.Underlying(), s.GenericParams) + + sp.specialize(spec.Types) c.defineOrDeclare(method) + + sp.finish() } } diff --git a/core/codegen/declarations.go b/core/codegen/declarations.go index 3b69764..bd9091d 100644 --- a/core/codegen/declarations.go +++ b/core/codegen/declarations.go @@ -96,26 +96,26 @@ func (c *codegen) genFunc(f ast.FuncType) { c.resolver = ast.NewGenericResolver(c.resolver, f.Underlying().GenericParams) } - // Add this variable - if receiver := f.Receiver(); receiver != nil { - name := scanner.Token{Kind: scanner.Identifier, Lexeme: "this"} - node := cst.Node{Kind: cst.TokenNode, Token: name, Range: decl.Name.Cst().Range} - - c.scopes.addVariable(ast.NewToken(node, name), receiver, function.Typ.Params[0], 1) - } - - // Copy parameters + // Classify funcAbi := abi.GetFuncAbi(decl) returnArgs := funcAbi.Classify(f.Returns(), nil) index := 0 + if len(returnArgs) == 1 && returnArgs[0].Class == abi.Memory { index++ } - if decl.Receiver() != nil { + + // Add this variable + if receiver := f.Receiver(); receiver != nil { + name := scanner.Token{Kind: scanner.Identifier, Lexeme: "this"} + node := cst.Node{Kind: cst.TokenNode, Token: name, Range: decl.Name.Cst().Range} + + c.scopes.addVariable(ast.NewToken(node, name), receiver, function.Typ.Params[index], 1) index++ } + // Copy parameters paramI := index for i := 0; i < f.ParameterCount(); i++ { diff --git a/core/codegen/expressions.go b/core/codegen/expressions.go index dbfbfc0..f8b1c06 100644 --- a/core/codegen/expressions.go +++ b/core/codegen/expressions.go @@ -761,7 +761,7 @@ func (c *codegen) VisitMember(expr *ast.Member) { if node.Underlying().IsStatic() { c.exprResult = c.getStaticVariable(node) } else { - struct_ := expr.Value.Result().Type.(ast.StructType) + struct_ := node.Struct() fields, _ := abi.GetStructLayout(struct_.Underlying()).Fields(abi.GetTargetAbi(), struct_) _, i := getField(fields, node.Name()) diff --git a/core/llvm/instructions.go b/core/llvm/instructions.go index fd4963a..0f080e0 100644 --- a/core/llvm/instructions.go +++ b/core/llvm/instructions.go @@ -386,6 +386,9 @@ func (w *textWriter) writeInstruction(inst ir.Inst) { w.writeString(", ") } + var argTP *ir.PointerType + var prevArgTByVal ir.Type + if i < len(type_.Params) { param := type_.Params[i] @@ -394,9 +397,22 @@ func (w *textWriter) writeInstruction(inst ir.Inst) { w.writeString("metadata ") } } + + if paramT, ok := param.Type().(*ir.PointerType); ok { + if argT, ok := arg.Type().(*ir.PointerType); ok { + argTP = argT + prevArgTByVal = argT.ByVal + + argT.ByVal = paramT.ByVal + } + } } w.writeValue(arg) + + if argTP != nil { + argTP.ByVal = prevArgTByVal + } } w.isArgument = false diff --git a/core/typeresolver/typeresolver.go b/core/typeresolver/typeresolver.go index 51a8767..5d77ede 100644 --- a/core/typeresolver/typeresolver.go +++ b/core/typeresolver/typeresolver.go @@ -109,30 +109,8 @@ func (t *typeResolver) visitType(type_ ast.Type) { } if resolved != nil { - // Visit children - type_.AcceptChildren(t) - - // Specialize struct if needed - if s, ok := ast.As[*ast.Struct](resolved); ok { - if len(resolvable.GenericArgs) != len(s.GenericParams) { - if resolvable.GenericArgs != nil { - errorSlice(t, resolvable.GenericArgs, "Got '%d' generic arguments but struct takes '%d'", len(resolvable.GenericArgs), len(s.GenericParams)) - } else { - t.error(resolvable.Parts[len(resolvable.Parts)-1], "Got '%d' generic arguments but struct takes '%d'", len(resolvable.GenericArgs), len(s.GenericParams)) - } - } - - if len(resolvable.GenericArgs) != 0 { - resolved = s.Specialize(resolvable.GenericArgs) - } - } else if len(resolvable.GenericArgs) != 0 { - errorSlice(t, resolvable.GenericArgs, "This type doesn't have any generic parameters") - } - // Store resolved type resolvable.Type = resolved - - return } else { // Report an error str := strings.Builder{} @@ -145,7 +123,7 @@ func (t *typeResolver) visitType(type_ ast.Type) { str.WriteString(part.String()) } - t.error(resolvable, "Unknown type '%s'", str.String()) + errorNode(t.reporter, resolvable, "Unknown type '%s'", str.String()) resolvable.Type = &ast.Primitive{Kind: ast.Void} if t.expr != nil { @@ -189,23 +167,23 @@ func (t *typeResolver) VisitNode(node ast.Node) { // Utils -func (t *typeResolver) error(node ast.Node, format string, args ...any) { +func errorNode(reporter utils.Reporter, node ast.Node, format string, args ...any) { if ast.IsNil(node) { return } - t.reporter.Report(utils.Diagnostic{ + reporter.Report(utils.Diagnostic{ Kind: utils.ErrorKind, Range: node.Cst().Range, Message: fmt.Sprintf(format, args...), }) } -func errorSlice[T ast.Node](t *typeResolver, nodes []T, format string, args ...any) { +func errorSlice[T ast.Node](reporter utils.Reporter, nodes []T, format string, args ...any) { start := nodes[0].Cst().Range.Start end := nodes[len(nodes)-1].Cst().Range.End - t.reporter.Report(utils.Diagnostic{ + reporter.Report(utils.Diagnostic{ Kind: utils.ErrorKind, Range: core.Range{ Start: start, diff --git a/core/typeresolver/typespecializer.go b/core/typeresolver/typespecializer.go new file mode 100644 index 0000000..08fd9aa --- /dev/null +++ b/core/typeresolver/typespecializer.go @@ -0,0 +1,45 @@ +package typeresolver + +import ( + "fireball/core/ast" + "fireball/core/utils" +) + +type typeSpecializer struct { + reporter utils.Reporter +} + +func Specialize(reporter utils.Reporter, file *ast.File) { + s := typeSpecializer{reporter: reporter} + + s.VisitNode(file) +} + +func (t *typeSpecializer) visitResolvable(resolvable *ast.Resolvable) { + // Specialize struct if needed + if s, ok := ast.As[*ast.Struct](resolvable.Type); ok { + if len(resolvable.GenericArgs) != len(s.GenericParams) { + if resolvable.GenericArgs != nil { + errorSlice(t.reporter, resolvable.GenericArgs, "Got '%d' generic arguments but struct takes '%d'", len(resolvable.GenericArgs), len(s.GenericParams)) + } else { + errorNode(t.reporter, resolvable.Parts[len(resolvable.Parts)-1], "Got '%d' generic arguments but struct takes '%d'", len(resolvable.GenericArgs), len(s.GenericParams)) + } + } + + if len(resolvable.GenericArgs) != 0 { + resolvable.Type = s.Specialize(resolvable.GenericArgs) + } + } else if len(resolvable.GenericArgs) != 0 { + errorSlice(t.reporter, resolvable.GenericArgs, "This type doesn't have any generic parameters") + } +} + +// ast.Visitor + +func (t *typeSpecializer) VisitNode(node ast.Node) { + if resolvable, ok := node.(*ast.Resolvable); ok { + t.visitResolvable(resolvable) + } + + node.AcceptChildren(t) +} diff --git a/core/workspace/file.go b/core/workspace/file.go index 9bf7c69..e6dfc07 100644 --- a/core/workspace/file.go +++ b/core/workspace/file.go @@ -65,6 +65,11 @@ func (f *File) SetText(text string, parse bool) { } } + // Specialize types + for _, file := range f.Project.Files { + typeresolver.Specialize(file, file.Ast) + } + // Check for _, file := range f.Project.Files { if resolver := f.Project.getNamespace(file.Ast); resolver != nil { diff --git a/core/workspace/project.go b/core/workspace/project.go index 7004b47..cbf1068 100644 --- a/core/workspace/project.go +++ b/core/workspace/project.go @@ -213,6 +213,11 @@ func (p *Project) LoadFiles() error { } } + // Specialize types + for _, file := range p.Files { + typeresolver.Specialize(file, file.Ast) + } + // Check for _, file := range p.Files { if resolver := p.getNamespace(file.Ast); resolver != nil { @@ -268,6 +273,10 @@ func (p *Project) RemoveFile(path string) bool { } } + for _, file := range p.Files { + typeresolver.Specialize(file, file.Ast) + } + for _, file := range p.Files { if resolver := p.getNamespace(file.Ast); resolver != nil { checker.Check(file, resolver, file.Ast)