diff --git a/compiler/src/main/cup/io/github/aplcornell/viaduct/syntax/circuit/CircuitParser.cup b/compiler/src/main/cup/io/github/aplcornell/viaduct/syntax/circuit/CircuitParser.cup index 33fbfd3cc8..f9a7fc177e 100755 --- a/compiler/src/main/cup/io/github/aplcornell/viaduct/syntax/circuit/CircuitParser.cup +++ b/compiler/src/main/cup/io/github/aplcornell/viaduct/syntax/circuit/CircuitParser.cup @@ -144,7 +144,7 @@ terminal COLONCOLON; terminal PERIOD; -terminal IF, ELSE, WHILE, LOOP, BREAK; +terminal IF, ELSE, LOOP, BREAK; terminal INPUT, FROM, OUTPUT, TO; terminal REDUCE; diff --git a/compiler/src/main/jflex/io/github/aplcornell/viaduct/syntax/circuit/CircuitLexer.jflex b/compiler/src/main/jflex/io/github/aplcornell/viaduct/syntax/circuit/CircuitLexer.jflex index 8d4d3ce970..7ebe9f9c88 100644 --- a/compiler/src/main/jflex/io/github/aplcornell/viaduct/syntax/circuit/CircuitLexer.jflex +++ b/compiler/src/main/jflex/io/github/aplcornell/viaduct/syntax/circuit/CircuitLexer.jflex @@ -110,6 +110,11 @@ NUM = ((-)?[1-9][0-9]*) | 0 "input" { return symbol(sym.INPUT); } "output" { return symbol(sym.OUTPUT); } + "if" { return symbol(sym.IF); } + "else" { return symbol(sym.ELSE); } + "loop" { return symbol(sym.LOOP); } + "break" { return symbol(sym.BREAK); } + /* Expressions */ "." { return symbol(sym.PERIOD); } diff --git a/compiler/src/main/kotlin/io/github/aplcornell/viaduct/circuitcodegeneration/BackendCodeGenerator.kt b/compiler/src/main/kotlin/io/github/aplcornell/viaduct/circuitcodegeneration/BackendCodeGenerator.kt index 35b66051a5..9c3b1329e6 100644 --- a/compiler/src/main/kotlin/io/github/aplcornell/viaduct/circuitcodegeneration/BackendCodeGenerator.kt +++ b/compiler/src/main/kotlin/io/github/aplcornell/viaduct/circuitcodegeneration/BackendCodeGenerator.kt @@ -20,13 +20,17 @@ import io.github.aplcornell.viaduct.runtime.Out import io.github.aplcornell.viaduct.runtime.ViaductGeneratedProgram import io.github.aplcornell.viaduct.runtime.ViaductRuntime import io.github.aplcornell.viaduct.syntax.Host +import io.github.aplcornell.viaduct.syntax.circuit.BreakNode import io.github.aplcornell.viaduct.syntax.circuit.CircuitCallNode import io.github.aplcornell.viaduct.syntax.circuit.CircuitDeclarationNode import io.github.aplcornell.viaduct.syntax.circuit.FunctionDeclarationNode +import io.github.aplcornell.viaduct.syntax.circuit.IfNode import io.github.aplcornell.viaduct.syntax.circuit.InputNode import io.github.aplcornell.viaduct.syntax.circuit.LetNode +import io.github.aplcornell.viaduct.syntax.circuit.LoopNode import io.github.aplcornell.viaduct.syntax.circuit.OutputNode import io.github.aplcornell.viaduct.syntax.circuit.ProgramNode +import io.github.aplcornell.viaduct.syntax.circuit.StatementNode import io.github.aplcornell.viaduct.syntax.circuit.Variable import io.github.aplcornell.viaduct.syntax.circuit.VariableBindingNode import io.github.aplcornell.viaduct.syntax.types.ValueType @@ -76,159 +80,191 @@ private class BackendCodeGenerator( fun generate(functionDeclaration: FunctionDeclarationNode): FunSpec { val builder = FunSpec.builder(functionDeclaration.name.value.name) for (stmt in functionDeclaration.body.statements) { - when (stmt) { - is LetNode -> { - when (val command = stmt.command) { - is CircuitCallNode -> { - val circuitDecl: CircuitDeclarationNode = nameAnalysis.declaration(command) - val circuitHosts = circuitDecl.protocol.value.hosts - - val importingHosts = command.inputs.map { - (nameAnalysis.declaration(it) as VariableBindingNode).protocol.value.hosts - }.flatten().toSet() + circuitHosts - val exportingHosts = stmt.bindings.map { - it.protocol.value.hosts - }.flatten().toSet() + circuitHosts - - if (context.host !in importingHosts + exportingHosts) continue - - // Forward declare bindings - stmt.bindings.zip(circuitDecl.outputs.map { it.type }).forEach { (binding, type) -> - if (context.host in binding.protocol.value.hosts) { - builder.addStatement( - "val %N: %T", - context.kotlinName(binding.name.value), - kotlinType( - type.shape, - codeGenerator.storageType(binding.protocol.value, type.elementType.value), - ), - ) - } - } - - builder.beginControlFlow("run") + generate(stmt, builder) + } + return builder.build() + } - // Declare size parameters - circuitDecl.sizes.zip(command.bounds) { sizeParam, sizeArg -> + fun generate(statement: StatementNode, builder: FunSpec.Builder) { + when (statement) { + is LetNode -> { + when (val command = statement.command) { + is CircuitCallNode -> { + val circuitDecl: CircuitDeclarationNode = nameAnalysis.declaration(command) + val circuitHosts = circuitDecl.protocol.value.hosts + + val importingHosts = command.inputs.map { + (nameAnalysis.declaration(it) as VariableBindingNode).protocol.value.hosts + }.flatten().toSet() + circuitHosts + val exportingHosts = statement.bindings.map { + it.protocol.value.hosts + }.flatten().toSet() + circuitHosts + + if (context.host !in importingHosts + exportingHosts) return + + // Forward declare bindings + statement.bindings.zip(circuitDecl.outputs.map { it.type }).forEach { (binding, type) -> + if (context.host in binding.protocol.value.hosts) { builder.addStatement( - "val %N = %L", - context.kotlinName(sizeParam.name.value), - indexExpression(sizeArg, context), + "val %N: %T", + context.kotlinName(binding.name.value), + kotlinType( + type.shape, + codeGenerator.storageType(binding.protocol.value, type.elementType.value), + ), ) } + } - val inputs = if (context.host in importingHosts) { - val (importCode, inputs) = codeGenerator.import( - circuitDecl.protocol.value, - command.inputs.zip(circuitDecl.inputs).map { (arg, inParam) -> - Argument( - indexExpression(arg, context), - inParam.type, - (nameAnalysis.declaration(arg) as VariableBindingNode).protocol.value, - arg.sourceLocation, - ) - }, - ) - builder.addCode(importCode) - inputs - } else { - null - } + builder.beginControlFlow("run") - val outTmps = circuitDecl.outputs.map { - val tmp = context.newTemporary(it.name.value.name) - CodeBlock.of("%N", tmp) - } + // Declare size parameters + circuitDecl.sizes.zip(command.bounds) { sizeParam, sizeArg -> + builder.addStatement( + "val %N = %L", + context.kotlinName(sizeParam.name.value), + indexExpression(sizeArg, context), + ) + } - if (context.host in circuitHosts) { - val outNames = circuitDecl.outputs.associateWith { outParam -> - val outName = - context.newTemporary(context.kotlinName(outParam.name.value) + "_boxed") - builder.addStatement( - "val %L = %T()", - outName, - Out::class.asClassName().parameterizedBy( - kotlinType( - outParam.type.shape, - codeGenerator.paramType( - circuitDecl.protocol.value, - outParam.type.elementType.value, - ), - ), - ), + val inputs = if (context.host in importingHosts) { + val (importCode, inputs) = codeGenerator.import( + circuitDecl.protocol.value, + command.inputs.zip(circuitDecl.inputs).map { (arg, inParam) -> + Argument( + indexExpression(arg, context), + inParam.type, + (nameAnalysis.declaration(arg) as VariableBindingNode).protocol.value, + arg.sourceLocation, ) - outName - } - builder.addStatement( - "%N(%L)", - command.name.value.name, - ( - (command.bounds).map { indexExpression(it, context) } + inputs!! + - circuitDecl.outputs.map { CodeBlock.of("%N", outNames[it]) } - ).joinToCode(), - ) - circuitDecl.outputs.forEachIndexed { index, param -> - builder.addStatement("val %L = %N.get()", outTmps[index], outNames[param]!!) - } - } + }, + ) + builder.addCode(importCode) + inputs + } else { + null + } - if (context.host in exportingHosts) { - val (exportCode, outputs) = codeGenerator.export( - circuitDecl.protocol.value, - stmt.bindings.mapIndexed { index, binding -> - Argument( - outTmps[index], - circuitDecl.outputs[index].type, - binding.protocol.value, - binding.sourceLocation, - ) - }, - ) - builder.addCode(exportCode) - // Bind results - stmt.bindings.zip(outputs).forEach { (binding, output) -> - if (context.host in binding.protocol.value.hosts) { - builder.addStatement( - "%N = %L", - context.kotlinName(binding.name.value), - output, - ) - } - } - } - builder.endControlFlow() + val outTmps = circuitDecl.outputs.map { + val tmp = context.newTemporary(it.name.value.name) + CodeBlock.of("%N", tmp) } - is InputNode -> { - if (command.host.value != context.host) continue - val unnamedIndices = command.type.shape.map { CodeBlock.of("_") } + if (context.host in circuitHosts) { + val outNames = circuitDecl.outputs.associateWith { outParam -> + val outName = context.newTemporary(context.kotlinName(outParam.name.value) + "_boxed") + builder.addStatement( + "val %L = %T()", + outName, + Out::class.asClassName().parameterizedBy( + kotlinType( + outParam.type.shape, + codeGenerator.paramType( + circuitDecl.protocol.value, + outParam.type.elementType.value, + ), + ), + ), + ) + outName + } builder.addStatement( - "val %N = %L", - context.kotlinName(stmt.bindings[0].name.value), - command.type.shape.new( - context, - unnamedIndices, - context.input(command.type.elementType.value), - ), + "%N(%L)", + command.name.value.name, + ((command.bounds).map { + indexExpression( + it, + context + ) + } + inputs!! + circuitDecl.outputs.map { + CodeBlock.of( + "%N", + outNames[it] + ) + }).joinToCode(), ) + circuitDecl.outputs.forEachIndexed { index, param -> + builder.addStatement("val %L = %N.get()", outTmps[index], outNames[param]!!) + } } - is OutputNode -> { - if (command.host.value != context.host) continue - builder.addCode( - indexExpression(command.message, context) - .forEachIndexed(command.type.shape, context) { _, value -> - context.output(value, command.type.elementType.value) - }, + if (context.host in exportingHosts) { + val (exportCode, outputs) = codeGenerator.export( + circuitDecl.protocol.value, + statement.bindings.mapIndexed { index, binding -> + Argument( + outTmps[index], + circuitDecl.outputs[index].type, + binding.protocol.value, + binding.sourceLocation, + ) + }, ) + builder.addCode(exportCode) + // Bind results + statement.bindings.zip(outputs).forEach { (binding, output) -> + if (context.host in binding.protocol.value.hosts) { + builder.addStatement( + "%N = %L", + context.kotlinName(binding.name.value), + output, + ) + } + } } + builder.endControlFlow() } - } - else -> throw UnsupportedOperationException("Incorrect statement type in function body") + is InputNode -> { + if (command.host.value != context.host) return + val unnamedIndices = command.type.shape.map { CodeBlock.of("_") } + builder.addStatement( + "val %N = %L", + context.kotlinName(statement.bindings[0].name.value), + command.type.shape.new( + context, + unnamedIndices, + context.input(command.type.elementType.value), + ), + ) + } + + is OutputNode -> { + if (command.host.value != context.host) return + builder.addCode( + indexExpression(command.message, context).forEachIndexed( + command.type.shape, + context + ) { _, value -> + context.output(value, command.type.elementType.value) + }, + ) + } + + is IfNode -> { + builder.beginControlFlow("if (%L)", indexExpression(command.guard, context)) + command.thenBranch.forEach { generate(it, builder) } + builder.endControlFlow() + builder.beginControlFlow("else") + command.elseBranch.forEach { generate(it, builder) } + builder.endControlFlow() + } + + is LoopNode -> { + builder.beginControlFlow("while (true)") + command.body.forEach { generate(it, builder) } + builder.endControlFlow() + } + + is BreakNode -> { + builder.addStatement("break") + } + + } } + + else -> throw UnsupportedOperationException("Incorrect statement type in function body") } - return builder.build() } /** Generates code for the circuit [circuitDeclaration]. */ @@ -279,27 +315,23 @@ private class BackendCodeGenerator( override fun kotlinName(sourceName: Variable): String = varMap.getOrPut(sourceName) { freshNameGenerator.getFreshName(sourceName.name) } - override fun newTemporary(baseName: String): String = - freshNameGenerator.getFreshName(baseName) + override fun newTemporary(baseName: String): String = freshNameGenerator.getFreshName(baseName) - override fun codeOf(host: Host) = - hostDeclarations.reference(host) + override fun codeOf(host: Host) = hostDeclarations.reference(host) - fun input(type: ValueType): CodeBlock = - CodeBlock.of( - "(%N.input(%T) as %T).value", - "runtime", - type::class, - type.valueClass, - ) + fun input(type: ValueType): CodeBlock = CodeBlock.of( + "(%N.input(%T) as %T).value", + "runtime", + type::class, + type.valueClass, + ) - fun output(value: CodeBlock, type: ValueType): CodeBlock = - CodeBlock.of( - "%N.output(%T(%L))", - "runtime", - type.valueClass, - value, - ) + fun output(value: CodeBlock, type: ValueType): CodeBlock = CodeBlock.of( + "%N.output(%T(%L))", + "runtime", + type.valueClass, + value, + ) override fun receive(type: TypeName, sender: Host): CodeBlock { require(sender != context.host) @@ -343,8 +375,7 @@ fun ProgramNode.compileToKotlin( // Mark generated code as automatically generated. fileBuilder.addAnnotation( - AnnotationSpec.builder(Generated::class).addMember("%S", BackendCodeGenerator::class.qualifiedName!!) - .build(), + AnnotationSpec.builder(Generated::class).addMember("%S", BackendCodeGenerator::class.qualifiedName!!).build(), ) // Suppress warnings expected in generated code. diff --git a/compiler/tests/should-pass/circuit/ctrlflow.circuit b/compiler/tests/should-pass/circuit/ctrlflow.circuit index 53f0015760..8f1f7f80a6 100644 --- a/compiler/tests/should-pass/circuit/ctrlflow.circuit +++ b/compiler/tests/should-pass/circuit/ctrlflow.circuit @@ -8,25 +8,23 @@ fun <> main() -> { val = alice.output(x) return } + else { return } + val = if (false) { + val = alice.output(y) + return + } else { - val = if (false) { - val = alice.output(y) - return - } - else { - val = loop { - val stop@Local(host = alice) = alice.input() - val = if (stop) { - val = break - return - } - else { - val = alice.output(z) - return - } + val = loop { + val stop@Local(host = alice) = alice.input() + val = if (stop) { + val = break + return + } + else { + val = alice.output(z) return - } - return + } + return } return }