Skip to content

Commit

Permalink
Fix type inference for one-liner functions with method syntax (#316)
Browse files Browse the repository at this point in the history
  • Loading branch information
dfellis authored Nov 13, 2020
1 parent 6619066 commit 2e95038
Show file tree
Hide file tree
Showing 4 changed files with 127 additions and 117 deletions.
51 changes: 51 additions & 0 deletions bdd/spec/027_seq_spec.sh
Original file line number Diff line number Diff line change
Expand Up @@ -195,4 +195,55 @@ error: sequence out-of-bounds"
The output should eq "$RECURSEOUTPUT"
End
End

Describe "recurse no-op one-liner regression test"
# Reported issue -- the root cause was due to how the compiler handled one-liner functions
# differently from multi-line functions. This test is to make sure the fix for this doesn't
# regress
before() {
sourceToAll "
import @std/app
from @std/seq import seq, Self, recurse
fn doNothing(x: int) : int = x
fn doNothingRec(x: int) : int = seq(x).recurse(fn (self: Self, x: int) : Result<int> {
return ok(x)
}, x) || 0
on app.start {
const x = 5
app.print(doNothing(x)) // 5
app.print(doNothingRec(x)) // 5
const xs = [1, 2, 3]
app.print(xs.map(doNothing).map(toString).join(' ')) // 1 2 3
app.print(xs.map(doNothingRec).map(toString).join(' ')) // 1 2 3
emit app.exit 0
}
"
}
BeforeAll before

after() {
cleanTemp
}
AfterAll after

ONELINEROUTPUT="5
5
1 2 3
1 2 3"

It "runs js"
When run node temp.js
The output should eq "$ONELINEROUTPUT"
End

It "runs agc"
When run alan run temp.agc
The output should eq "$ONELINEROUTPUT"
End
End
End
27 changes: 7 additions & 20 deletions compiler/src/lntoamm/Microstatement.ts
Original file line number Diff line number Diff line change
Expand Up @@ -893,11 +893,7 @@ ${withOperatorsAst.getText()}`
}
const len = microstatements.length - args.length
for (const s of fn.statements) {
if (s.statementOrAssignableAst instanceof LnParser.StatementsContext) {
Microstatement.fromStatementsAst(s.statementOrAssignableAst, scope, microstatements)
} else {
Microstatement.fromAssignablesAst(s.statementOrAssignableAst, scope, microstatements)
}
Microstatement.fromStatementsAst(s.statementAst, scope, microstatements)
}
microstatements.splice(idx, args.length)
const newlen = microstatements.length
Expand Down Expand Up @@ -1762,25 +1758,16 @@ ${constdeclarationAst.getText()} on line ${constdeclarationAst.start.line}:${con
const newScope = new Scope(statement.scope)
newScope.secondaryPar = secondaryScope
actualStatement = new Statement(
statement.statementOrAssignableAst,
statement.statementAst,
newScope,
statement.pure,
)
}
if (actualStatement.statementOrAssignableAst instanceof LnParser.StatementsContext) {
Microstatement.fromStatementsAst(
actualStatement.statementOrAssignableAst,
actualStatement.scope,
microstatements
)
} else {
// Otherwise it's a one-liner function
Microstatement.fromAssignablesAst(
actualStatement.statementOrAssignableAst,
actualStatement.scope,
microstatements
)
}
Microstatement.fromStatementsAst(
actualStatement.statementAst,
actualStatement.scope,
microstatements
)
}
}

Expand Down
121 changes: 53 additions & 68 deletions compiler/src/lntoamm/Statement.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,22 @@ import { LnParser, } from '../ln'

// Only implements the pieces necessary for the first stage compiler
class Statement {
statementOrAssignableAst: any // TODO: Migrate off ANTLR for better typing here
statementAst: any // TODO: Migrate off ANTLR for better typing here
scope: Scope
pure: boolean

constructor(statementOrAssignableAst: any, scope: Scope, pure: boolean) {
this.statementOrAssignableAst = statementOrAssignableAst,
constructor(statementAst: any, scope: Scope, pure: boolean) {
this.statementAst = statementAst,
this.scope = scope
this.pure = pure
}

isConditionalStatement() {
return this.statementOrAssignableAst instanceof LnParser.StatementsContext &&
this.statementOrAssignableAst.conditionals() !== null
return this.statementAst.conditionals() !== null
}

isReturnStatement() {
return this.statementOrAssignableAst instanceof LnParser.AssignablesContext ||
this.statementOrAssignableAst.exits() !== null
return this.statementAst.exits() !== null
}

static basicAssignableHasObjectLiteral(basicAssignableAst: any) { // TODO: Remove ANTLR
Expand All @@ -45,27 +43,23 @@ class Statement {
}

hasObjectLiteral() {
if (this.statementOrAssignableAst instanceof LnParser.StatementsContext) {
const s = this.statementOrAssignableAst
if (s.declarations()) {
const d = s.declarations().constdeclaration() || s.declarations().letdeclaration()
return Statement.assignablesHasObjectLiteral(d.assignables())
}
if (s.assignments()) return Statement.assignmentsHasObjectLiteral(s.assignments())
if (s.calls() && s.calls().assignables() > 0) s.calls().assignables().some(
(a: any) => Statement.assignablesHasObjectLiteral(a)
)
if (s.exits() && s.exits().assignables()) return Statement.assignablesHasObjectLiteral(
s.exits().assignables()
)
if (s.emits() && s.emits().assignables()) return Statement.assignablesHasObjectLiteral(
s.emits().assignables()
)
// TODO: Cover conditionals
return false
} else {
return Statement.assignablesHasObjectLiteral(this.statementOrAssignableAst)
const s = this.statementAst
if (s.declarations()) {
const d = s.declarations().constdeclaration() || s.declarations().letdeclaration()
return Statement.assignablesHasObjectLiteral(d.assignables())
}
if (s.assignments()) return Statement.assignmentsHasObjectLiteral(s.assignments())
if (s.calls() && s.calls().assignables() > 0) s.calls().assignables().some(
(a: any) => Statement.assignablesHasObjectLiteral(a)
)
if (s.exits() && s.exits().assignables()) return Statement.assignablesHasObjectLiteral(
s.exits().assignables()
)
if (s.emits() && s.emits().assignables()) return Statement.assignablesHasObjectLiteral(
s.emits().assignables()
)
// TODO: Cover conditionals
return false
}

static isCallPure(callAst: any, scope: Scope) { // TODO: Migrate off ANTLR
Expand Down Expand Up @@ -151,59 +145,50 @@ class Statement {
throw new Error("Impossible assignment situation")
}

static create(statementOrAssignableAst: any, scope: Scope) { // TODO: Migrate off ANTLR
if (statementOrAssignableAst instanceof LnParser.AssignablesContext) {
const pure = Statement.isAssignablePure(statementOrAssignableAst, scope)
return new Statement(statementOrAssignableAst, scope, pure)
} else if (statementOrAssignableAst instanceof LnParser.StatementsContext) {
const statementAst = statementOrAssignableAst
let pure = true
if (statementAst.declarations() != null) {
if (statementAst.declarations().constdeclaration() != null) {
static create(statementAst: any, scope: Scope) { // TODO: Migrate off ANTLR
let pure = true
if (statementAst.declarations() != null) {
if (statementAst.declarations().constdeclaration() != null) {
pure = Statement.isAssignablePure(
statementAst.declarations().constdeclaration().assignables(),
scope
)
} else if (statementAst.declarations().letdeclaration() != null) {
if (statementAst.declarations().letdeclaration().assignables() == null) {
pure = true
} else {
pure = Statement.isAssignablePure(
statementAst.declarations().constdeclaration().assignables(),
statementAst.declarations().letdeclaration().assignables(),
scope
)
} else if (statementAst.declarations().letdeclaration() != null) {
if (statementAst.declarations().letdeclaration().assignables() == null) {
pure = true
} else {
pure = Statement.isAssignablePure(
statementAst.declarations().letdeclaration().assignables(),
scope
)
}
} else {
throw new Error("Bad assignment somehow reached")
}
} else {
throw new Error("Bad assignment somehow reached")
}
if (statementAst.assignments() != null) {
if (statementAst.assignments().assignables() != null) {
pure = Statement.isAssignablePure(statementAst.assignments().assignables(), scope)
}
}
if (statementAst.calls() != null) {
pure = Statement.isCallPure(statementAst.calls(), scope)
}
if (statementAst.assignments() != null) {
if (statementAst.assignments().assignables() != null) {
pure = Statement.isAssignablePure(statementAst.assignments().assignables(), scope)
}
if (statementAst.exits() != null) {
if (statementAst.exits().assignables() != null) {
pure = Statement.isAssignablePure(statementAst.exits().assignables(), scope)
}
}
if (statementAst.calls() != null) {
pure = Statement.isCallPure(statementAst.calls(), scope)
}
if (statementAst.exits() != null) {
if (statementAst.exits().assignables() != null) {
pure = Statement.isAssignablePure(statementAst.exits().assignables(), scope)
}
if (statementAst.emits() != null) {
if (statementAst.emits().assignables() != null) {
pure = Statement.isAssignablePure(statementAst.emits().assignables(), scope)
}
}
if (statementAst.emits() != null) {
if (statementAst.emits().assignables() != null) {
pure = Statement.isAssignablePure(statementAst.emits().assignables(), scope)
}
return new Statement(statementAst, scope, pure)
} else {
// What?
throw new Error("This should not be possible")
}
return new Statement(statementAst, scope, pure)
}

toString() {
return this.statementOrAssignableAst.getText()
return this.statementAst.getText()
}
}

Expand Down
45 changes: 16 additions & 29 deletions compiler/src/lntoamm/UserFunction.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class UserFunction implements Fn {
if (statements[i].isReturnStatement()) {
// There are unreachable statements after this line, abort
throw new Error(`Unreachable code in function '${name}' after:
${statements[i].statementOrAssignableAst.getText().trim()} on line ${statements[i].statementOrAssignableAst.start.line}:${statements[i].statementOrAssignableAst.start.column}`)
${statements[i].statementAst.getText().trim()} on line ${statements[i].statementAst.start.line}:${statements[i].statementAst.start.column}`)
}
}
this.statements = statements
Expand Down Expand Up @@ -148,7 +148,8 @@ ${statements[i].statementOrAssignableAst.getText().trim()} on line ${statements[
}
} else {
const assignablesAst = functionAst.fullfunctionbody().assignables()
let statement = Statement.create(assignablesAst, scope)
const statementAst = Ast.statementAstFromString(`return ${assignablesAst.getText()}\n`)
const statement = Statement.create(statementAst, scope)
if (!statement.pure) pure = false
statements.push(statement)
// TODO: Infer the return type for anything other than calls or object literals
Expand Down Expand Up @@ -253,17 +254,9 @@ ${statements[i].statementOrAssignableAst.getText().trim()} on line ${statements[
}

toFnStr() {
if (
this.statements.length === 1 &&
this.statements[0].statementOrAssignableAst instanceof LnParser.AssignablesContext
) {
return `
fn ${this.name || ''} (${Object.keys(this.args).map(argName => `${argName}: ${this.args[argName].typename}`).join(', ')}): ${this.returnType.typename} = ${(this.statements[0].statementOrAssignableAst as any).getText()}
`.trim()
}
return `
fn ${this.name || ''} (${Object.keys(this.args).map(argName => `${argName}: ${this.args[argName].typename}`).join(', ')}): ${this.returnType.typename} {
${this.statements.map(s => s.statementOrAssignableAst.getText()).join('\n')}
${this.statements.map(s => s.statementAst.getText()).join('\n')}
}
`.trim()
}
Expand Down Expand Up @@ -336,7 +329,7 @@ ${statements[i].statementOrAssignableAst.getText().trim()} on line ${statements[
const block = args.assignables(1).basicassignables().functions()
const blockFn = UserFunction.fromAst(block, scope)
if (blockFn.statements[blockFn.statements.length - 1].isReturnStatement()) {
const innerStatements = blockFn.statements.map(s => s.statementOrAssignableAst)
const innerStatements = blockFn.statements.map(s => s.statementAst)
const newBlockStatements = UserFunction.earlyReturnRewrite(
retVal, retNotSet, innerStatements, scope
)
Expand Down Expand Up @@ -394,13 +387,13 @@ ${statements[i].statementOrAssignableAst.getText().trim()} on line ${statements[
let hasConditionalReturn = false // Flag for potential second pass
for (let i = 0; i < this.statements.length; i++) {
let s = new Statement(
this.statements[i].statementOrAssignableAst,
this.statements[i].statementAst,
this.statements[i].scope,
this.statements[i].pure,
)
// Potentially rewrite the type for the object literal to match the interface type used by
// a specific call
const str = s.statementOrAssignableAst.getText()
const str = s.statementAst.getText()
const corrected = str.replace(/new ([^<]+)<([^{\[]+)> *([{\[])/g, (
_: any,
basetypestr: string,
Expand Down Expand Up @@ -445,29 +438,23 @@ ${statements[i].statementOrAssignableAst.getText().trim()} on line ${statements[
const replacementType = originalType.realize(interfaceMap, this.scope)
return `: ${replacementType.typename}${openstr}`
})
if (s.statementOrAssignableAst instanceof LnParser.AssignablesContext) {
const correctedAst = Ast.statementAstFromString(`return ${secondCorrection}\n`)
s.statementOrAssignableAst = correctedAst
// statementAsts.push(correctedAst)
} else {
const correctedAst = Ast.statementAstFromString(secondCorrection)
s.statementOrAssignableAst = correctedAst
// statementAsts.push(correctedAst)
}
const correctedAst = Ast.statementAstFromString(secondCorrection)
s.statementAst = correctedAst
// statementAsts.push(correctedAst)
if (s.isConditionalStatement()) {
const cond = s.statementOrAssignableAst.conditionals()
const cond = s.statementAst.conditionals()
const res = UserFunction.conditionalToCond(cond, this.scope)
const newStatements = res[0] as Array<any>
if (res[1]) hasConditionalReturn = true
statementAsts.push(...newStatements)
} else if (s.statementOrAssignableAst instanceof LnParser.AssignmentsContext) {
const a = s.statementOrAssignableAst
} else if (s.statementAst instanceof LnParser.AssignmentsContext) {
const a = s.statementAst
const wrappedAst = Ast.statementAstFromString(`
${a.varn().getText()} = ref(${a.assignables().getText()})
`.trim() + '\n')
statementAsts.push(wrappedAst)
} else if (s.statementOrAssignableAst instanceof LnParser.LetdeclarationContext) {
const l = s.statementOrAssignableAst
} else if (s.statementAst instanceof LnParser.LetdeclarationContext) {
const l = s.statementAst
const name = l.VARNAME().getText()
const type = l.othertype() ? l.othertype().getText() : undefined
const v = l.assignables().getText()
Expand All @@ -476,7 +463,7 @@ ${statements[i].statementOrAssignableAst.getText().trim()} on line ${statements[
`.trim() + '\n')
statementAsts.push(wrappedAst)
} else {
statementAsts.push(s.statementOrAssignableAst)
statementAsts.push(s.statementAst)
}
}
// Second pass, there was a conditional return, mutate everything *again* so the return is
Expand Down

0 comments on commit 2e95038

Please sign in to comment.