From 584aa3526d1cdc269ab43e78a954df4d25817616 Mon Sep 17 00:00:00 2001 From: Yaron Sheffer Date: Thu, 2 May 2024 11:39:02 +0300 Subject: [PATCH] Expressions: coverage (#133) * Repro issue 84 * Improve coverage for expressions, specifically Print * Clean up warnings --------- Co-authored-by: ysheffer --- datalog/expressions.go | 72 ++++++++++++++++++++++++------------- datalog/expressions_test.go | 49 +++++++++++++++++++++++++ 2 files changed, 97 insertions(+), 24 deletions(-) diff --git a/datalog/expressions.go b/datalog/expressions.go index d84e3e1..ed229c7 100644 --- a/datalog/expressions.go +++ b/datalog/expressions.go @@ -33,8 +33,12 @@ func (e *Expression) Evaluate(values map[Variable]*Term, symbols *SymbolTable) ( return nil, fmt.Errorf("datalog: expressions: unknown variable %d", id.(Variable)) } id = *idptr + default: // do nothing + } + err := s.Push(id) + if err != nil { + return nil, fmt.Errorf("datalog: expressions: stack overflow") } - s.Push(id) case OpTypeUnary: v, err := s.Pop() if err != nil { @@ -45,7 +49,10 @@ func (e *Expression) Evaluate(values map[Variable]*Term, symbols *SymbolTable) ( if err != nil { return nil, fmt.Errorf("datalog: expressions: unary eval failed: %w", err) } - s.Push(res) + err = s.Push(res) + if err != nil { + return nil, fmt.Errorf("datalog: expressions: stack overflow") + } case OpTypeBinary: right, err := s.Pop() if err != nil { @@ -60,7 +67,10 @@ func (e *Expression) Evaluate(values map[Variable]*Term, symbols *SymbolTable) ( if err != nil { return nil, fmt.Errorf("datalog: expressions: binary eval failed: %w", err) } - s.Push(res) + err = s.Push(res) + if err != nil { + return nil, fmt.Errorf("datalog: expressions: stack overflow") + } default: return nil, fmt.Errorf("datalog: expressions: unsupported Op: %v", op.Type()) } @@ -83,11 +93,20 @@ func (e *Expression) Print(symbols *SymbolTable) string { id := op.(Value).ID switch id.Type() { case TermTypeString: - s.Push(fmt.Sprintf("\"%s\"", symbols.Str(id.(String)))) + err := s.Push(fmt.Sprintf("\"%s\"", symbols.Str(id.(String)))) + if err != nil { + return "" + } case TermTypeVariable: - s.Push(fmt.Sprintf("$%s", symbols.Var(id.(Variable)))) + err := s.Push(fmt.Sprintf("$%s", symbols.Var(id.(Variable)))) + if err != nil { + return "" + } default: - s.Push(id.String()) + err := s.Push(id.String()) + if err != nil { + return "" + } } case OpTypeUnary: v, err := s.Pop() @@ -95,10 +114,10 @@ func (e *Expression) Print(symbols *SymbolTable) string { return "" } res := op.(UnaryOp).Print(v) + err = s.Push(res) if err != nil { - return "" + return "" } - s.Push(res) case OpTypeBinary: right, err := s.Pop() if err != nil { @@ -109,7 +128,10 @@ func (e *Expression) Print(symbols *SymbolTable) string { return "" } res := op.(BinaryOp).Print(left, right) - s.Push(res) + err = s.Push(res) + if err != nil { + return "" + } default: return fmt.Sprintf("", op.Type()) } @@ -160,6 +182,8 @@ func (op UnaryOp) Print(value string) string { out = fmt.Sprintf("!%s", value) case UnaryParens: out = fmt.Sprintf("(%s)", value) + case UnaryLength: + out = fmt.Sprintf("%s.length()", value) default: out = fmt.Sprintf("unknown(%s)", value) } @@ -186,7 +210,7 @@ type Negate struct{} func (Negate) Type() UnaryOpType { return UnaryNegate } -func (Negate) Eval(value Term, symbols *SymbolTable) (Term, error) { +func (Negate) Eval(value Term, _ *SymbolTable) (Term, error) { var out Term switch value.Type() { case TermTypeBool: @@ -206,7 +230,7 @@ type Parens struct{} func (Parens) Type() UnaryOpType { return UnaryParens } -func (Parens) Eval(value Term, symbols *SymbolTable) (Term, error) { +func (Parens) Eval(value Term, _ *SymbolTable) (Term, error) { return value, nil } @@ -228,7 +252,7 @@ func (Length) Eval(value Term, symbols *SymbolTable) (Term, error) { case TermTypeSet: out = Integer(len(value.(Set))) default: - return nil, fmt.Errorf("datalog: unexpected Negate value type: %d", value.Type()) + return nil, fmt.Errorf("datalog: unexpected Length value type: %d", value.Type()) } return out, nil } @@ -318,7 +342,7 @@ type LessThan struct{} func (LessThan) Type() BinaryOpType { return BinaryLessThan } -func (LessThan) Eval(left Term, right Term, symbols *SymbolTable) (Term, error) { +func (LessThan) Eval(left Term, right Term, _ *SymbolTable) (Term, error) { if g, w := left.Type(), right.Type(); g != w { return nil, fmt.Errorf("datalog: LessThan type mismatch: %d != %d", g, w) } @@ -344,7 +368,7 @@ type LessOrEqual struct{} func (LessOrEqual) Type() BinaryOpType { return BinaryLessOrEqual } -func (LessOrEqual) Eval(left Term, right Term, symbols *SymbolTable) (Term, error) { +func (LessOrEqual) Eval(left Term, right Term, _ *SymbolTable) (Term, error) { if g, w := left.Type(), right.Type(); g != w { return nil, fmt.Errorf("datalog: LessOrEqual type mismatch: %d != %d", g, w) } @@ -370,7 +394,7 @@ type GreaterThan struct{} func (GreaterThan) Type() BinaryOpType { return BinaryGreaterThan } -func (GreaterThan) Eval(left Term, right Term, symbols *SymbolTable) (Term, error) { +func (GreaterThan) Eval(left Term, right Term, _ *SymbolTable) (Term, error) { if g, w := left.Type(), right.Type(); g != w { return nil, fmt.Errorf("datalog: GreaterThan type mismatch: %d != %d", g, w) } @@ -396,7 +420,7 @@ type GreaterOrEqual struct{} func (GreaterOrEqual) Type() BinaryOpType { return BinaryGreaterOrEqual } -func (GreaterOrEqual) Eval(left Term, right Term, symbols *SymbolTable) (Term, error) { +func (GreaterOrEqual) Eval(left Term, right Term, _ *SymbolTable) (Term, error) { if g, w := left.Type(), right.Type(); g != w { return nil, fmt.Errorf("datalog: GreaterOrEqual type mismatch: %d != %d", g, w) } @@ -422,7 +446,7 @@ type Equal struct{} func (Equal) Type() BinaryOpType { return BinaryEqual } -func (Equal) Eval(left Term, right Term, symbols *SymbolTable) (Term, error) { +func (Equal) Eval(left Term, right Term, _ *SymbolTable) (Term, error) { if g, w := left.Type(), right.Type(); g != w { return nil, fmt.Errorf("datalog: Equal type mismatch: %d != %d", g, w) } @@ -510,7 +534,7 @@ type Intersection struct{} func (Intersection) Type() BinaryOpType { return BinaryIntersection } -func (Intersection) Eval(left Term, right Term, symbols *SymbolTable) (Term, error) { +func (Intersection) Eval(left Term, right Term, _ *SymbolTable) (Term, error) { set, ok := left.(Set) if !ok { return nil, errors.New("datalog: Intersection left value must be a Set") @@ -530,7 +554,7 @@ type Union struct{} func (Union) Type() BinaryOpType { return BinaryUnion } -func (Union) Eval(left Term, right Term, symbols *SymbolTable) (Term, error) { +func (Union) Eval(left Term, right Term, _ *SymbolTable) (Term, error) { set, ok := left.(Set) if !ok { return nil, errors.New("datalog: Union left value must be a Set") @@ -654,7 +678,7 @@ type Sub struct{} func (Sub) Type() BinaryOpType { return BinarySub } -func (Sub) Eval(left Term, right Term, symbols *SymbolTable) (Term, error) { +func (Sub) Eval(left Term, right Term, _ *SymbolTable) (Term, error) { ileft, ok := left.(Integer) if !ok { return nil, fmt.Errorf("datalog: Sub requires left value to be an Integer, got %T", left) @@ -682,7 +706,7 @@ type Mul struct{} func (Mul) Type() BinaryOpType { return BinaryMul } -func (Mul) Eval(left Term, right Term, symbols *SymbolTable) (Term, error) { +func (Mul) Eval(left Term, right Term, _ *SymbolTable) (Term, error) { ileft, ok := left.(Integer) if !ok { return nil, fmt.Errorf("datalog: Mul requires left value to be an Integer, got %T", left) @@ -711,7 +735,7 @@ type Div struct{} func (Div) Type() BinaryOpType { return BinaryDiv } -func (Div) Eval(left Term, right Term, symbols *SymbolTable) (Term, error) { +func (Div) Eval(left Term, right Term, _ *SymbolTable) (Term, error) { ileft, ok := left.(Integer) if !ok { return nil, fmt.Errorf("datalog: Div requires left value to be an Integer, got %T", left) @@ -735,7 +759,7 @@ type And struct{} func (And) Type() BinaryOpType { return BinaryAnd } -func (And) Eval(left Term, right Term, symbols *SymbolTable) (Term, error) { +func (And) Eval(left Term, right Term, _ *SymbolTable) (Term, error) { bleft, ok := left.(Bool) if !ok { return nil, fmt.Errorf("datalog: And requires left value to be a Bool, got %T", left) @@ -755,7 +779,7 @@ type Or struct{} func (Or) Type() BinaryOpType { return BinaryOr } -func (Or) Eval(left Term, right Term, symbols *SymbolTable) (Term, error) { +func (Or) Eval(left Term, right Term, _ *SymbolTable) (Term, error) { bleft, ok := left.(Bool) if !ok { return nil, fmt.Errorf("datalog: Or requires left value to be a Bool, got %T", left) diff --git a/datalog/expressions_test.go b/datalog/expressions_test.go index c41ee76..9531e14 100644 --- a/datalog/expressions_test.go +++ b/datalog/expressions_test.go @@ -1188,3 +1188,52 @@ func TestBinaryOr(t *testing.T) { }) } } + +func TestPrint(t *testing.T) { + syms := SymbolTable{} + syms.Insert("abc") + testCases := []struct { + desc string + expr Expression + res string + }{ + { + desc: "number", + expr: Expression{Value{Integer(9)}}, + res: "9", + }, + { + desc: "string", + expr: Expression{Value{syms.Sym("abc")}}, + res: "\"abc\"", + }, + { + desc: "unary", + expr: Expression{Value{syms.Sym("abc")}, UnaryOp{Length{}}}, + res: "\"abc\".length()", + }, + { + desc: "binary", + expr: Expression{Value{Integer(9)}, Value{Integer(4)}, BinaryOp{Mul{}}}, + res: "9 * 4", + }, + { + desc: "parens", + expr: Expression{ + Value{Integer(9)}, + Value{Integer(3)}, + BinaryOp{Add{}}, + UnaryOp{Parens{}}, + Value{Integer(4)}, + BinaryOp{Div{}}, + }, + res: "(9 + 3) / 4", + }, + } + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + p := tc.expr.Print(&syms) + require.Equal(t, tc.res, p) + }) + } +}