From 263e4c9bc31cff1a66dc381ef1c024d740d616f2 Mon Sep 17 00:00:00 2001 From: MineGame159 Date: Fri, 26 Jan 2024 23:33:41 +0100 Subject: [PATCH] CORE: Add 'is' operator --- core/ast/cst2ast/expressions.go | 7 +++- core/ast/expressions.go | 26 ++++++++---- core/checker/expressions.go | 21 ++++++++-- core/codegen/expressions.go | 73 ++++++++++++++++++++++++++++----- core/codegen/vtables.go | 33 ++++++++++++--- core/cst/expressions.go | 6 +-- core/scanner/scanner.go | 2 + core/scanner/token.go | 17 +++++++- gen/ast.go | 1 + tests/src/interfaces.fb | 8 ++++ vscode/fireball.tmGrammar.json | 2 +- 11 files changed, 162 insertions(+), 34 deletions(-) diff --git a/core/ast/cst2ast/expressions.go b/core/ast/cst2ast/expressions.go index a0770c9..28bc656 100644 --- a/core/ast/cst2ast/expressions.go +++ b/core/ast/cst2ast/expressions.go @@ -82,7 +82,7 @@ func (c *converter) convertBinaryExpr(node cst.Node) ast.Expr { if node.Contains(scanner.Dot) { return c.convertMemberExpr(node) } - if node.Contains(scanner.As) { + if node.ContainsAny(scanner.CastOperators) { return c.convertCastExpr(node) } @@ -162,17 +162,20 @@ func (c *converter) convertIndexExpr(node cst.Node) ast.Expr { func (c *converter) convertCastExpr(node cst.Node) ast.Expr { var value ast.Expr + var operator *ast.Token var target ast.Type for _, child := range node.Children { if child.Kind.IsExpr() { value = c.convertExpr(child) + } else if child.Token.Kind.IsAny(scanner.CastOperators) { + operator = c.convertToken(child) } else if child.Kind.IsType() { target = c.convertType(child) } } - if c := ast.NewCast(node, value, target); c != nil { + if c := ast.NewCast(node, value, operator, target); c != nil { return c } diff --git a/core/ast/expressions.go b/core/ast/expressions.go index 9dee687..f8387a6 100644 --- a/core/ast/expressions.go +++ b/core/ast/expressions.go @@ -732,26 +732,31 @@ type Cast struct { cst cst.Node parent Node - Value Expr - Target Type + Value Expr + Operator *Token + Target Type result ExprResult } -func NewCast(node cst.Node, value Expr, target Type) *Cast { - if value == nil && target == nil { +func NewCast(node cst.Node, value Expr, operator *Token, target Type) *Cast { + if value == nil && operator == nil && target == nil { return nil } c := &Cast{ - cst: node, - Value: value, - Target: target, + cst: node, + Value: value, + Operator: operator, + Target: target, } if value != nil { value.SetParent(c) } + if operator != nil { + operator.SetParent(c) + } if target != nil { target.SetParent(c) } @@ -787,6 +792,9 @@ func (c *Cast) AcceptChildren(visitor Visitor) { if c.Value != nil { visitor.VisitNode(c.Value) } + if c.Operator != nil { + visitor.VisitNode(c.Operator) + } if c.Target != nil { visitor.VisitNode(c.Target) } @@ -801,6 +809,10 @@ func (c *Cast) Clone() Node { c2.Value = c.Value.Clone().(Expr) c2.Value.SetParent(c2) } + if c.Operator != nil { + c2.Operator = c.Operator.Clone().(*Token) + c2.Operator.SetParent(c2) + } if c.Target != nil { c2.Target = c.Target.Clone().(Type) c2.Target.SetParent(c2) diff --git a/core/checker/expressions.go b/core/checker/expressions.go index 63a9bc8..5274ef9 100644 --- a/core/checker/expressions.go +++ b/core/checker/expressions.go @@ -550,11 +550,24 @@ func (c *checker) VisitCast(expr *ast.Cast) { return } - expr.Result().SetValue(expr.Target, 0, nil) + // Check based on the operator + switch expr.Operator.Token().Kind { + case scanner.As: + if _, ok := ast.GetCast(expr.Value.Result().Type, expr.Target); !ok { + c.error(expr, "Cannot cast type '%s' to type '%s'", ast.PrintType(expr.Value.Result().Type), ast.PrintType(expr.Target)) + } - // Check type - if _, ok := ast.GetCast(expr.Value.Result().Type, expr.Target); !ok { - c.error(expr, "Cannot cast type '%s' to type '%s'", ast.PrintType(expr.Value.Result().Type), ast.PrintType(expr.Target)) + expr.Result().SetValue(expr.Target, 0, nil) + + case scanner.Is: + if _, ok := ast.As[*ast.Interface](expr.Value.Result().Type); !ok { + c.error(expr.Value, "Runtime type checking is only supported for interfaces") + } + + expr.Result().SetValue(&ast.Primitive{Kind: ast.Bool}, 0, nil) + + default: + panic("checker.VisitCast() - Not implemented") } } diff --git a/core/codegen/expressions.go b/core/codegen/expressions.go index f36451e..00d3d3a 100644 --- a/core/codegen/expressions.go +++ b/core/codegen/expressions.go @@ -522,9 +522,57 @@ func (c *codegen) VisitAssignment(expr *ast.Assignment) { } func (c *codegen) VisitCast(expr *ast.Cast) { - value := c.acceptExpr(expr.Value) + switch expr.Operator.Token().Kind { + case scanner.As: + value := c.acceptExpr(expr.Value) + c.exprResult = c.cast(value, expr.Value.Result().Type, expr.Target, expr) + + case scanner.Is: + value := c.loadExpr(expr.Value) + + // Get vtable pointer + vtablePtr := c.block.Add(&ir.ExtractValueInst{ + Value: value.v, + Indices: []uint32{uint32(0)}, + }) + + // Get type id + typ := c.vtables.getType(expr.Value.Result().Type.Resolved().(*ast.Interface)) + typPtr := &ir.PointerType{Pointee: typ} + + typeIdPtr := c.block.Add(&ir.GetElementPtrInst{ + PointerTyp: typPtr, + Typ: typ, + Pointer: vtablePtr, + Indices: []ir.Value{ + &ir.IntConst{Typ: ir.I32, Value: ir.Unsigned(0)}, + &ir.IntConst{Typ: ir.I32, Value: ir.Unsigned(0)}, + }, + Inbounds: true, + }) + + typeId := c.block.Add(&ir.LoadInst{ + Typ: ir.I32, + Pointer: typeIdPtr, + }) - c.exprResult = c.cast(value, expr.Value.Result().Type, expr.Target, expr) + // Compare + result := c.block.Add(&ir.ICmpInst{ + Kind: ir.Eq, + Signed: false, + Left: typeId, + Right: &ir.IntConst{ + Typ: ir.I32, + Value: ir.Unsigned(uint64(c.ctx.GetTypeID(expr.Target))), + }, + }) + + c.setLocationMeta(result, expr.Operator) + c.exprResult = exprValue{v: result} + + default: + panic("codegen.VisitCast() - Not implemented") + } } func (c *codegen) VisitTypeCall(expr *ast.TypeCall) { @@ -731,41 +779,44 @@ func (c *codegen) VisitMember(expr *ast.Member) { value = c.load(value, expr.Value.Result().Type) } + // Get vtable pointer vtablePtr := c.block.Add(&ir.ExtractValueInst{ Value: value.v, Indices: []uint32{uint32(0)}, }) + // Get function _, index := inter.GetMethod(expr.Name.String()) - void := ast.Primitive{Kind: ast.Void} - fnPtr := ast.Pointer{Pointee: &void} - - typ := ast.Array{Base: &fnPtr, Count: uint32(len(inter.Methods))} - typPtr := ast.Pointer{Pointee: &typ} + typ := c.vtables.getType(inter) + typPtr := &ir.PointerType{Pointee: typ} functionPtr := c.block.Add(&ir.GetElementPtrInst{ - PointerTyp: c.types.get(&typPtr), - Typ: c.types.get(&typ), + PointerTyp: typPtr, + Typ: typ, Pointer: vtablePtr, Indices: []ir.Value{ &ir.IntConst{Typ: ir.I32, Value: ir.Unsigned(0)}, + &ir.IntConst{Typ: ir.I32, Value: ir.Unsigned(1)}, &ir.IntConst{Typ: ir.I32, Value: ir.Unsigned(uint64(index))}, }, Inbounds: true, }) + fnPtr := typ.Fields[1].(*ir.ArrayType).Base + function := c.block.Add(&ir.LoadInst{ - Typ: c.types.get(&fnPtr), + Typ: fnPtr, Pointer: functionPtr, - Align: fnPtr.Align(), }) + // Get data pointer dataPtr := c.block.Add(&ir.ExtractValueInst{ Value: value.v, Indices: []uint32{uint32(1)}, }) + // Return c.exprResult = exprValue{v: function} c.this = exprValue{v: dataPtr, addressable: true} diff --git a/core/codegen/vtables.go b/core/codegen/vtables.go index 80dabfe..47fefe3 100644 --- a/core/codegen/vtables.go +++ b/core/codegen/vtables.go @@ -34,14 +34,22 @@ func (v *vtables) get(type_, inter ast.Type) ir.Value { methods[i] = v.c.getFunction(method).v } + typ := v.getType(inter.Resolved().(*ast.Interface)) + value := v.c.module.Constant( getVtableName(type_, inter), - &ir.ArrayConst{ - Typ: &ir.ArrayType{ - Count: uint32(len(methods)), - Base: &ir.PointerType{}, + &ir.StructConst{ + Typ: typ, + Fields: []ir.Value{ + &ir.IntConst{ + Typ: ir.I32, + Value: ir.Unsigned(uint64(v.c.ctx.GetTypeID(type_))), + }, + &ir.ArrayConst{ + Typ: typ.Fields[1], + Values: methods, + }, }, - Values: methods, }, ) @@ -54,6 +62,21 @@ func (v *vtables) get(type_, inter ast.Type) ir.Value { return value } +func (v *vtables) getType(inter *ast.Interface) *ir.StructType { + funcPtrArrayType := &ir.ArrayType{ + Count: uint32(len(inter.Methods)), + Base: &ir.PointerType{}, + } + + return &ir.StructType{ + Name: "", + Fields: []ir.Type{ + ir.I32, + funcPtrArrayType, + }, + } +} + func getVtableName(type_, inter ast.Type) string { sb := strings.Builder{} sb.WriteString("__fb_vtable__") diff --git a/core/cst/expressions.go b/core/cst/expressions.go index f6a6005..1f99640 100644 --- a/core/cst/expressions.go +++ b/core/cst/expressions.go @@ -194,7 +194,7 @@ func parseStructFieldExpr(p *parser) Node { func parseInfixExprPratt(p *parser, op scanner.TokenKind, lhs Node, rightPower int) Node { switch op { - case scanner.As: + case scanner.As, scanner.Is: p.begin(BinaryExprNode) p.childAdd(lhs) @@ -326,8 +326,8 @@ func init() { infix(false, scanner.Ampersand) // ==, != infix(false, scanner.EqualEqual, scanner.BangEqual) - // >, <=, >, >=, as - infix(false, scanner.Less, scanner.LessEqual, scanner.Greater, scanner.GreaterEqual, scanner.As) + // >, <=, >, >=, as, is + infix(false, scanner.Less, scanner.LessEqual, scanner.Greater, scanner.GreaterEqual, scanner.As, scanner.Is) // <<, >> infix(false, scanner.LessLess, scanner.GreaterGreater) // +, - diff --git a/core/scanner/scanner.go b/core/scanner/scanner.go index 6feed9b..68d9931 100644 --- a/core/scanner/scanner.go +++ b/core/scanner/scanner.go @@ -194,6 +194,8 @@ func (s *Scanner) identifierKind() TokenKind { return s.checkKeyword(2, "pl", Impl) case 'n': return s.checkKeyword(2, "terface", Interface) + case 's': + return Is } } case 'n': diff --git a/core/scanner/token.go b/core/scanner/token.go index 9a913f1..d5cf2c2 100644 --- a/core/scanner/token.go +++ b/core/scanner/token.go @@ -1,6 +1,9 @@ package scanner -import "fireball/core" +import ( + "fireball/core" + "slices" +) type TokenKind uint8 @@ -68,6 +71,7 @@ const ( While For As + Is Static Func Fn @@ -93,6 +97,10 @@ const ( Eof ) +func (t TokenKind) IsAny(kinds []TokenKind) bool { + return slices.Contains(kinds, t) +} + var LogicalOperators = []TokenKind{ And, Or, @@ -115,6 +123,11 @@ var AssignmentOperators = []TokenKind{ GreaterGreaterEqual, } +var CastOperators = []TokenKind{ + As, + Is, +} + type Token struct { Kind TokenKind Lexeme string @@ -304,6 +317,8 @@ func TokenKindStr(kind TokenKind) string { return "'for'" case As: return "'as'" + case Is: + return "'is'" case Static: return "'static'" case Func: diff --git a/gen/ast.go b/gen/ast.go index df01055..2dd3487 100644 --- a/gen/ast.go +++ b/gen/ast.go @@ -198,6 +198,7 @@ var expressions = Group{ node( "Cast", field("value", type_("Expr")), + field("operator", type_("Token")), field("target", type_("Type")), ), node( diff --git a/tests/src/interfaces.fb b/tests/src/interfaces.fb index 12a2609..250eda8 100644 --- a/tests/src/interfaces.fb +++ b/tests/src/interfaces.fb @@ -73,6 +73,14 @@ func checkConcrete() bool { return s == &f && s != &s; } +#[Test("is")] +func _is() bool { + var f = Foo {}; + var s Something = &f; + + return (s is Foo) && !(s is i32); +} + func check(s Something, number i32) bool { return s.getNumber() == number; } diff --git a/vscode/fireball.tmGrammar.json b/vscode/fireball.tmGrammar.json index 3f366e0..4067aae 100644 --- a/vscode/fireball.tmGrammar.json +++ b/vscode/fireball.tmGrammar.json @@ -120,7 +120,7 @@ "name": "string.quoted.double.fb" }, "keyword": { - "match": "\\b(nil|true|false|and|or|var|if|else|while|for|as|static|func|continue|break|return|namespace|using|struct|impl|enum|interface|new|fn)\\b", + "match": "\\b(nil|true|false|and|or|var|if|else|while|for|as|is|static|func|continue|break|return|namespace|using|struct|impl|enum|interface|new|fn)\\b", "name": "keyword.fb" }, "attribute": {