Skip to content

Commit

Permalink
Test flow up to loop unrolling working
Browse files Browse the repository at this point in the history
  • Loading branch information
d0cd committed Nov 27, 2024
1 parent c6e49ad commit ff4e46f
Show file tree
Hide file tree
Showing 10 changed files with 225 additions and 18 deletions.
2 changes: 1 addition & 1 deletion compiler/ast/src/passes/reconstructor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -425,7 +425,7 @@ pub trait ProgramReconstructor: StatementReconstructor {
.into_iter()
.map(|(id, scope)| (id, self.reconstruct_program_scope(scope)))
.collect(),
tests: input.tests.into_iter().map(|(id, test)| (id, self.reconstruct_test(test))).collect(),
tests: input.tests.into_iter().map(|test| self.reconstruct_test(test)).collect(),
}
}

Expand Down
2 changes: 1 addition & 1 deletion compiler/ast/src/passes/visitor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ pub trait ProgramVisitor<'a>: StatementVisitor<'a> {
input.imports.values().for_each(|import| self.visit_import(&import.0));
input.stubs.values().for_each(|stub| self.visit_stub(stub));
input.program_scopes.values().for_each(|scope| self.visit_program_scope(scope));
input.tests.values().for_each(|test| self.visit_test(test));
input.tests.iter().for_each(|test| self.visit_test(test));
}

