From e774f4371ad57a1f91c5300e5311cfa6ca01aa15 Mon Sep 17 00:00:00 2001 From: MineGame159 Date: Wed, 24 Jan 2024 17:42:39 +0100 Subject: [PATCH] CORE: Rewrite handling of casts CORE: Add int promotion (implicit casts) --- .gitignore | 2 +- core/ast/casts.go | 105 ++++++++++++++++ core/ast/cst2ast/types.go | 7 +- core/ast/types.go | 1 - core/ast/types_manual.go | 74 ++--------- core/checker/checker.go | 10 ++ core/checker/expressions.go | 230 ++++++++++++++--------------------- core/checker/statements.go | 28 ++--- core/codegen/codegen.go | 94 +++++++++++++- core/codegen/declarations.go | 2 + core/codegen/expressions.go | 207 ++++++++----------------------- core/codegen/instructions.go | 4 + core/codegen/statements.go | 17 ++- gen/main.go | 1 - tests/src/casts.fb | 72 +++++++++++ tests/src/implicit_casts.fb | 36 ++++++ 16 files changed, 502 insertions(+), 388 deletions(-) create mode 100644 core/ast/casts.go create mode 100644 tests/src/casts.fb create mode 100644 tests/src/implicit_casts.fb diff --git a/.gitignore b/.gitignore index 3866115..e396430 100644 --- a/.gitignore +++ b/.gitignore @@ -4,4 +4,4 @@ example/build tests/build -fireball +/fireball diff --git a/core/ast/casts.go b/core/ast/casts.go new file mode 100644 index 0000000..b2f8765 --- /dev/null +++ b/core/ast/casts.go @@ -0,0 +1,105 @@ +package ast + +type CastKind uint8 + +const ( + None CastKind = iota + + Truncate + Extend + + Int2Float + Float2Int +) + +func GetCast(from, to Type) (CastKind, bool) { + if from.Equals(to) { + return None, true + } + + switch from := from.Resolved().(type) { + // Primitive -> ... + case *Primitive: + switch to := to.Resolved().(type) { + // Primitive -> Primitive + case *Primitive: + if (IsInteger(from.Kind) && IsInteger(to.Kind)) || (IsFloating(from.Kind) && IsFloating(to.Kind)) { + if to.Size() > from.Size() { + return Extend, true + } else if to.Size() < from.Size() { + return Truncate, true + } else { + return None, true + } + } else if IsInteger(from.Kind) && IsFloating(to.Kind) { + return Int2Float, true + } else if IsFloating(from.Kind) && IsInteger(to.Kind) { + return Float2Int, true + } + + // Primitive (Integer) -> Enum + case *Enum: + if IsInteger(from.Kind) { + if to.Size() > from.Size() { + return Extend, true + } else if to.Size() < from.Size() { + return Truncate, true + } else { + return None, true + } + } + } + + // Pointer -> Pointer, Func + case *Pointer: + switch to.Resolved().(type) { + case *Pointer, *Func: + return None, true + } + + // Enum -> Primitive (Integer) + case *Enum: + if to, ok := As[*Primitive](to); ok && IsInteger(to.Kind) { + if to.Size() > from.Size() { + return Extend, true + } else if to.Size() < from.Size() { + return Truncate, true + } else { + return None, true + } + } + } + + return None, false +} + +func GetImplicitCast(from, to Type) (CastKind, bool) { + if from.Equals(to) { + return None, true + } + + switch from := from.Resolved().(type) { + // Primitive -> Primitive + case *Primitive: + if to, ok := As[*Primitive](to); ok { + // Primitive (smaller integer / floating) -> Primitive (bigger integer / floating) + if ((IsInteger(from.Kind) && IsInteger(to.Kind)) || (IsFloating(from.Kind) && IsFloating(to.Kind))) && to.Size() > from.Size() { + return Extend, true + } + + // Primitive (same integer) -> Primitive (same floating) + // TODO: Allow converting a smaller integer to the next bigger floating (eg. i16 -> f32) + if IsInteger(from.Kind) && IsFloating(to.Kind) && to.Size() == from.Size() { + return Int2Float, true + } + } + + // Pointer -> Pointer (*void) + case *Pointer: + if to, ok := As[*Pointer](to); ok && IsPrimitive(to.Pointee, Void) { + return None, true + } + } + + return None, false +} diff --git a/core/ast/cst2ast/types.go b/core/ast/cst2ast/types.go index ae4ad38..f07acb9 100644 --- a/core/ast/cst2ast/types.go +++ b/core/ast/cst2ast/types.go @@ -3,6 +3,7 @@ package cst2ast import ( "fireball/core/ast" "fireball/core/cst" + "fireball/core/scanner" "strconv" ) @@ -23,10 +24,12 @@ func (c *converter) convertType(node cst.Node) ast.Type { } func (c *converter) convertIdentifierType(node cst.Node) ast.Type { - if len(node.Children) == 1 { + identifier := node.Get(scanner.Identifier) + + if identifier != nil { kind := ast.Unknown - switch node.Children[0].Token.Lexeme { + switch identifier.Token.Lexeme { case "void": kind = ast.Void case "bool": diff --git a/core/ast/types.go b/core/ast/types.go index 47ddacf..e32263f 100644 --- a/core/ast/types.go +++ b/core/ast/types.go @@ -24,7 +24,6 @@ type Type interface { Align() uint32 Equals(other Type) bool - CanAssignTo(other Type) bool Resolved() Type diff --git a/core/ast/types_manual.go b/core/ast/types_manual.go index 0baa9b3..b8875d3 100644 --- a/core/ast/types_manual.go +++ b/core/ast/types_manual.go @@ -54,10 +54,6 @@ func (p *Primitive) Equals(other Type) bool { return IsPrimitive(other, p.Kind) } -func (p *Primitive) CanAssignTo(other Type) bool { - return IsPrimitive(other, p.Kind) -} - // Pointer func (p *Pointer) Size() uint32 { @@ -76,14 +72,6 @@ func (p *Pointer) Equals(other Type) bool { return false } -func (p *Pointer) CanAssignTo(other Type) bool { - if p2, ok := As[*Pointer](other); ok { - return IsPrimitive(p2.Pointee, Void) || typesEquals(p.Pointee, p2.Pointee) - } - - return false -} - // Array func (a *Array) Size() uint32 { @@ -102,10 +90,6 @@ func (a *Array) Equals(other Type) bool { return false } -func (a *Array) CanAssignTo(other Type) bool { - return a.Equals(other) -} - // Resolvable func (r *Resolvable) Size() uint32 { @@ -132,14 +116,6 @@ func (r *Resolvable) Equals(other Type) bool { return r.Resolved().Equals(other.Resolved()) } -func (r *Resolvable) CanAssignTo(other Type) bool { - if r.Resolved() == nil { - panic("ast.Resolvable.Equals() - Not resolved") - } - - return r.Resolved().CanAssignTo(other.Resolved()) -} - // Struct func (s *Struct) Size() uint32 { @@ -163,23 +139,7 @@ func (s *Struct) Align() uint32 { } func (s *Struct) Equals(other Type) bool { - if s2, ok := As[*Struct](other); ok { - return tokensEquals(s.Name, s2.Name) && slices.EqualFunc(s.Fields, s2.Fields, fieldEquals) - } - - return false -} - -func fieldEquals(v1, v2 *Field) bool { - return tokensEquals(v1.Name, v2.Name) && typesEquals(v1.Type, v2.Type) -} - -func (s *Struct) CanAssignTo(other Type) bool { - if s2, ok := As[*Struct](other); ok { - return slices.EqualFunc(s.Fields, s2.Fields, fieldEquals) - } - - return false + return s == other.Resolved() } func (s *Struct) Resolved() Type { @@ -201,19 +161,7 @@ func (e *Enum) Align() uint32 { } func (e *Enum) Equals(other Type) bool { - if e2, ok := As[*Enum](other); ok { - return typesEquals(e.ActualType, e2.ActualType) && slices.EqualFunc(e.Cases, e2.Cases, enumCaseEquals) - } - - return false -} - -func enumCaseEquals(v1, v2 *EnumCase) bool { - return tokensEquals(v1.Name, v2.Name) && v1.ActualValue == v2.ActualValue -} - -func (e *Enum) CanAssignTo(other Type) bool { - return e.Equals(other) + return e == other.Resolved() } func (e *Enum) Resolved() Type { @@ -236,7 +184,11 @@ func (f *Func) Align() uint32 { func (f *Func) Equals(other Type) bool { if f2, ok := As[*Func](other); ok { - return tokensEquals(f.Name, f2.Name) && typesEquals(f.Returns, f2.Returns) && slices.EqualFunc(f.Params, f2.Params, paramEquals) + if f.Name != nil && f2.Name != nil { + return f == f2 + } + + return typesEquals(f.Returns, f2.Returns) && slices.EqualFunc(f.Params, f2.Params, paramEquals) } return false @@ -246,14 +198,6 @@ func paramEquals(v1, v2 *Param) bool { return typesEquals(v1.Type, v2.Type) } -func (f *Func) CanAssignTo(other Type) bool { - if f2, ok := As[*Func](other); ok { - return typesEquals(f.Returns, f2.Returns) && slices.EqualFunc(f.Params, f2.Params, paramEquals) - } - - return false -} - func (f *Func) Resolved() Type { return f } @@ -332,10 +276,6 @@ func (t *typePrinter) VisitNode(node Node) { // Utils -func tokensEquals(t1, t2 *Token) bool { - return (t1 == nil && t2 == nil) || (t1 != nil && t2 != nil && t1.String() == t2.String()) -} - func typesEquals(t1, t2 Type) bool { return (t1 == nil && t2 == nil) || (t1 != nil && t2 != nil && t1.Equals(t2)) } diff --git a/core/checker/checker.go b/core/checker/checker.go index 7ab5b81..d9fc908 100644 --- a/core/checker/checker.go +++ b/core/checker/checker.go @@ -146,6 +146,16 @@ func (c *checker) expectPrimitiveValue(expr ast.Expr, kind ast.PrimitiveKind) { } } +func (c *checker) checkRequired(required ast.Type, expr ast.Expr) { + if required == nil || expr == nil || expr.Result().Kind == ast.InvalidResultKind { + return + } + + if _, ok := ast.GetImplicitCast(expr.Result().Type, required); !ok { + c.error(expr, "Expected a '%s' but got a '%s'", ast.PrintType(required), ast.PrintType(expr.Result().Type)) + } +} + // ast.Visit func (c *checker) VisitNode(node ast.Node) { diff --git a/core/checker/expressions.go b/core/checker/expressions.go index 7fdd2c2..733de2b 100644 --- a/core/checker/expressions.go +++ b/core/checker/expressions.go @@ -155,9 +155,7 @@ func (c *checker) VisitStructInitializer(expr *ast.StructInitializer) { continue } - if !initField.Value.Result().Type.CanAssignTo(field.Type) { - c.error(initField.Value, "Expected a '%s' but got '%s'", ast.PrintType(field.Type), ast.PrintType(initField.Value.Result().Type)) - } + c.checkRequired(field.Type, initField.Value) } } @@ -198,10 +196,7 @@ func (c *checker) VisitArrayInitializer(expr *ast.ArrayInitializer) { if type_ == nil { type_ = value.Result().Type } else { - if !value.Result().Type.CanAssignTo(type_) { - c.error(value, "Expected a '%s' but got '%s'", ast.PrintType(type_), ast.PrintType(value.Result().Type)) - ok = false - } + c.checkRequired(type_, value) } } @@ -383,6 +378,8 @@ func (c *checker) VisitBinary(expr *ast.Binary) { return } + expr.Result().SetInvalid() + if expr.Left.Result().Kind == ast.InvalidResultKind || expr.Right.Result().Kind == ast.InvalidResultKind { return // // Do not cascade errors } @@ -401,95 +398,21 @@ func (c *checker) VisitBinary(expr *ast.Binary) { } if !ok { - expr.Result().SetInvalid() return } - leftType := expr.Left.Result().Type - rightType := expr.Right.Result().Type - - // Check based on the operator - if scanner.IsArithmetic(expr.Operator.Token().Kind) { - // Arithmetic - if left, ok := ast.As[*ast.Primitive](leftType); ok { - if right, ok := ast.As[*ast.Primitive](rightType); ok { - if ast.IsNumber(left.Kind) && ast.IsNumber(right.Kind) && left.Equals(right) { - expr.Result().SetValue(leftType, 0, nil) - return - } - } - } - - c.error(expr, "Expected two equal number types") - expr.Result().SetInvalid() - } else if scanner.IsEquality(expr.Operator.Token().Kind) { - // Equality - valid := false - - if leftType.Equals(rightType) { - // left type == right type - valid = true - } else if left, ok := ast.As[*ast.Primitive](leftType); ok { - // integer == integer || floating == floating - if right, ok := ast.As[*ast.Primitive](rightType); ok { - if (ast.IsInteger(left.Kind) && ast.IsInteger(right.Kind)) || (ast.IsFloating(left.Kind) && ast.IsFloating(right.Kind)) { - valid = true - } - } - } else if left, ok := ast.As[*ast.Pointer](leftType); ok { - if right, ok := ast.As[*ast.Pointer](rightType); ok { - // *void == *? || *? == *void - if ast.IsPrimitive(left.Pointee, ast.Void) || ast.IsPrimitive(right.Pointee, ast.Void) { - valid = true - } - } - } - - if !valid { - c.error(expr, "Cannot check equality for '%s' and '%s'", ast.PrintType(leftType), ast.PrintType(rightType)) - expr.Result().SetInvalid() - } else { - expr.Result().SetValue(&ast.Primitive{Kind: ast.Bool}, 0, nil) - } - } else if scanner.IsComparison(expr.Operator.Token().Kind) { - // Comparison - if left, ok := ast.As[*ast.Primitive](leftType); ok { - if right, ok := ast.As[*ast.Primitive](rightType); ok { - if !ast.IsNumber(left.Kind) || !ast.IsNumber(right.Kind) || !left.Equals(right) { - c.error(expr, "Expected two equal number types") - expr.Result().SetInvalid() - - return - } - } - } - - expr.Result().SetValue(&ast.Primitive{Kind: ast.Bool}, 0, nil) - } else if scanner.IsBitwise(expr.Operator.Token().Kind) { - // Bitwise - if left, ok := ast.As[*ast.Primitive](leftType); ok { - if right, ok := ast.As[*ast.Primitive](rightType); ok { - if left.Equals(right) && ast.IsInteger(left.Kind) { - expr.Result().SetValue(leftType, 0, nil) - return - } - } - } - - c.error(expr, "Expected two equal integer types") - expr.Result().SetInvalid() - } else { - // Error - panic("checker.VisitBinary() - Invalid operator kind") - } + // Check + c.checkBinary(expr, expr.Left, expr.Right, expr.Operator, false) } func (c *checker) VisitLogical(expr *ast.Logical) { expr.AcceptChildren(c) // Check expressions - c.expectPrimitiveValue(expr.Left, ast.Bool) - c.expectPrimitiveValue(expr.Right, ast.Bool) + type_ := ast.Primitive{Kind: ast.Bool} + + c.checkRequired(&type_, expr.Left) + c.checkRequired(&type_, expr.Right) // Set type expr.Result().SetValue(&ast.Primitive{Kind: ast.Bool}, 0, nil) @@ -589,42 +512,13 @@ func (c *checker) VisitAssignment(expr *ast.Assignment) { // Check type if expr.Operator.Token().Kind == scanner.Equal { // Equal - if !expr.Value.Result().Type.CanAssignTo(expr.Assignee.Result().Type) { - c.error(expr.Value, "Expected a '%s' but got '%s'", ast.PrintType(expr.Assignee.Result().Type), ast.PrintType(expr.Value.Result().Type)) - } + c.checkRequired(expr.Assignee.Result().Type, expr.Value) } else { - if scanner.IsArithmetic(expr.Operator.Token().Kind) { - // Arithmetic - valid := false - - if assignee, ok := ast.As[*ast.Primitive](expr.Assignee.Result().Type); ok { - if value, ok := ast.As[*ast.Primitive](expr.Value.Result().Type); ok { - if ast.IsNumber(assignee.Kind) && ast.IsNumber(value.Kind) && assignee.Equals(value) { - valid = true - } - } - } - - if !valid { - c.error(expr.Value, "Expected two equal number types") - } - } else if scanner.IsBitwise(expr.Operator.Token().Kind) { - // Bitwise - valid := false - - if left, ok := ast.As[*ast.Primitive](expr.Assignee.Result().Type); ok { - if right, ok := ast.As[*ast.Primitive](expr.Value.Result().Type); ok { - if left.Equals(right) && ast.IsInteger(left.Kind) { - valid = true - } - } - } + // Binary + c.checkBinary(expr, expr.Assignee, expr.Value, expr.Operator, true) - if !valid { - c.error(expr.Value, "Expected two equal integer types") - } - } else { - panic("checker.VisitAssignment() - Invalid operator") + if expr.Result().Kind == ast.InvalidResultKind { + panic("checker.VisitAssignment() - Not implemented") } } } @@ -651,19 +545,8 @@ func (c *checker) VisitCast(expr *ast.Cast) { expr.Result().SetValue(expr.Target, 0, nil) // Check type - if ast.IsPrimitive(expr.Value.Result().Type, ast.Void) || ast.IsPrimitive(expr.Target, ast.Void) { - // void - c.error(expr, "Cannot cast to or from type 'void'") - } else if _, ok := ast.As[*ast.Enum](expr.Value.Result().Type); ok { - // enum to non integer - if to, ok := ast.As[*ast.Primitive](expr.Target); !ok || !ast.IsInteger(to.Kind) { - c.error(expr, "Can only cast enums to integers, not '%s'", ast.PrintType(to)) - } - } else if _, ok := ast.As[*ast.Enum](expr.Target); ok { - // non integer to enum - if from, ok := ast.As[*ast.Primitive](expr.Value.Result().Type); !ok || !ast.IsInteger(from.Kind) { - c.error(expr, "Can only cast to enums from integers, not '%s'", ast.PrintType(from)) - } + if _, ok := ast.GetCast(expr.Value.Result().Type, expr.Target); !ok { + c.error(expr, "Cannot cast type '%s' to type '%s'", ast.PrintType(expr.Value.Result().Type), ast.PrintType(expr.Target)) } } @@ -721,9 +604,7 @@ func (c *checker) VisitCall(expr *ast.Call) { continue } - if !arg.Result().Type.CanAssignTo(param.Type) { - c.error(arg, "Argument with type '%s' cannot be assigned to a parameter with type '%s'", ast.PrintType(arg.Result().Type), ast.PrintType(param.Type)) - } + c.checkRequired(param.Type, arg) } } @@ -961,6 +842,81 @@ func parentWantsFunction(expr ast.Expr) bool { } } +func (c *checker) checkBinary(expr, left, right ast.Expr, operator *ast.Token, assignment bool) { + // Implicitly cast between left and right types + leftType := left.Result().Type + rightType := right.Result().Type + + castType := leftType + _, castOk := ast.GetImplicitCast(rightType, leftType) + + if !castOk { + if assignment { + c.error(expr, "Expected a '%s' but got a '%s'", ast.PrintType(leftType), ast.PrintType(rightType)) + return + } + + castType = rightType + _, castOk = ast.GetImplicitCast(leftType, rightType) + } + + // Arithmetic + if scanner.IsArithmetic(operator.Token().Kind) { + if left, ok := ast.As[*ast.Primitive](leftType); ok { + if right, ok := ast.As[*ast.Primitive](rightType); ok { + if ast.IsNumber(left.Kind) && ast.IsNumber(right.Kind) && castOk { + expr.Result().SetValue(castType, 0, nil) + return + } + } + } + + c.error(expr, "Operator '%s' cannot be applied to '%s' and '%s'", operator.String(), ast.PrintType(leftType), ast.PrintType(rightType)) + return + } + + // Equality + if !assignment && scanner.IsEquality(operator.Token().Kind) { + if castOk { + expr.Result().SetValue(&ast.Primitive{Kind: ast.Bool}, 0, nil) + return + } + + c.error(expr, "Operator '%s' cannot be applied to '%s' and '%s'", operator.String(), ast.PrintType(leftType), ast.PrintType(rightType)) + return + } + + // Comparison + if scanner.IsComparison(operator.Token().Kind) { + if left, ok := ast.As[*ast.Primitive](leftType); ok { + if right, ok := ast.As[*ast.Primitive](rightType); ok { + if ast.IsNumber(left.Kind) && ast.IsNumber(right.Kind) && castOk { + expr.Result().SetValue(&ast.Primitive{Kind: ast.Bool}, 0, nil) + return + } + } + } + + c.error(expr, "Operator '%s' cannot be applied to '%s' and '%s'", operator.String(), ast.PrintType(leftType), ast.PrintType(rightType)) + return + } + + // Bitwise + if !assignment && scanner.IsBitwise(operator.Token().Kind) { + if left, ok := ast.As[*ast.Primitive](leftType); ok { + if right, ok := ast.As[*ast.Primitive](rightType); ok { + if ast.IsInteger(left.Kind) && ast.IsInteger(right.Kind) && castOk { + expr.Result().SetValue(castType, 0, nil) + return + } + } + } + + c.error(expr, "Operator '%s' cannot be applied to '%s' and '%s'", operator.String(), ast.PrintType(leftType), ast.PrintType(rightType)) + return + } +} + func (c *checker) checkMalloc(expr ast.Expr) { function := c.resolver.GetFunction("malloc") diff --git a/core/checker/statements.go b/core/checker/statements.go index dd92f7e..bbd9d88 100644 --- a/core/checker/statements.go +++ b/core/checker/statements.go @@ -44,9 +44,7 @@ func (c *checker) VisitVar(stmt *ast.Var) { stmt.ActualType = stmt.Value.Result().Type } } else { - if stmt.Value != nil && !stmt.Value.Result().Type.CanAssignTo(stmt.ActualType) { - c.error(stmt.Value, "Initializer with type '%s' cannot be assigned to a variable with type '%s'", ast.PrintType(stmt.Value.Result().Type), ast.PrintType(stmt.ActualType)) - } + c.checkRequired(stmt.ActualType, stmt.Value) } } } @@ -71,7 +69,8 @@ func (c *checker) VisitVar(stmt *ast.Var) { func (c *checker) VisitIf(stmt *ast.If) { stmt.AcceptChildren(c) - c.expectPrimitiveValue(stmt.Condition, ast.Bool) + required := ast.Primitive{Kind: ast.Bool} + c.checkRequired(&required, stmt.Condition) } func (c *checker) VisitWhile(stmt *ast.While) { @@ -80,7 +79,8 @@ func (c *checker) VisitWhile(stmt *ast.While) { c.loopDepth-- // Check condition value - c.expectPrimitiveValue(stmt.Condition, ast.Bool) + required := ast.Primitive{Kind: ast.Bool} + c.checkRequired(&required, stmt.Condition) } func (c *checker) VisitFor(stmt *ast.For) { @@ -94,31 +94,27 @@ func (c *checker) VisitFor(stmt *ast.For) { c.popScope() // Check condition value - c.expectPrimitiveValue(stmt.Condition, ast.Bool) + required := ast.Primitive{Kind: ast.Bool} + c.checkRequired(&required, stmt.Condition) } func (c *checker) VisitReturn(stmt *ast.Return) { stmt.AcceptChildren(c) // Check return value - var type_ ast.Type - var errorNode ast.Node - if stmt.Value != nil { if stmt.Value.Result().Kind != ast.ValueResultKind { c.error(stmt.Value, "Invalid value") return } - type_ = stmt.Value.Result().Type - errorNode = stmt.Value + c.checkRequired(c.function.Returns, stmt.Value) } else { - type_ = &ast.Primitive{Kind: ast.Void} - errorNode = stmt - } + type_ := ast.Primitive{Kind: ast.Void} - if !type_.CanAssignTo(c.function.Returns) { - c.error(errorNode, "Cannot return type '%s' from a function with return type '%s'", ast.PrintType(type_), ast.PrintType(c.function.Returns)) + if !c.function.Returns.Equals(&type_) { + c.error(stmt, "Expected a '%s' but got a 'void'", ast.PrintType(c.function.Returns)) + } } } diff --git a/core/codegen/codegen.go b/core/codegen/codegen.go index 449dfc5..4256094 100644 --- a/core/codegen/codegen.go +++ b/core/codegen/codegen.go @@ -18,8 +18,9 @@ type codegen struct { allocas map[ast.Node]exprValue - function *ir.Func - block *ir.Block + astFunction *ast.Func + function *ir.Func + block *ir.Block loopStart *ir.Block loopEnd *ir.Block @@ -186,6 +187,95 @@ func (c *codegen) loadExpr(expr ast.Expr) exprValue { return c.load(c.acceptExpr(expr), expr.Result().Type) } +func (c *codegen) implicitCast(required ast.Type, value exprValue, valueType ast.Type) exprValue { + if kind, ok := ast.GetImplicitCast(valueType, required); ok && kind != ast.None { + return c.cast(value, valueType, required, nil) + } + + return value +} + +func (c *codegen) implicitCastLoadExpr(required ast.Type, expr ast.Expr) exprValue { + return c.implicitCast(required, c.loadExpr(expr), expr.Result().Type) +} + +func (c *codegen) cast(value exprValue, from, to ast.Type, location ast.Node) exprValue { + kind, ok := ast.GetCast(from, to) + if !ok { + panic("codegen.convertAstCastKind() - ast.GetCast() returned false") + } + + return c.convertCast(value, kind, from, to, location) +} + +func (c *codegen) convertCast(value exprValue, kind ast.CastKind, from, to ast.Type, location ast.Node) exprValue { + if kind == ast.None { + return value + } + + value = c.load(value, from) + toIr := c.types.get(to) + + switch kind { + case ast.Truncate: + result := c.block.Add(&ir.TruncInst{ + Value: value.v, + Typ: toIr, + }) + + c.setLocationMeta(result, location) + return exprValue{v: result} + + case ast.Extend: + var result ir.MetaValue + + if ast.IsFloating(to.Resolved().(*ast.Primitive).Kind) { + result = c.block.Add(&ir.FExtInst{ + Value: value.v, + Typ: toIr, + }) + } else { + signed := ast.IsSigned(to.Resolved().(*ast.Primitive).Kind) + + if from, ok := ast.As[*ast.Primitive](from); !ok || !ast.IsSigned(from.Kind) { + signed = false + } + + result = c.block.Add(&ir.ExtInst{ + SignExtend: signed, + Value: value.v, + Typ: toIr, + }) + } + + c.setLocationMeta(result, location) + return exprValue{v: result} + + case ast.Int2Float: + result := c.block.Add(&ir.I2FInst{ + Signed: ast.IsSigned(from.Resolved().(*ast.Primitive).Kind), + Value: value.v, + Typ: toIr, + }) + + c.setLocationMeta(result, location) + return exprValue{v: result} + + case ast.Float2Int: + result := c.block.Add(&ir.F2IInst{ + Signed: ast.IsSigned(to.Resolved().(*ast.Primitive).Kind), + Value: value.v, + Typ: toIr, + }) + + c.setLocationMeta(result, location) + return exprValue{v: result} + + default: + panic("codegen.convertAstCastKind() - Not implemented") + } +} + // Static / Global variables func (c *codegen) getStaticVariable(field *ast.Field) exprValue { diff --git a/core/codegen/declarations.go b/core/codegen/declarations.go index e74ce87..ab36504 100644 --- a/core/codegen/declarations.go +++ b/core/codegen/declarations.go @@ -29,6 +29,7 @@ func (c *codegen) VisitFunc(decl *ast.Func) { function := c.functions[decl] // Setup state + c.astFunction = decl c.function = function c.beginBlock(function.Block("entry")) @@ -77,6 +78,7 @@ func (c *codegen) VisitFunc(decl *ast.Func) { c.block = nil c.function = nil + c.astFunction = nil } func (c *codegen) VisitGlobalVar(_ *ast.GlobalVar) {} diff --git a/core/codegen/expressions.go b/core/codegen/expressions.go index b672e13..12cc2da 100644 --- a/core/codegen/expressions.go +++ b/core/codegen/expressions.go @@ -156,9 +156,11 @@ func (c *codegen) VisitStructInitializer(expr *ast.StructInitializer) { var result ir.Value = &ir.ZeroInitConst{Typ: type_} - for _, field := range expr.Fields { - element := c.loadExpr(field.Value) - i, _ := struct_.GetField(field.Name.String()) + for _, initField := range expr.Fields { + _, field := struct_.GetField(initField.Name.String()) + + element := c.implicitCastLoadExpr(field.Type, initField.Value) + i, _ := struct_.GetField(initField.Name.String()) r := c.block.Add(&ir.InsertValueInst{ Value: result, @@ -166,7 +168,7 @@ func (c *codegen) VisitStructInitializer(expr *ast.StructInitializer) { Indices: []uint32{uint32(i)}, }) - c.setLocationMeta(r, field) + c.setLocationMeta(r, initField) result = r } @@ -196,12 +198,13 @@ func (c *codegen) VisitStructInitializer(expr *ast.StructInitializer) { } func (c *codegen) VisitArrayInitializer(expr *ast.ArrayInitializer) { + baseType := expr.Result().Type.(*ast.Array).Base type_ := c.types.get(expr.Result().Type) var result ir.Value = &ir.ZeroInitConst{Typ: type_} for i, value := range expr.Values { - element := c.loadExpr(value) + element := c.implicitCastLoadExpr(baseType, value) r := c.block.Add(&ir.InsertValueInst{ Value: result, @@ -218,23 +221,11 @@ func (c *codegen) VisitArrayInitializer(expr *ast.ArrayInitializer) { } func (c *codegen) VisitAllocateArray(expr *ast.AllocateArray) { - count := c.loadExpr(expr.Count) - mallocFunc := c.resolver.GetFunction("malloc") malloc := c.getFunction(mallocFunc) - a, _ := ast.As[*ast.Primitive](expr.Count.Result().Type) - b, _ := ast.As[*ast.Primitive](mallocFunc.Params[0].Type) - - c.castPrimitiveToPrimitive( - count, - expr.Count.Result().Type, - mallocFunc.Params[0].Type, - a.Kind, - b.Kind, - expr, - ) - count = c.exprResult + count := c.loadExpr(expr.Count) + count = c.cast(count, expr.Count.Result().Type, mallocFunc.Params[0].Type, expr) pointer := c.block.Add(&ir.CallInst{ Callee: malloc.v, @@ -379,15 +370,14 @@ func (c *codegen) VisitUnary(expr *ast.Unary) { } func (c *codegen) VisitBinary(expr *ast.Binary) { - left := c.acceptExpr(expr.Left) - right := c.acceptExpr(expr.Right) - - c.exprResult = c.binary(expr.Operator, left, right, expr.Left.Result().Type) + c.exprResult = c.binaryLoad(expr.Left, expr.Right, expr.Operator) } func (c *codegen) VisitLogical(expr *ast.Logical) { - left := c.loadExpr(expr.Left) - right := c.loadExpr(expr.Right) + type_ := ast.Primitive{Kind: ast.Bool} + + left := c.implicitCastLoadExpr(&type_, expr.Left) + right := c.implicitCastLoadExpr(&type_, expr.Right) switch expr.Operator.Token().Kind { case scanner.Or: @@ -509,7 +499,7 @@ func (c *codegen) VisitAssignment(expr *ast.Assignment) { assignee := c.acceptExpr(expr.Assignee) // Value - value := c.loadExpr(expr.Value) + value := c.implicitCastLoadExpr(expr.Assignee.Result().Type, expr.Value) if expr.Operator.Token().Kind != scanner.Equal { value = c.binary( @@ -534,132 +524,7 @@ func (c *codegen) VisitAssignment(expr *ast.Assignment) { func (c *codegen) VisitCast(expr *ast.Cast) { value := c.acceptExpr(expr.Value) - if from, ok := ast.As[*ast.Primitive](expr.Value.Result().Type); ok { - if to, ok := ast.As[*ast.Primitive](expr.Result().Type); ok { - // primitive to primitive - c.castPrimitiveToPrimitive(value, from, to, from.Kind, to.Kind, expr) - return - } - } - - if from, ok := ast.As[*ast.Enum](expr.Value.Result().Type); ok { - if to, ok := ast.As[*ast.Primitive](expr.Result().Type); ok { - // enum to integer - fromT, _ := ast.As[*ast.Primitive](from.Type) - - c.castPrimitiveToPrimitive(value, from, to, fromT.Kind, to.Kind, expr) - return - } - } - - if from, ok := ast.As[*ast.Primitive](expr.Value.Result().Type); ok { - if to, ok := expr.Result().Type.(*ast.Enum); ok { - // integer to enum - toT, _ := ast.As[*ast.Primitive](to.Type) - - c.castPrimitiveToPrimitive(value, from, to, from.Kind, toT.Kind, expr) - return - } - } - - if _, ok := ast.As[*ast.Pointer](expr.Value.Result().Type); ok { - if _, ok := ast.As[*ast.Pointer](expr.Result().Type); ok { - // pointer to pointer - c.exprResult = value - return - } - - if _, ok := ast.As[*ast.Func](expr.Result().Type); ok { - // pointer to function pointer - c.exprResult = value - return - } - } - - // Error - panic("codegen.VisitCast() - Invalid cast") -} - -func (c *codegen) castPrimitiveToPrimitive(value exprValue, from, to ast.Type, fromKind, toKind ast.PrimitiveKind, location ast.Node) { - if fromKind == toKind || (ast.EqualsPrimitiveCategory(fromKind, toKind) && ast.GetBitSize(fromKind) == ast.GetBitSize(toKind)) { - c.exprResult = value - return - } - - value = c.load(value, from) - - if (ast.IsInteger(fromKind) || ast.IsFloating(fromKind)) && toKind == ast.Bool { - // integer / floating to bool - var result ir.MetaValue - - if ast.IsFloating(fromKind) { - result = c.block.Add(&ir.FCmpInst{ - Kind: ir.Ne, - Ordered: false, - Left: value.v, - Right: ir.False, - }) - } else { - result = c.block.Add(&ir.ICmpInst{ - Kind: ir.Ne, - Signed: ast.IsSigned(fromKind), - Left: value.v, - Right: ir.False, - }) - } - - c.setLocationMeta(result, location) - c.exprResult = exprValue{v: result} - } else { - type_ := c.types.get(to) - var result ir.MetaValue - - if (ast.IsInteger(fromKind) || fromKind == ast.Bool) && ast.IsInteger(toKind) { - // integer / bool to integer - if from.Size() > to.Size() { - result = c.block.Add(&ir.TruncInst{ - Value: value.v, - Typ: type_, - }) - } else { - result = c.block.Add(&ir.ExtInst{ - SignExtend: false, - Value: value.v, - Typ: type_, - }) - } - } else if ast.IsFloating(fromKind) && ast.IsFloating(toKind) { - // floating to floating - if from.Size() > to.Size() { - result = c.block.Add(&ir.TruncInst{ - Value: value.v, - Typ: type_, - }) - } else { - result = c.block.Add(&ir.FExtInst{ - Value: value.v, - Typ: type_, - }) - } - } else if (ast.IsInteger(fromKind) || fromKind == ast.Bool) && ast.IsFloating(toKind) { - // integer / bool to floating - result = c.block.Add(&ir.I2FInst{ - Signed: ast.IsSigned(fromKind), - Value: value.v, - Typ: type_, - }) - } else if ast.IsFloating(fromKind) && ast.IsInteger(toKind) { - // floating to integer - result = c.block.Add(&ir.F2IInst{ - Signed: ast.IsSigned(toKind), - Value: value.v, - Typ: type_, - }) - } - - c.setLocationMeta(result, location) - c.exprResult = exprValue{v: result} - } + c.exprResult = c.cast(value, expr.Value.Result().Type, expr.Target, expr) } func (c *codegen) VisitTypeCall(expr *ast.TypeCall) { @@ -709,12 +574,13 @@ func (c *codegen) VisitCall(expr *ast.Call) { } for i, arg := range expr.Args { - index := i if function.Method() != nil { - index++ + args[i+1] = c.loadExpr(arg).v + } else if i >= len(function.Params) { + args[i] = c.loadExpr(arg).v + } else { + args[i] = c.implicitCastLoadExpr(function.Params[i].Type, arg).v } - - args[index] = c.loadExpr(arg).v } // Intrinsic @@ -902,6 +768,35 @@ func (c *codegen) memberLoad(type_ ast.Type, value exprValue) (exprValue, *ast.S // Utils +func (c *codegen) binaryLoad(left, right ast.Expr, operator *ast.Token) exprValue { + // Left -> Right + cast, castOk := ast.GetImplicitCast(left.Result().Type, right.Result().Type) + + if castOk { + to := right.Result().Type + + left := c.convertCast(c.loadExpr(left), cast, left.Result().Type, to, operator) + right := c.loadExpr(right) + + return c.binary(operator, left, right, to) + } + + // Right -> Left + cast, castOk = ast.GetImplicitCast(right.Result().Type, left.Result().Type) + + if castOk { + to := left.Result().Type + + left := c.loadExpr(left) + right := c.convertCast(c.loadExpr(right), cast, right.Result().Type, to, operator) + + return c.binary(operator, left, right, to) + } + + // Invalid + panic("codegen.binaryLoad() - Not implemented") +} + func (c *codegen) binary(op ast.Node, left exprValue, right exprValue, type_ ast.Type) exprValue { left = c.load(left, type_) right = c.load(right, type_) diff --git a/core/codegen/instructions.go b/core/codegen/instructions.go index 339f3c6..79cab5c 100644 --- a/core/codegen/instructions.go +++ b/core/codegen/instructions.go @@ -23,6 +23,10 @@ func (c *codegen) alloca(type_ ast.Type, name string, node ast.Node) ir.Value { } func (c *codegen) setLocationMeta(value ir.MetaValue, node ast.Node) { + if node == nil { + return + } + meta := &ir.LocationMeta{ Scope: c.scopes.getMeta(), } diff --git a/core/codegen/statements.go b/core/codegen/statements.go index ce67666..1425777 100644 --- a/core/codegen/statements.go +++ b/core/codegen/statements.go @@ -26,7 +26,7 @@ func (c *codegen) VisitVar(stmt *ast.Var) { // Initializer if stmt.Value != nil { - initializer := c.loadExpr(stmt.Value) + initializer := c.implicitCastLoadExpr(stmt.ActualType, stmt.Value) store := c.block.Add(&ir.StoreInst{ Pointer: pointer.v, @@ -49,7 +49,9 @@ func (c *codegen) VisitIf(stmt *ast.If) { } // Condition - condition := c.loadExpr(stmt.Condition) + required := ast.Primitive{Kind: ast.Bool} + condition := c.implicitCastLoadExpr(&required, stmt.Condition) + c.block.Add(&ir.BrInst{Condition: condition.v, True: then, False: else_}) // Then @@ -83,7 +85,10 @@ func (c *codegen) VisitWhile(stmt *ast.While) { // Condition c.beginBlock(c.loopStart) - condition := c.acceptExpr(stmt.Condition) + + required := ast.Primitive{Kind: ast.Bool} + condition := c.implicitCastLoadExpr(&required, stmt.Condition) + c.block.Add(&ir.BrInst{Condition: condition.v, True: body, False: c.loopEnd}) // Body @@ -124,7 +129,9 @@ func (c *codegen) VisitFor(stmt *ast.For) { c.beginBlock(c.loopStart) if stmt.Condition != nil { - condition := c.loadExpr(stmt.Condition) + required := ast.Primitive{Kind: ast.Bool} + condition := c.implicitCastLoadExpr(&required, stmt.Condition) + c.block.Add(&ir.BrInst{Condition: condition.v, True: body, False: c.loopEnd}) } @@ -157,7 +164,7 @@ func (c *codegen) VisitReturn(stmt *ast.Return) { ) } else { // Other - value := c.loadExpr(stmt.Value) + value := c.implicitCastLoadExpr(c.astFunction.Returns, stmt.Value) c.setLocationMeta( c.block.Add(&ir.RetInst{Value: value.v}), diff --git a/gen/main.go b/gen/main.go index 641d007..28c8a03 100644 --- a/gen/main.go +++ b/gen/main.go @@ -74,7 +74,6 @@ func genVisitor(w *Writer, group Group) { w.write("Align() uint32") w.write("") w.write("Equals(other Type) bool") - w.write("CanAssignTo(other Type) bool") w.write("") w.write("Resolved() Type") w.write("") diff --git a/tests/src/casts.fb b/tests/src/casts.fb new file mode 100644 index 0000000..47cbb16 --- /dev/null +++ b/tests/src/casts.fb @@ -0,0 +1,72 @@ +namespace Tests.Casts; + +enum Animal { + Dog, + Cat = 9, +} + +#[Test] +func primitive2primitive_extend() bool { + var a = 5 as i32; + return (a as i64) == (5 as i64); +} + +#[Test] +func primitive2primitive_truncate() bool { + var a = 5 as i32; + return (a as i16) == (5 as i16); +} + +#[Test] +func primitive2primitive_none() bool { + var a = 5 as i32; + return (a as u32) == (5 as u32); +} + +#[Test] +func primitive2primitive_float2double() bool { + return (5f as f64) == 5.0; +} + +#[Test] +func primitive2primitive_double2float() bool { + return (5.0 as f32) == 5f; +} + +#[Test] +func primitive2primitive_int2float() bool { + return (5 as f32) == 5f; +} + +#[Test] +func primitive2primitive_float2int() bool { + return (5f as i32) == 5; +} + +#[Test] +func primitive2enum() bool { + var a = 9; + return (a as Animal) == Animal.Cat; +} + +#[Test] +func pointer2pointer() bool { + var a *i32; + var _b = a as *f64; + + return true; +} + +#[Test] +func pointer2func() bool { + var a *i32; + var _b = a as fn () void; + + return true; +} + +#[Test] +func enum2primitive() bool { + var a = Animal.Cat; + return (a as i32) == 9; +} diff --git a/tests/src/implicit_casts.fb b/tests/src/implicit_casts.fb new file mode 100644 index 0000000..610f564 --- /dev/null +++ b/tests/src/implicit_casts.fb @@ -0,0 +1,36 @@ +namespace Tests.Casts.Implicit; + +#[Test] +func pointer2pointer() bool { + var a *i32; + var _b *void = a; + + return true; +} + +#[Test] +func intPromotion() bool { + var a = 5 as i16; + + var b = 1; + b = a; + + var c = a + b; + return c == 10; +} + +#[Test] +func int2float() bool { + var a = 5; + var b f32 = a; + + return b == 5f; +} + +#[Test] +func long2double() bool { + var a i64 = 5; + var b f64 = a; + + return b == 5.0; +}