diff --git a/bdd/spec/027_seq_spec.sh b/bdd/spec/027_seq_spec.sh index aaf3798ea..d50d2a885 100644 --- a/bdd/spec/027_seq_spec.sh +++ b/bdd/spec/027_seq_spec.sh @@ -195,4 +195,55 @@ error: sequence out-of-bounds" The output should eq "$RECURSEOUTPUT" End End + + Describe "recurse no-op one-liner regression test" + # Reported issue -- the root cause was due to how the compiler handled one-liner functions + # differently from multi-line functions. This test is to make sure the fix for this doesn't + # regress + before() { + sourceToAll " + import @std/app + from @std/seq import seq, Self, recurse + + fn doNothing(x: int) : int = x + + fn doNothingRec(x: int) : int = seq(x).recurse(fn (self: Self, x: int) : Result { + return ok(x) + }, x) || 0 + + on app.start { + const x = 5 + app.print(doNothing(x)) // 5 + app.print(doNothingRec(x)) // 5 + + const xs = [1, 2, 3] + app.print(xs.map(doNothing).map(toString).join(' ')) // 1 2 3 + app.print(xs.map(doNothingRec).map(toString).join(' ')) // 1 2 3 + + emit app.exit 0 + } + " + } + BeforeAll before + + after() { + cleanTemp + } + AfterAll after + + ONELINEROUTPUT="5 +5 +1 2 3 +1 2 3" + + It "runs js" + When run node temp.js + The output should eq "$ONELINEROUTPUT" + End + + It "runs agc" + When run alan run temp.agc + The output should eq "$ONELINEROUTPUT" + End + End End \ No newline at end of file diff --git a/compiler/src/lntoamm/Microstatement.ts b/compiler/src/lntoamm/Microstatement.ts index 6f9113b3d..032727273 100644 --- a/compiler/src/lntoamm/Microstatement.ts +++ b/compiler/src/lntoamm/Microstatement.ts @@ -893,11 +893,7 @@ ${withOperatorsAst.getText()}` } const len = microstatements.length - args.length for (const s of fn.statements) { - if (s.statementOrAssignableAst instanceof LnParser.StatementsContext) { - Microstatement.fromStatementsAst(s.statementOrAssignableAst, scope, microstatements) - } else { - Microstatement.fromAssignablesAst(s.statementOrAssignableAst, scope, microstatements) - } + Microstatement.fromStatementsAst(s.statementAst, scope, microstatements) } microstatements.splice(idx, args.length) const newlen = microstatements.length @@ -1762,25 +1758,16 @@ ${constdeclarationAst.getText()} on line ${constdeclarationAst.start.line}:${con const newScope = new Scope(statement.scope) newScope.secondaryPar = secondaryScope actualStatement = new Statement( - statement.statementOrAssignableAst, + statement.statementAst, newScope, statement.pure, ) } - if (actualStatement.statementOrAssignableAst instanceof LnParser.StatementsContext) { - Microstatement.fromStatementsAst( - actualStatement.statementOrAssignableAst, - actualStatement.scope, - microstatements - ) - } else { - // Otherwise it's a one-liner function - Microstatement.fromAssignablesAst( - actualStatement.statementOrAssignableAst, - actualStatement.scope, - microstatements - ) - } + Microstatement.fromStatementsAst( + actualStatement.statementAst, + actualStatement.scope, + microstatements + ) } } diff --git a/compiler/src/lntoamm/Statement.ts b/compiler/src/lntoamm/Statement.ts index eedcd8657..912dec4b4 100644 --- a/compiler/src/lntoamm/Statement.ts +++ b/compiler/src/lntoamm/Statement.ts @@ -5,24 +5,22 @@ import { LnParser, } from '../ln' // Only implements the pieces necessary for the first stage compiler class Statement { - statementOrAssignableAst: any // TODO: Migrate off ANTLR for better typing here + statementAst: any // TODO: Migrate off ANTLR for better typing here scope: Scope pure: boolean - constructor(statementOrAssignableAst: any, scope: Scope, pure: boolean) { - this.statementOrAssignableAst = statementOrAssignableAst, + constructor(statementAst: any, scope: Scope, pure: boolean) { + this.statementAst = statementAst, this.scope = scope this.pure = pure } isConditionalStatement() { - return this.statementOrAssignableAst instanceof LnParser.StatementsContext && - this.statementOrAssignableAst.conditionals() !== null + return this.statementAst.conditionals() !== null } isReturnStatement() { - return this.statementOrAssignableAst instanceof LnParser.AssignablesContext || - this.statementOrAssignableAst.exits() !== null + return this.statementAst.exits() !== null } static basicAssignableHasObjectLiteral(basicAssignableAst: any) { // TODO: Remove ANTLR @@ -45,27 +43,23 @@ class Statement { } hasObjectLiteral() { - if (this.statementOrAssignableAst instanceof LnParser.StatementsContext) { - const s = this.statementOrAssignableAst - if (s.declarations()) { - const d = s.declarations().constdeclaration() || s.declarations().letdeclaration() - return Statement.assignablesHasObjectLiteral(d.assignables()) - } - if (s.assignments()) return Statement.assignmentsHasObjectLiteral(s.assignments()) - if (s.calls() && s.calls().assignables() > 0) s.calls().assignables().some( - (a: any) => Statement.assignablesHasObjectLiteral(a) - ) - if (s.exits() && s.exits().assignables()) return Statement.assignablesHasObjectLiteral( - s.exits().assignables() - ) - if (s.emits() && s.emits().assignables()) return Statement.assignablesHasObjectLiteral( - s.emits().assignables() - ) - // TODO: Cover conditionals - return false - } else { - return Statement.assignablesHasObjectLiteral(this.statementOrAssignableAst) + const s = this.statementAst + if (s.declarations()) { + const d = s.declarations().constdeclaration() || s.declarations().letdeclaration() + return Statement.assignablesHasObjectLiteral(d.assignables()) } + if (s.assignments()) return Statement.assignmentsHasObjectLiteral(s.assignments()) + if (s.calls() && s.calls().assignables() > 0) s.calls().assignables().some( + (a: any) => Statement.assignablesHasObjectLiteral(a) + ) + if (s.exits() && s.exits().assignables()) return Statement.assignablesHasObjectLiteral( + s.exits().assignables() + ) + if (s.emits() && s.emits().assignables()) return Statement.assignablesHasObjectLiteral( + s.emits().assignables() + ) + // TODO: Cover conditionals + return false } static isCallPure(callAst: any, scope: Scope) { // TODO: Migrate off ANTLR @@ -151,59 +145,50 @@ class Statement { throw new Error("Impossible assignment situation") } - static create(statementOrAssignableAst: any, scope: Scope) { // TODO: Migrate off ANTLR - if (statementOrAssignableAst instanceof LnParser.AssignablesContext) { - const pure = Statement.isAssignablePure(statementOrAssignableAst, scope) - return new Statement(statementOrAssignableAst, scope, pure) - } else if (statementOrAssignableAst instanceof LnParser.StatementsContext) { - const statementAst = statementOrAssignableAst - let pure = true - if (statementAst.declarations() != null) { - if (statementAst.declarations().constdeclaration() != null) { + static create(statementAst: any, scope: Scope) { // TODO: Migrate off ANTLR + let pure = true + if (statementAst.declarations() != null) { + if (statementAst.declarations().constdeclaration() != null) { + pure = Statement.isAssignablePure( + statementAst.declarations().constdeclaration().assignables(), + scope + ) + } else if (statementAst.declarations().letdeclaration() != null) { + if (statementAst.declarations().letdeclaration().assignables() == null) { + pure = true + } else { pure = Statement.isAssignablePure( - statementAst.declarations().constdeclaration().assignables(), + statementAst.declarations().letdeclaration().assignables(), scope ) - } else if (statementAst.declarations().letdeclaration() != null) { - if (statementAst.declarations().letdeclaration().assignables() == null) { - pure = true - } else { - pure = Statement.isAssignablePure( - statementAst.declarations().letdeclaration().assignables(), - scope - ) - } - } else { - throw new Error("Bad assignment somehow reached") } + } else { + throw new Error("Bad assignment somehow reached") } - if (statementAst.assignments() != null) { - if (statementAst.assignments().assignables() != null) { - pure = Statement.isAssignablePure(statementAst.assignments().assignables(), scope) - } - } - if (statementAst.calls() != null) { - pure = Statement.isCallPure(statementAst.calls(), scope) + } + if (statementAst.assignments() != null) { + if (statementAst.assignments().assignables() != null) { + pure = Statement.isAssignablePure(statementAst.assignments().assignables(), scope) } - if (statementAst.exits() != null) { - if (statementAst.exits().assignables() != null) { - pure = Statement.isAssignablePure(statementAst.exits().assignables(), scope) - } + } + if (statementAst.calls() != null) { + pure = Statement.isCallPure(statementAst.calls(), scope) + } + if (statementAst.exits() != null) { + if (statementAst.exits().assignables() != null) { + pure = Statement.isAssignablePure(statementAst.exits().assignables(), scope) } - if (statementAst.emits() != null) { - if (statementAst.emits().assignables() != null) { - pure = Statement.isAssignablePure(statementAst.emits().assignables(), scope) - } + } + if (statementAst.emits() != null) { + if (statementAst.emits().assignables() != null) { + pure = Statement.isAssignablePure(statementAst.emits().assignables(), scope) } - return new Statement(statementAst, scope, pure) - } else { - // What? - throw new Error("This should not be possible") } + return new Statement(statementAst, scope, pure) } toString() { - return this.statementOrAssignableAst.getText() + return this.statementAst.getText() } } diff --git a/compiler/src/lntoamm/UserFunction.ts b/compiler/src/lntoamm/UserFunction.ts index bf6b5376d..4ffd00922 100644 --- a/compiler/src/lntoamm/UserFunction.ts +++ b/compiler/src/lntoamm/UserFunction.ts @@ -33,7 +33,7 @@ class UserFunction implements Fn { if (statements[i].isReturnStatement()) { // There are unreachable statements after this line, abort throw new Error(`Unreachable code in function '${name}' after: -${statements[i].statementOrAssignableAst.getText().trim()} on line ${statements[i].statementOrAssignableAst.start.line}:${statements[i].statementOrAssignableAst.start.column}`) +${statements[i].statementAst.getText().trim()} on line ${statements[i].statementAst.start.line}:${statements[i].statementAst.start.column}`) } } this.statements = statements @@ -148,7 +148,8 @@ ${statements[i].statementOrAssignableAst.getText().trim()} on line ${statements[ } } else { const assignablesAst = functionAst.fullfunctionbody().assignables() - let statement = Statement.create(assignablesAst, scope) + const statementAst = Ast.statementAstFromString(`return ${assignablesAst.getText()}\n`) + const statement = Statement.create(statementAst, scope) if (!statement.pure) pure = false statements.push(statement) // TODO: Infer the return type for anything other than calls or object literals @@ -253,17 +254,9 @@ ${statements[i].statementOrAssignableAst.getText().trim()} on line ${statements[ } toFnStr() { - if ( - this.statements.length === 1 && - this.statements[0].statementOrAssignableAst instanceof LnParser.AssignablesContext - ) { - return ` - fn ${this.name || ''} (${Object.keys(this.args).map(argName => `${argName}: ${this.args[argName].typename}`).join(', ')}): ${this.returnType.typename} = ${(this.statements[0].statementOrAssignableAst as any).getText()} - `.trim() - } return ` fn ${this.name || ''} (${Object.keys(this.args).map(argName => `${argName}: ${this.args[argName].typename}`).join(', ')}): ${this.returnType.typename} { - ${this.statements.map(s => s.statementOrAssignableAst.getText()).join('\n')} + ${this.statements.map(s => s.statementAst.getText()).join('\n')} } `.trim() } @@ -336,7 +329,7 @@ ${statements[i].statementOrAssignableAst.getText().trim()} on line ${statements[ const block = args.assignables(1).basicassignables().functions() const blockFn = UserFunction.fromAst(block, scope) if (blockFn.statements[blockFn.statements.length - 1].isReturnStatement()) { - const innerStatements = blockFn.statements.map(s => s.statementOrAssignableAst) + const innerStatements = blockFn.statements.map(s => s.statementAst) const newBlockStatements = UserFunction.earlyReturnRewrite( retVal, retNotSet, innerStatements, scope ) @@ -394,13 +387,13 @@ ${statements[i].statementOrAssignableAst.getText().trim()} on line ${statements[ let hasConditionalReturn = false // Flag for potential second pass for (let i = 0; i < this.statements.length; i++) { let s = new Statement( - this.statements[i].statementOrAssignableAst, + this.statements[i].statementAst, this.statements[i].scope, this.statements[i].pure, ) // Potentially rewrite the type for the object literal to match the interface type used by // a specific call - const str = s.statementOrAssignableAst.getText() + const str = s.statementAst.getText() const corrected = str.replace(/new ([^<]+)<([^{\[]+)> *([{\[])/g, ( _: any, basetypestr: string, @@ -445,29 +438,23 @@ ${statements[i].statementOrAssignableAst.getText().trim()} on line ${statements[ const replacementType = originalType.realize(interfaceMap, this.scope) return `: ${replacementType.typename}${openstr}` }) - if (s.statementOrAssignableAst instanceof LnParser.AssignablesContext) { - const correctedAst = Ast.statementAstFromString(`return ${secondCorrection}\n`) - s.statementOrAssignableAst = correctedAst - // statementAsts.push(correctedAst) - } else { - const correctedAst = Ast.statementAstFromString(secondCorrection) - s.statementOrAssignableAst = correctedAst - // statementAsts.push(correctedAst) - } + const correctedAst = Ast.statementAstFromString(secondCorrection) + s.statementAst = correctedAst + // statementAsts.push(correctedAst) if (s.isConditionalStatement()) { - const cond = s.statementOrAssignableAst.conditionals() + const cond = s.statementAst.conditionals() const res = UserFunction.conditionalToCond(cond, this.scope) const newStatements = res[0] as Array if (res[1]) hasConditionalReturn = true statementAsts.push(...newStatements) - } else if (s.statementOrAssignableAst instanceof LnParser.AssignmentsContext) { - const a = s.statementOrAssignableAst + } else if (s.statementAst instanceof LnParser.AssignmentsContext) { + const a = s.statementAst const wrappedAst = Ast.statementAstFromString(` ${a.varn().getText()} = ref(${a.assignables().getText()}) `.trim() + '\n') statementAsts.push(wrappedAst) - } else if (s.statementOrAssignableAst instanceof LnParser.LetdeclarationContext) { - const l = s.statementOrAssignableAst + } else if (s.statementAst instanceof LnParser.LetdeclarationContext) { + const l = s.statementAst const name = l.VARNAME().getText() const type = l.othertype() ? l.othertype().getText() : undefined const v = l.assignables().getText() @@ -476,7 +463,7 @@ ${statements[i].statementOrAssignableAst.getText().trim()} on line ${statements[ `.trim() + '\n') statementAsts.push(wrappedAst) } else { - statementAsts.push(s.statementOrAssignableAst) + statementAsts.push(s.statementAst) } } // Second pass, there was a conditional return, mutate everything *again* so the return is