fn visit_program_scope(&mut self, input: &'a ProgramScope) {
Expand Down
9 changes: 2 additions & 7 deletions compiler/ast/src/program/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ pub struct Program {
pub program_scopes: IndexMap<Symbol, ProgramScope>,
/// A map from test file names to test defintions.
// TODO: This is a temporary way to store tests in the AST, without requiring an overhaul of the compiler.
pub tests: IndexMap<Symbol, Test>,
pub tests: Vec<Test>,
}

impl fmt::Display for Program {
Expand All @@ -63,11 +63,6 @@ impl fmt::Display for Program {
impl Default for Program {
/// Constructs an empty program node.
fn default() -> Self {
Self {
imports: IndexMap::new(),
stubs: IndexMap::new(),
program_scopes: IndexMap::new(),
tests: IndexMap::new(),
}
Self { imports: IndexMap::new(), stubs: IndexMap::new(), program_scopes: IndexMap::new(), tests: Vec::new() }
}
}
193 changes: 190 additions & 3 deletions compiler/compiler/src/compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,31 @@ impl<'a, N: Network> Compiler<'a, N> {
Ok(())
}

/// Parses and stores the test source , constructs the AST, and optionally outputs it.
pub fn parse_test(&mut self) -> Result<()> {
// Initialize the AST.
let mut ast = Ast::default();
// Parse the sources.
for (name, program_string) in &self.sources {
// Register the source (`program_string`) in the source map.
let prg_sf = with_session_globals(|s| s.source_map.new_source(program_string, name.clone()));
// Use the parser to construct the abstract syntax tree (ast).
ast.combine(leo_parser::parse_test_ast::<N>(
self.handler,
&self.node_builder,
&prg_sf.src,
prg_sf.start_pos,
)?);
}
// Store the AST.
self.ast = ast;
// Write the AST to a JSON file.
if self.compiler_options.output.initial_ast {
self.write_ast_to_json("initial_ast.json")?;
}
Ok(())
}

/// Runs the symbol table pass.
pub fn symbol_table_pass(&self) -> Result<SymbolTable> {
let symbol_table = SymbolTableCreator::do_pass((&self.ast, self.handler))?;
Expand Down Expand Up @@ -293,6 +318,168 @@ impl<'a, N: Network> Compiler<'a, N> {
Ok((st, struct_graph, call_graph))
}

/// Runs the test symbol table pass.
pub fn test_symbol_table_pass(&self) -> Result<SymbolTable> {
let symbol_table = SymbolTableCreator::do_pass((&self.ast, self.handler))?;
if self.compiler_options.output.initial_symbol_table {
self.write_symbol_table_to_json("initial_symbol_table.json", &symbol_table)?;
}
Ok(symbol_table)
}

/// Runs the test type checker pass.
pub fn test_type_checker_pass(
&'a self,
symbol_table: SymbolTable,
) -> Result<(SymbolTable, StructGraph, CallGraph)> {
let (symbol_table, struct_graph, call_graph) = TypeChecker::<N>::do_pass((
&self.ast,
self.handler,
symbol_table,
&self.type_table,
self.compiler_options.build.conditional_block_max_depth,
self.compiler_options.build.disable_conditional_branch_type_checking,
self.compiler_options.output.build_tests,
))?;
if self.compiler_options.output.type_checked_symbol_table {
self.write_symbol_table_to_json("type_checked_symbol_table.json", &symbol_table)?;
}
Ok((symbol_table, struct_graph, call_graph))
}

/// Runs the test loop unrolling pass.
pub fn test_loop_unrolling_pass(&mut self, symbol_table: SymbolTable) -> Result<SymbolTable> {
let (ast, symbol_table) = Unroller::do_pass((
std::mem::take(&mut self.ast),
self.handler,
&self.node_builder,
symbol_table,
&self.type_table,
))?;
self.ast = ast;

if self.compiler_options.output.unrolled_ast {
self.write_ast_to_json("unrolled_ast.json")?;
}

if self.compiler_options.output.unrolled_symbol_table {
self.write_symbol_table_to_json("unrolled_symbol_table.json", &symbol_table)?;
}

Ok(symbol_table)
}

/// Runs the test static single assignment pass.
pub fn test_static_single_assignment_pass(&mut self, symbol_table: &SymbolTable) -> Result<()> {
self.ast = StaticSingleAssigner::do_pass((
std::mem::take(&mut self.ast),
&self.node_builder,
&self.assigner,
symbol_table,
&self.type_table,
))?;

if self.compiler_options.output.ssa_ast {
self.write_ast_to_json("ssa_ast.json")?;
}

Ok(())
}

/// Runs the test flattening pass.
pub fn test_flattening_pass(&mut self, symbol_table: &SymbolTable) -> Result<()> {
self.ast = Flattener::do_pass((
std::mem::take(&mut self.ast),
symbol_table,
&self.type_table,
&self.node_builder,
&self.assigner,
))?;

if self.compiler_options.output.flattened_ast {
self.write_ast_to_json("flattened_ast.json")?;
}

Ok(())
}

/// Runs the test destructuring pass.
pub fn test_destructuring_pass(&mut self) -> Result<()> {
self.ast = Destructurer::do_pass((
std::mem::take(&mut self.ast),
&self.type_table,
&self.node_builder,
&self.assigner,
))?;

if self.compiler_options.output.destructured_ast {
self.write_ast_to_json("destructured_ast.json")?;
}

Ok(())
}

/// Runs the test function inlining pass.
pub fn test_function_inlining_pass(&mut self, call_graph: &CallGraph) -> Result<()> {
let ast = FunctionInliner::do_pass((
std::mem::take(&mut self.ast),
&self.node_builder,
call_graph,
&self.assigner,
&self.type_table,
))?;
self.ast = ast;

if self.compiler_options.output.inlined_ast {
self.write_ast_to_json("inlined_ast.json")?;
}

Ok(())
}

/// Runs the test dead code elimination pass.
pub fn test_dead_code_elimination_pass(&mut self) -> Result<()> {
if self.compiler_options.build.dce_enabled {
self.ast = DeadCodeEliminator::do_pass((std::mem::take(&mut self.ast), &self.node_builder))?;
}

if self.compiler_options.output.dce_ast {
self.write_ast_to_json("dce_ast.json")?;
}

Ok(())
}

/// Runs the test code generation pass.
pub fn test_code_generation_pass(
&mut self,
symbol_table: &SymbolTable,
struct_graph: &StructGraph,
call_graph: &CallGraph,
) -> Result<String> {
CodeGenerator::do_pass((&self.ast, symbol_table, &self.type_table, struct_graph, call_graph, &self.ast.ast))
}

/// Runs the test compiler stages.
pub fn test_compiler_stages(&mut self) -> Result<(SymbolTable, StructGraph, CallGraph)> {
let st = self.test_symbol_table_pass()?;
let (st, struct_graph, call_graph) = self.test_type_checker_pass(st)?;

let st = self.test_loop_unrolling_pass(st)?;

self.test_static_single_assignment_pass(&st)?;

self.test_flattening_pass(&st)?;

self.test_destructuring_pass()?;

self.test_function_inlining_pass(&call_graph)?;

self.test_dead_code_elimination_pass()?;

Ok((st, struct_graph, call_graph))
}

/// Returns a compiled Leo program.
pub fn compile(&mut self) -> Result<String> {
// Parse the program.
Expand All @@ -309,13 +496,13 @@ impl<'a, N: Network> Compiler<'a, N> {
/// Returns the compiled Leo tests.
pub fn compile_tests(&mut self) -> Result<String> {
// Parse the program.
self.parse()?;
self.parse_test()?;
// Copy the dependencies specified in `program.json` into the AST.
self.add_import_stubs()?;
// Run the intermediate compiler stages.
let (symbol_table, struct_graph, call_graph) = self.compiler_stages()?;
let (symbol_table, struct_graph, call_graph) = self.test_compiler_stages()?;
// Run code generation.
let bytecode = self.code_generation_pass(&symbol_table, &struct_graph, &call_graph)?;
let bytecode = self.test_code_generation_pass(&symbol_table, &struct_graph, &call_graph)?;
Ok(bytecode)
}

Expand Down
15 changes: 14 additions & 1 deletion compiler/parser/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ pub(crate) use tokenizer::*;
pub mod parser;
pub use parser::*;

use leo_ast::{Ast, NodeBuilder};
use leo_ast::{Ast, NodeBuilder, Test};

Check warning on line 34 in compiler/parser/src/lib.rs

View workflow job for this annotation

GitHub Actions / Code Coverage

unused import: `Test`
use leo_errors::{Result, emitter::Handler};

use snarkvm::prelude::Network;
Expand All @@ -48,3 +48,16 @@ pub fn parse_ast<N: Network>(
) -> Result<Ast> {
Ok(Ast::new(parse::<N>(handler, node_builder, source, start_pos)?))
}

/// Creates a new test AST from a given file path and source code text.
pub fn parse_test_ast<N: Network>(
handler: &Handler,
node_builder: &NodeBuilder,
source: &str,
start_pos: BytePos,
) -> Result<Ast> {
let test = parse_test::<N>(handler, node_builder, source, start_pos)?;
let mut program = leo_ast::Program::default();
program.tests.push(test);
Ok(Ast::new(program))
}
2 changes: 1 addition & 1 deletion compiler/parser/src/parser/file.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ impl<N: Network> ParserContext<'_, N> {
return Err(ParserError::missing_program_scope(self.token.span).into());
}

Ok(Program { imports, stubs: IndexMap::new(), program_scopes, tests: IndexMap::new() })
Ok(Program { imports, stubs: IndexMap::new(), program_scopes, tests: Vec::new() })
}

pub(super) fn unexpected_item(token: &SpannedToken, expected: &[Token]) -> ParserError {
Expand Down
12 changes: 12 additions & 0 deletions compiler/parser/src/parser/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,15 @@ pub fn parse<N: Network>(

tokens.parse_program()
}

/// Creates a new test from a given file path and source code text.
pub fn parse_test<N: Network>(
handler: &Handler,
node_builder: &NodeBuilder,
source: &str,
start_pos: BytePos,
) -> Result<Test> {
let mut tokens = ParserContext::<N>::new(handler, node_builder, crate::tokenize(source, start_pos)?);

tokens.parse_test()
}
2 changes: 1 addition & 1 deletion compiler/parser/src/parser/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use leo_span::Symbol;

impl<N: Network> ParserContext<'_, N> {
/// Parses a test file.
fn parse_test(&mut self) -> Result<Test> {
pub fn parse_test(&mut self) -> Result<Test> {
// Initialize storage for the components of the test file
let mut consts: Vec<(Symbol, ConstDeclaration)> = Vec::new();
let mut functions = Vec::new();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ impl ProgramConsumer for StaticSingleAssigner<'_> {
.into_iter()
.map(|(name, scope)| (name, self.consume_program_scope(scope)))
.collect(),
tests: input.tests.into_iter().map(|(name, test)| (name, self.consume_test(test))).collect(),
tests: input.tests.into_iter().map(|test| self.consume_test(test)).collect(),
}
}
}
4 changes: 2 additions & 2 deletions leo/package/src/tst/default.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ impl DefaultTestFile {
@native_test
@interpreted_test
transition test_helloworld() {{
let result: u32 = helloworld.aleo/main(1u32, 2u32)
assert_eq!(result, 3u32)
let result: u32 = helloworld.aleo/main(1u32, 2u32);
assert_eq(result, 3u32);
}}
"#
.to_string()
Expand Down

0 comments on commit ff4e46f

Please sign in to comment.