Skip to content

Commit

Permalink
CORE: Uhh fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
MineGame159 committed Feb 22, 2024
1 parent d11956c commit 508f2e6
Show file tree
Hide file tree
Showing 14 changed files with 164 additions and 53 deletions.
2 changes: 1 addition & 1 deletion cmd/cmd/test.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ func testCmd(_ *cobra.Command, _ []string) {
_, _ = namespaceStyle.Printf("%s.", part)
}

if receiver := test.Receiver(); receiver != nil {
if receiver, ok := test.Receiver().(ast.StructType); ok {
_, _ = namespaceStyle.Printf("%s.", receiver.Underlying().Name)
}

Expand Down
3 changes: 3 additions & 0 deletions cmd/lsp/implementations.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ import (

func getImplementations(node ast.Node, pos core.Pos, resolver ast.Resolver) []protocol.Location {
leaf := ast.GetLeaf(node, pos)
if ast.IsNil(leaf) {
return nil
}

builder := implementationBuilder{}

Expand Down
59 changes: 47 additions & 12 deletions core/ast/generics.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,24 +32,36 @@ func (s *Struct) Specialize(types []Type) StructType {
fields: fields,
}

for i := 0; i < len(staticFields); i++ {
staticFields[i].struct_ = spec
}

for i := 0; i < len(fields); i++ {
fields[i].struct_ = spec
}

s.Specializations = append(s.Specializations, spec)
return spec
}

func specializeFields(s *Struct, types []Type, fields []*Field) []SpecializedField {
specFields := make([]SpecializedField, len(fields))
specFields := make([]SpecializedField, 0, len(fields))

for i, field := range fields {
type_ := specialize(s.GenericParams, types, field.Type_)
for _, field := range fields {
if IsNil(field.Type()) {
continue
}

type_ := specialize(s.GenericParams, types, field.Type())

if type_ == nil {
type_ = field.Type_
type_ = field.Type()
}

specFields[i] = SpecializedField{
specFields = append(specFields, SpecializedField{
wrapper: wrapper[*Field]{wrapped: field},
type_: type_,
}
})
}

return specFields
Expand All @@ -70,7 +82,7 @@ func (f *Func) Specialize(types []Type) FuncType {
return spec
}

func specializeFunc(receiver StructType, specializations *[]*SpecializedFunc, f SpecializableFunc, types []Type) (FuncType, bool) {
func specializeFunc(receiver Type, specializations *[]*SpecializedFunc, f SpecializableFunc, types []Type) (FuncType, bool) {
// Check cache
for _, spec := range *specializations {
if slices.EqualFunc(spec.Types, types, typesEquals) {
Expand Down Expand Up @@ -241,6 +253,14 @@ func shallowCopyStruct(s StructType, types []Type) *SpecializedStruct {
fields: fields,
}

for i := 0; i < len(staticFields); i++ {
staticFields[i].struct_ = spec
}

for i := 0; i < len(fields); i++ {
fields[i].struct_ = spec
}

s.Underlying().Specializations = append(s.Underlying().Specializations, spec)
return spec
}
Expand Down Expand Up @@ -281,19 +301,26 @@ func shallowCopyFuncType(f *Func, types []Type) *SpecializedFunc {
type SpecializedField struct {
wrapper[*Field]

type_ Type
struct_ StructType
type_ Type
}

func (s *SpecializedField) Clone() Node {
return &SpecializedField{
wrapper: s.wrapper,
struct_: s.struct_,
type_: s.type_,
}
}

func (s *SpecializedField) Underlying() *Field {
return s.wrapped
}

func (s *SpecializedField) Struct() StructType {
return s.struct_
}

func (s *SpecializedField) Name() *Token {
return s.Underlying().Name_
}
Expand Down Expand Up @@ -493,7 +520,7 @@ func (p *PartiallySpecializedFunc) Underlying() *Func {
return p.wrapped
}

func (p *PartiallySpecializedFunc) Receiver() StructType {
func (p *PartiallySpecializedFunc) Receiver() Type {
return p.receiver
}

Expand Down Expand Up @@ -536,7 +563,7 @@ func (p *PartiallySpecializedFunc) Specialize(types []Type) FuncType {
type SpecializedFunc struct {
wrapper[*Func]

receiver StructType
receiver Type
Types []Type

params []SpecializedParam
Expand Down Expand Up @@ -577,7 +604,7 @@ func (s *SpecializedFunc) Underlying() *Func {
return s.wrapped
}

func (s *SpecializedFunc) Receiver() StructType {
func (s *SpecializedFunc) Receiver() Type {
return s.receiver
}

Expand All @@ -595,7 +622,15 @@ func (s *SpecializedFunc) Returns() Type {

func (s *SpecializedFunc) MangledName(name *strings.Builder) {
// Base
receiver := s.Receiver()
var receiver StructType

if s.Receiver() != nil {
if s, ok := s.Receiver().(StructType); ok {
receiver = s
} else {
panic("ast.SpecializedFunc.MangledName() - Receiver is not a StructType, this shouldn't happen")
}
}

if receiver == nil {
receiver = s.Underlying().Struct()
Expand Down
4 changes: 3 additions & 1 deletion core/ast/interfaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ type FieldLike interface {

Underlying() *Field

Struct() StructType

Name() *Token
Type() Type
}
Expand Down Expand Up @@ -44,7 +46,7 @@ type FuncType interface {
Type

Underlying() *Func
Receiver() StructType
Receiver() Type

ParameterCount() int
ParameterIndex(index int) SpecializedParam
Expand Down
10 changes: 9 additions & 1 deletion core/ast/interfaces_impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ func (f *Field) Underlying() *Field {
return f
}

func (f *Field) Struct() StructType {
return f.Parent().(StructType)
}

func (f *Field) Name() *Token {
return f.Name_
}
Expand Down Expand Up @@ -72,11 +76,15 @@ func (f *Func) Underlying() *Func {
return f
}

func (f *Func) Receiver() StructType {
func (f *Func) Receiver() Type {
if impl, ok := f.Parent().(*Impl); ok && !f.IsStatic() {
return impl.Type.(*Struct)
}

if inter, ok := f.Parent().(*Interface); ok {
return inter
}

return nil
}

Expand Down
4 changes: 4 additions & 0 deletions core/ast/types_manual.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,10 @@ func (r *Resolvable) Equals(other Type) bool {
panic("ast.Resolvable.Equals() - Not resolved")
}

if IsNil(other) {
return false
}

return r.Resolved().Equals(other.Resolved())
}

Expand Down
6 changes: 6 additions & 0 deletions core/codegen/codegen.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,13 @@ func Emit(ctx *Context, path string, root ast.RootResolver, file *ast.File) *ir.
if len(s.GenericParams) > 0 {
for _, spec := range s.Specializations {
for _, method := range spec.Methods {
var sp specializer
sp.prepare(method.Underlying(), s.GenericParams)

sp.specialize(spec.Types)
c.defineOrDeclare(method)

sp.finish()
}
}

Expand Down
20 changes: 10 additions & 10 deletions core/codegen/declarations.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,26 +96,26 @@ func (c *codegen) genFunc(f ast.FuncType) {
c.resolver = ast.NewGenericResolver(c.resolver, f.Underlying().GenericParams)
}

// Add this variable
if receiver := f.Receiver(); receiver != nil {
name := scanner.Token{Kind: scanner.Identifier, Lexeme: "this"}
node := cst.Node{Kind: cst.TokenNode, Token: name, Range: decl.Name.Cst().Range}

c.scopes.addVariable(ast.NewToken(node, name), receiver, function.Typ.Params[0], 1)
}

// Copy parameters
// Classify
funcAbi := abi.GetFuncAbi(decl)
returnArgs := funcAbi.Classify(f.Returns(), nil)

index := 0

if len(returnArgs) == 1 && returnArgs[0].Class == abi.Memory {
index++
}
if decl.Receiver() != nil {

// Add this variable
if receiver := f.Receiver(); receiver != nil {
name := scanner.Token{Kind: scanner.Identifier, Lexeme: "this"}
node := cst.Node{Kind: cst.TokenNode, Token: name, Range: decl.Name.Cst().Range}

c.scopes.addVariable(ast.NewToken(node, name), receiver, function.Typ.Params[index], 1)
index++
}

// Copy parameters
paramI := index

for i := 0; i < f.ParameterCount(); i++ {
Expand Down
2 changes: 1 addition & 1 deletion core/codegen/expressions.go
Original file line number Diff line number Diff line change
Expand Up @@ -761,7 +761,7 @@ func (c *codegen) VisitMember(expr *ast.Member) {
if node.Underlying().IsStatic() {
c.exprResult = c.getStaticVariable(node)
} else {
struct_ := expr.Value.Result().Type.(ast.StructType)
struct_ := node.Struct()
fields, _ := abi.GetStructLayout(struct_.Underlying()).Fields(abi.GetTargetAbi(), struct_)
_, i := getField(fields, node.Name())

Expand Down
16 changes: 16 additions & 0 deletions core/llvm/instructions.go
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,9 @@ func (w *textWriter) writeInstruction(inst ir.Inst) {
w.writeString(", ")
}

var argTP *ir.PointerType
var prevArgTByVal ir.Type

if i < len(type_.Params) {
param := type_.Params[i]

Expand All @@ -394,9 +397,22 @@ func (w *textWriter) writeInstruction(inst ir.Inst) {
w.writeString("metadata ")
}
}

if paramT, ok := param.Type().(*ir.PointerType); ok {
if argT, ok := arg.Type().(*ir.PointerType); ok {
argTP = argT
prevArgTByVal = argT.ByVal

argT.ByVal = paramT.ByVal
}
}
}

w.writeValue(arg)

if argTP != nil {
argTP.ByVal = prevArgTByVal
}
}

w.isArgument = false
Expand Down
32 changes: 5 additions & 27 deletions core/typeresolver/typeresolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,30 +109,8 @@ func (t *typeResolver) visitType(type_ ast.Type) {
}

if resolved != nil {
// Visit children
type_.AcceptChildren(t)

// Specialize struct if needed
if s, ok := ast.As[*ast.Struct](resolved); ok {
if len(resolvable.GenericArgs) != len(s.GenericParams) {
if resolvable.GenericArgs != nil {
errorSlice(t, resolvable.GenericArgs, "Got '%d' generic arguments but struct takes '%d'", len(resolvable.GenericArgs), len(s.GenericParams))
} else {
t.error(resolvable.Parts[len(resolvable.Parts)-1], "Got '%d' generic arguments but struct takes '%d'", len(resolvable.GenericArgs), len(s.GenericParams))
}
}

if len(resolvable.GenericArgs) != 0 {
resolved = s.Specialize(resolvable.GenericArgs)
}
} else if len(resolvable.GenericArgs) != 0 {
errorSlice(t, resolvable.GenericArgs, "This type doesn't have any generic parameters")
}

// Store resolved type
resolvable.Type = resolved

return
} else {
// Report an error
str := strings.Builder{}
Expand All @@ -145,7 +123,7 @@ func (t *typeResolver) visitType(type_ ast.Type) {
str.WriteString(part.String())
}

t.error(resolvable, "Unknown type '%s'", str.String())
errorNode(t.reporter, resolvable, "Unknown type '%s'", str.String())
resolvable.Type = &ast.Primitive{Kind: ast.Void}

if t.expr != nil {
Expand Down Expand Up @@ -189,23 +167,23 @@ func (t *typeResolver) VisitNode(node ast.Node) {

// Utils

func (t *typeResolver) error(node ast.Node, format string, args ...any) {
func errorNode(reporter utils.Reporter, node ast.Node, format string, args ...any) {
if ast.IsNil(node) {
return
}

t.reporter.Report(utils.Diagnostic{
reporter.Report(utils.Diagnostic{
Kind: utils.ErrorKind,
Range: node.Cst().Range,
Message: fmt.Sprintf(format, args...),
})
}

func errorSlice[T ast.Node](t *typeResolver, nodes []T, format string, args ...any) {
func errorSlice[T ast.Node](reporter utils.Reporter, nodes []T, format string, args ...any) {
start := nodes[0].Cst().Range.Start
end := nodes[len(nodes)-1].Cst().Range.End

t.reporter.Report(utils.Diagnostic{
reporter.Report(utils.Diagnostic{
Kind: utils.ErrorKind,
Range: core.Range{
Start: start,
Expand Down
Loading

0 comments on commit 508f2e6

Please sign in to comment.