diff --git a/src/ast_types.rs b/src/ast_types.rs index 34466fec..6bef5d19 100644 --- a/src/ast_types.rs +++ b/src/ast_types.rs @@ -1,7 +1,7 @@ use anyhow::bail; use itertools::Itertools; use std::{fmt::Display, str::FromStr}; -use triton_vm::{triton_asm, triton_instr}; +use triton_vm::{instruction::LabelledInstruction, triton_asm, triton_instr}; use crate::{ast::FnSignature, libraries::LibraryFunction}; @@ -463,6 +463,23 @@ pub enum AbstractArgument { ValueArgument(AbstractValueArg), } +impl Display for AbstractArgument { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{}", + match self { + AbstractArgument::ValueArgument(val_arg) => { + format!("{}: {}", val_arg.name, val_arg.data_type) + } + AbstractArgument::FunctionArgument(fun_arg) => { + format!("fn ({}): {}", fun_arg.abstract_name, fun_arg.function_type) + } + } + ) + } +} + #[derive(Debug, Clone, Hash, PartialEq, Eq)] pub struct AbstractFunctionArg { pub abstract_name: String, @@ -589,6 +606,78 @@ impl EnumType { .concat() } + /// Return the code to put enum-variant data fields on top of stack. + /// Does not consume the enum_expr pointer. + /// BEFORE: _ *enum_expr + /// AFTER: _ *enum_expr [*variant-data-fields] + pub(crate) fn get_variant_data_fields_in_memory( + &self, + variant_name: &str, + ) -> Vec<(Vec, DataType)> { + // TODO: Can we get this code to consume *enum_expr instead? + + // Example: Foo::Bar(Vec) + // Memory layout will be: + // [discriminant, field_size, [u32_list]] + // In that case we want to return code to get *u32_list. + + // You can assume that the stack has a pointer to `discriminant` on + // top. So we want to return the code + // `push 1 add push 1 add` + let data_types = self.variant_data_type(variant_name); + + // Skip discriminant + let mut acc_code = vec![triton_instr!(push 1), triton_instr!(add)]; + let mut ret: Vec> = vec![]; + + // Invariant: _ *enum_expr [*preceding_fields] + for (field_count, dtype) in data_types.clone().into_iter().enumerate() { + match dtype.bfield_codec_length() { + Some(size) => { + // field size is known statically, does not need to be read + // from memory + // stack: _ *enum_expr [*preceding_fields] + ret.push(triton_asm!( + // _ *enum_expr [*preceding_fields] + + dup {field_count} + // _ *enum_expr [*previous_fields] *enum_expr + + {&acc_code} + // _ *enum_expr [*previous_fields] *current_field + )); + acc_code.append(&mut triton_asm!(push {size} add)); + } + None => { + // Current field size must be read from memory + // stack: _ *enum_expr [*preceding_fields] + ret.push(triton_asm!( + // _ *enum_expr [*preceding_fields] + + dup {field_count} + // _ *enum_expr [*previous_fields] *enum_expr + + {&acc_code} + // _ *enum_expr [*previous_fields] *current_field_size + + push 1 + add + // _ *enum_expr [*previous_fields] *current_field + )); + + acc_code.append(&mut triton_asm!( + read_mem + add + push 1 + add + )); + } + } + } + + ret.into_iter().zip_eq(data_types).collect_vec() + } + /// Return the constructor that is called by an expression evaluating to an /// enum type. E.g.: `Foo::A(100u32);` pub(crate) fn variant_constructor(&self, variant_name: &str) -> LibraryFunction { @@ -946,6 +1035,12 @@ pub struct FunctionType { pub output: DataType, } +impl Display for FunctionType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{} -> {}", self.input_argument, self.output) + } +} + impl From<&FnSignature> for DataType { fn from(value: &FnSignature) -> Self { let mut input_args = vec![]; diff --git a/src/graft.rs b/src/graft.rs index 241777da..df37bfbe 100644 --- a/src/graft.rs +++ b/src/graft.rs @@ -1346,6 +1346,12 @@ impl<'a> Graft<'a> { name: ident.ident.to_string(), }); } + syn::Pat::Wild(_) => { + assert!( + pat.elems.len().is_one(), + "For now, wildcard binding must be only binding" + ); + } other => { panic!("unsupported binding for match-arm: {other:?}") } diff --git a/src/tasm_code_generator.rs b/src/tasm_code_generator.rs index 28e57a2a..1703f91d 100644 --- a/src/tasm_code_generator.rs +++ b/src/tasm_code_generator.rs @@ -57,6 +57,42 @@ impl ValueLocation { } } } + + /// Given a value location, return the code to put the pointer on top of the stack. + fn get_value_pointer( + &self, + state: &mut CompilerState, + field_or_element_type: &ast_types::DataType, + identifier: &ast::Identifier, + ) -> Vec { + match self { + ValueLocation::OpStack(depth) => { + if !self.is_accessible(field_or_element_type) { + let binding_name = identifier.binding_name(); + let value_ident_of_binding = state + .function_state + .var_addr + .get(&binding_name) + .unwrap_or_else(|| { + panic!("Could not locate value identifier for binding {binding_name}") + }) + .to_owned(); + state.mark_as_spilled(&value_ident_of_binding); + triton_asm!(push 0 assert) + } else { + // Type of `l` in `l[]` is known to be list. So we know stack size = 1 + triton_asm!(dup { depth }) + } + } + ValueLocation::StaticMemoryAddress(pointer) => { + // Type of `l` in `l[]` is known to be list. So we know stack size = 1 + // Read the list pointer from memory, then clear the stack, leaving only + // the list pointer on the stack. + triton_asm!(push {pointer} read_mem swap 1 pop) + } + ValueLocation::DynamicMemoryAddress(code) => code.to_owned(), + } + } } /// State that is preserved across the compilation of functions @@ -200,44 +236,6 @@ impl<'a> CompilerState<'a> { &mut self, identifier: &ast::Identifier, ) -> ValueLocation { - /// Return the code to put the pointer on top of the stack - fn get_value_pointer( - state: &mut CompilerState, - lhs_location: &ValueLocation, - field_or_element_type: &ast_types::DataType, - identifier: &ast::Identifier, - ) -> Vec { - match lhs_location { - ValueLocation::OpStack(depth) => { - if !lhs_location.is_accessible(field_or_element_type) { - let binding_name = identifier.binding_name(); - let value_ident_of_binding = state - .function_state - .var_addr - .get(&binding_name) - .unwrap_or_else(|| { - panic!( - "Could not locate value identifier for binding {binding_name}" - ) - }) - .to_owned(); - state.mark_as_spilled(&value_ident_of_binding); - triton_asm!(push 0 assert) - } else { - // Type of `l` in `l[]` is known to be list. So we know stack size = 1 - triton_asm!(dup { depth }) - } - } - ValueLocation::StaticMemoryAddress(pointer) => { - // Type of `l` in `l[]` is known to be list. So we know stack size = 1 - // Read the list pointer from memory, then clear the stack, leaving only - // the list pointer on the stack. - triton_asm!(push {pointer} read_mem swap 1 pop) - } - ValueLocation::DynamicMemoryAddress(code) => code.to_owned(), - } - } - match identifier { ast::Identifier::String(_, _) => { let var_name = identifier.binding_name(); @@ -276,7 +274,7 @@ impl<'a> CompilerState<'a> { let element_type = element_type.get_type(); let lhs_location = state.locate_identifier(ident); let ident_addr_code = - get_value_pointer(state, &lhs_location, &element_type, ident); + lhs_location.get_value_pointer(state, &element_type, ident); // stack: _ *sequence state.new_value_identifier("list_expression", &ident.get_type()); @@ -431,7 +429,7 @@ impl<'a> CompilerState<'a> { match lhs.get_type() { ast_types::DataType::Boxed(inner_type) => { let get_pointer = - get_value_pointer(self, &lhs_location, &known_type.get_type(), lhs); + lhs_location.get_value_pointer(self, &known_type.get_type(), lhs); match *inner_type { ast_types::DataType::Struct(inner_struct) => { let get_field_pointer_from_struct_pointer = @@ -612,7 +610,10 @@ impl<'a> CompilerState<'a> { } } - /// Return a new, guaranteed unique label that can be used anywhere in the code + /// Return a new, guaranteed unique label that can be used in the code. + /// Do *not* use this to declare new vstack elements. Instead use + /// `new_value_identifier` for that, as that function checks if the + /// newly bound value needs to be spilled. pub fn unique_label( &mut self, prefix: &str, @@ -1031,32 +1032,32 @@ pub(crate) fn compile_function( state.compose_code_for_outer_function(compiled_function, &function.signature) } +/// Local function to handle a block statement. Returns code to execute +/// all statements, and to clean up the stack after the exit of a block. +fn compile_block_stmt( + block: &ast::BlockStmt, + state: &mut CompilerState, +) -> Vec { + let vstack_init = state.function_state.vstack.clone(); + let var_addr_init = state.function_state.var_addr.clone(); + let block_body_code = block + .stmts + .iter() + .map(|stmt| compile_stmt(stmt, state)) + .collect_vec() + .concat(); + + let restore_stack_code = state.restore_stack_code(&vstack_init, &var_addr_init); + + [block_body_code, restore_stack_code].concat() +} + /// Produce the code and handle the `vstack` for a statement. `env_fn_signature` is the /// function signature in which the statement is enclosed. fn compile_stmt( stmt: &ast::Stmt, state: &mut CompilerState, ) -> Vec { - /// Local function to handle a block statement. Returns code to execute - /// all statements, and to clean up the stack after the exit of a block. - fn compile_block_stmt( - block: &ast::BlockStmt, - state: &mut CompilerState, - ) -> Vec { - let vstack_init = state.function_state.vstack.clone(); - let var_addr_init = state.function_state.var_addr.clone(); - let block_body_code = block - .stmts - .iter() - .map(|stmt| compile_stmt(stmt, state)) - .collect_vec() - .concat(); - - let restore_stack_code = state.restore_stack_code(&vstack_init, &var_addr_init); - - [block_body_code, restore_stack_code].concat() - } - match stmt { ast::Stmt::Let(ast::LetStmt { var_name, @@ -1314,151 +1315,348 @@ fn compile_stmt( vec![] } - ast::Stmt::Match(ast::MatchStmt { - arms, - match_expression, - }) => { - let vstack_init = state.function_state.vstack.clone(); - let var_addr_init = state.function_state.var_addr.clone(); + ast::Stmt::Match(match_stmt) => match match_stmt.match_expression.get_type() { + ast_types::DataType::Enum(_) => compile_match_stmt_stack_expr(match_stmt, state), + ast_types::DataType::Boxed(inner) => match *inner.to_owned() { + ast_types::DataType::Enum(_) => compile_match_stmt_boxed_expr(match_stmt, state), + _ => unreachable!(), + }, + _ => unreachable!(), + }, + } +} - // Evaluate match expression - let (match_expr_id, match_expr_evaluation) = - compile_expr(match_expression, "match-expr", state); - assert!( +/// Compile a match-statement where the matched-against value lives in memory +fn compile_match_stmt_boxed_expr( + ast::MatchStmt { + arms, + match_expression, + }: &ast::MatchStmt, + state: &mut CompilerState, +) -> Vec { + let vstack_init = state.function_state.vstack.clone(); + let var_addr_init = state.function_state.var_addr.clone(); + + // Evaluate match expression + let (match_expr_id, match_expr_evaluation) = + compile_expr(match_expression, "match-expr", state); + assert!( !state.function_state.spill_required.contains(&match_expr_id), "Cannot handle memory-spill of evaluated match expressions. But {match_expr_id} required memory spilling" ); - let mut match_code = triton_asm!({ &match_expr_evaluation }); - let contains_wildcard = arms - .iter() - .any(|x| matches!(x.match_condition, ast::MatchCondition::CatchAll)); + let mut match_code = triton_asm!({ &match_expr_evaluation }); + let contains_wildcard = arms + .iter() + .any(|x| matches!(x.match_condition, ast::MatchCondition::CatchAll)); + + let match_expr_discriminant = if contains_wildcard { + // Indicate that no arm body has been executed yet. For wildcard arm-conditions. + match_code.push(triton_instr!(push 1)); + triton_asm!( + // *match_expr no_arm_taken_indicator + swap 1 + // no_arm_taken_indicator *match_expr + read_mem + // no_arm_taken_indicator *match_expr discriminant + swap 2 + // discriminant *match_expr no_arm_taken_indicator + swap 1 + // discriminant no_arm_taken_indicator *match_expr - if contains_wildcard { - // Indicate that no arm body has been executed yet. For wildcard arm-conditions. - match_code.push(triton_instr!(push 1)); - } + swap 2 + // *match_expr no_arm_taken_indicator discriminant + ) + } else { + // *match_expr + triton_asm!(read_mem) + // *match_expr discriminant + }; - let match_expression_enum_type = match_expression.get_type().as_enum_type(); - - let outer_vstack = state.function_state.vstack.clone(); - let outer_bindings = state.function_state.var_addr.clone(); - for (arm_counter, arm) in arms.iter().enumerate() { - // At start of each loop-iternation, stack is: - // stack: _ [expression_variant_data] expression_variant_discriminant - - let arm_subroutine_label = format!("{match_expr_id}_body_{arm_counter}"); - - match &arm.match_condition { - ast::MatchCondition::EnumVariant(enum_variant) => { - // We know that variant discriminant is on top - let arm_variant_discriminant = match_expression_enum_type - .variant_discriminant(&enum_variant.variant_name); - match_code.append(&mut triton_asm!( - dup {contains_wildcard as u32} - // _ match_expr match_expr_discriminant - push {arm_variant_discriminant} - // _ match_expr match_expr_discriminant needle_discriminant - - eq - skiz - call {arm_subroutine_label} - )); + // Get enum_type + let match_expression_enum_type = match_expression.get_type().unbox().as_enum_type(); + let outer_vstack = state.function_state.vstack.clone(); + let outer_bindings = state.function_state.var_addr.clone(); - let remove_old_any_arm_taken_indicator = if contains_wildcard { - triton_asm!(pop) - } else { - triton_asm!() - }; - let set_new_no_arm_taken_indicator = if contains_wildcard { - triton_asm!(push 0) - } else { - triton_asm!() - }; - - // Split compiler's view of evaluated expression from - // _ [enum_value] - // into - // _ [enum_data] [padding] discriminant - let new_ids = state.split_value( - &match_expr_id, - match_expression_enum_type - .decompose_variant(&enum_variant.variant_name), - ); - // Insert bindings from pattern-match into stack view for arm-body - enum_variant - .data_bindings - .iter() - .zip(new_ids.iter()) - .for_each(|(binding, new_id)| { - state - .function_state - .var_addr - .insert(binding.name.to_owned(), new_id.clone()); - }); - - let body_code = compile_block_stmt(&arm.body, state); - - // This arm-body changes the `arm_taken` bool but otherwise leaves the stack unchanged - let subroutine_code = triton_asm!( - {arm_subroutine_label}: - {&remove_old_any_arm_taken_indicator} - // stack: _ [expression_variant_data] [padding] expression_variant_discriminant - - {&body_code} - - {&set_new_no_arm_taken_indicator} - return - ); + for (arm_counter, arm) in arms.iter().enumerate() { + // At start of each loop-iternation, stack is: + // stack: _ *match_expression + + let arm_subroutine_label = format!("{match_expr_id}_body_{arm_counter}"); + + match &arm.match_condition { + ast::MatchCondition::EnumVariant(enum_variant) => { + let arm_variant_discriminant = + match_expression_enum_type.variant_discriminant(&enum_variant.variant_name); + + match_code.append(&mut triton_asm!( + // _ match_expr + + {&match_expr_discriminant} + // _ *match_expr match_expr_discriminant + push {arm_variant_discriminant} + // _ match_expr match_expr_discriminant needle_discriminant + + eq + // _ match_expr (match_expr_discriminant == needle_discriminant) + + skiz + call {arm_subroutine_label} + // _ match_expr + )); + + let remove_old_any_arm_taken_indicator = if contains_wildcard { + triton_asm!(pop) + } else { + triton_asm!() + }; + let set_new_no_arm_taken_indicator = if contains_wildcard { + triton_asm!(push 0) + } else { + triton_asm!() + }; + + // Insert bindings from pattern-match into stack view for arm-body + // We need to insert pointers to each data-element contained in the + // variant. So we need a function that returns those addresses + // relative to the value. + let rel_data_position_and_types = match_expression_enum_type + .get_variant_data_fields_in_memory(&enum_variant.variant_name); + + let mut bindings_code = vec![]; + enum_variant + .data_bindings + .iter() + .zip(rel_data_position_and_types.iter()) + .for_each(|(binding, (get_field_pointer, dtype))| { + let dtype = ast_types::DataType::Boxed(Box::new(dtype.to_owned())); + let (new_binding_id, spill_addr) = + state.new_value_identifier("in_memory_split_value", &dtype); + assert!(spill_addr.is_none(), "Cannot handle memory-spilling in match-arm bindings yet. Required spilling of binding '{}'", binding.name); + // push relative address, add to absolute address + // to get *new* absolute address. + // Then insert this boxed type into `var_addr` state - .function_state - .subroutines - .push(subroutine_code.try_into().unwrap()); - } - ast::MatchCondition::CatchAll => { - // CatchAll (`_`) is guaranteed to be the last arm. So we only have to check if any - // previous arm was taken - match_code.append(&mut triton_asm!( - skiz - call {arm_subroutine_label} - push 0 // push 0 to make stack-cleanup code-path independent + .function_state + .var_addr + .insert(binding.name.to_owned(), new_binding_id); + bindings_code.append(&mut triton_asm!( + {&get_field_pointer} )); + }); - let body_code = compile_block_stmt(&arm.body, state); - let subroutine_code = triton_asm!( - {arm_subroutine_label}: - {&body_code} - return - ); + let body_code = compile_block_stmt(&arm.body, state); + + let pop_local_bindings = vec![triton_instr!(pop); enum_variant.data_bindings.len()]; + let subroutine_code = triton_asm!( + {arm_subroutine_label}: + {&remove_old_any_arm_taken_indicator} + {&bindings_code} + {&body_code} + + // We can just pop local binding from top of stack, since a statement cannot return anything + {&pop_local_bindings} + {&set_new_no_arm_taken_indicator} + return + ); + + state + .function_state + .subroutines + .push(subroutine_code.try_into().unwrap()); + } + ast::MatchCondition::CatchAll => { + // CatchAll (`_`) is guaranteed to be the last arm. So we only have to check if any + // previous arm was taken + match_code.append(&mut triton_asm!( + skiz + call {arm_subroutine_label} + push 0 // push 0 to make stack-cleanup code-path independent + )); + + let body_code = compile_block_stmt(&arm.body, state); + let subroutine_code = triton_asm!( + {arm_subroutine_label}: + {&body_code} + return + ); + state + .function_state + .subroutines + .push(subroutine_code.try_into().unwrap()); + } + } + + // Restore stack view and bindings view for next loop-iteration + state + .function_state + .restore_stack_and_bindings(&outer_vstack, &outer_bindings); + } + + // Cleanup stack by removing evaluated expresison and `any_arm_taken_bool` indicator + if contains_wildcard { + match_code.push(triton_instr!(pop)); + } + + // Remove match-expression from stack + let restore_stack_code = state.restore_stack_code(&vstack_init, &var_addr_init); + + triton_asm!( + {&match_code} + {&restore_stack_code} + ) +} + +/// Compile a match-statement where the matched-against value lives on the stack +fn compile_match_stmt_stack_expr( + ast::MatchStmt { + arms, + match_expression, + }: &ast::MatchStmt, + state: &mut CompilerState, +) -> Vec { + let vstack_init = state.function_state.vstack.clone(); + let var_addr_init = state.function_state.var_addr.clone(); + + // Evaluate match expression + let (match_expr_id, match_expr_evaluation) = + compile_expr(match_expression, "match-expr", state); + assert!( + !state.function_state.spill_required.contains(&match_expr_id), + "Cannot handle memory-spill of evaluated match expressions. But {match_expr_id} required memory spilling" + ); + + let mut match_code = triton_asm!({ &match_expr_evaluation }); + let contains_wildcard = arms + .iter() + .any(|x| matches!(x.match_condition, ast::MatchCondition::CatchAll)); + + if contains_wildcard { + // Indicate that no arm body has been executed yet. For wildcard arm-conditions. + match_code.push(triton_instr!(push 1)); + } + + // Match-expression is either an enum type, or a boxed enum type. + let match_expression_enum_type = match_expression.get_type().unbox().as_enum_type(); + + let outer_vstack = state.function_state.vstack.clone(); + let outer_bindings = state.function_state.var_addr.clone(); + let match_expr_discriminant = triton_asm!(dup {contains_wildcard as u32}); + for (arm_counter, arm) in arms.iter().enumerate() { + // At start of each loop-iternation, stack is: + // stack: _ [expression_variant_data] expression_variant_discriminant + + let arm_subroutine_label = format!("{match_expr_id}_body_{arm_counter}"); + + match &arm.match_condition { + ast::MatchCondition::EnumVariant(enum_variant) => { + // We know that variant discriminant is on top + let arm_variant_discriminant = + match_expression_enum_type.variant_discriminant(&enum_variant.variant_name); + match_code.append(&mut triton_asm!( + // dup {contains_wildcard as u32} + {&match_expr_discriminant} + // _ match_expr match_expr_discriminant + push {arm_variant_discriminant} + // _ match_expr match_expr_discriminant needle_discriminant + + eq + skiz + call {arm_subroutine_label} + )); + + let remove_old_any_arm_taken_indicator = if contains_wildcard { + triton_asm!(pop) + } else { + triton_asm!() + }; + let set_new_no_arm_taken_indicator = if contains_wildcard { + triton_asm!(push 0) + } else { + triton_asm!() + }; + + // Split compiler's view of evaluated expression from + // _ [enum_value] + // into + // _ [enum_data] [padding] discriminant + let new_ids = state.split_value( + &match_expr_id, + match_expression_enum_type.decompose_variant(&enum_variant.variant_name), + ); + // Insert bindings from pattern-match into stack view for arm-body + enum_variant + .data_bindings + .iter() + .zip(new_ids.iter()) + .for_each(|(binding, new_id)| { state .function_state - .subroutines - .push(subroutine_code.try_into().unwrap()); - // stack: _ [expression_variant_data] expression_variant_discriminant - } - } + .var_addr + .insert(binding.name.to_owned(), new_id.clone()); + }); + + let body_code = compile_block_stmt(&arm.body, state); + + // This arm-body changes the `arm_taken` bool but otherwise leaves the stack unchanged + let subroutine_code = triton_asm!( + {arm_subroutine_label}: + {&remove_old_any_arm_taken_indicator} + // stack: _ [expression_variant_data] [padding] expression_variant_discriminant + + {&body_code} + + {&set_new_no_arm_taken_indicator} + return + ); - // Restore stack view and bindings view for next loop-iteration state .function_state - .restore_stack_and_bindings(&outer_vstack, &outer_bindings); + .subroutines + .push(subroutine_code.try_into().unwrap()); } - - // Cleanup stack by removing evaluated expresison and `any_arm_taken_bool` indicator - if contains_wildcard { - match_code.push(triton_instr!(pop)); + ast::MatchCondition::CatchAll => { + // CatchAll (`_`) is guaranteed to be the last arm. So we only have to check if any + // previous arm was taken + match_code.append(&mut triton_asm!( + skiz + call {arm_subroutine_label} + push 0 // push 0 to make stack-cleanup code-path independent + )); + + let body_code = compile_block_stmt(&arm.body, state); + let subroutine_code = triton_asm!( + {arm_subroutine_label}: + {&body_code} + return + ); + state + .function_state + .subroutines + .push(subroutine_code.try_into().unwrap()); + // stack: _ [expression_variant_data] expression_variant_discriminant } + } - // Remove match-expression from stack - let restore_stack_code = state.restore_stack_code(&vstack_init, &var_addr_init); + // Restore stack view and bindings view for next loop-iteration + state + .function_state + .restore_stack_and_bindings(&outer_vstack, &outer_bindings); + } - triton_asm!( - {&match_code} - {&restore_stack_code} - ) - } + // Cleanup stack by removing evaluated expresison and `any_arm_taken_bool` indicator + if contains_wildcard { + match_code.push(triton_instr!(pop)); } + + // Remove match-expression from stack + let restore_stack_code = state.restore_stack_code(&vstack_init, &var_addr_init); + + triton_asm!( + {&match_code} + {&restore_stack_code} + ) } fn compile_fn_call( diff --git a/src/tests_and_benchmarks/ozk/programs/enums.rs b/src/tests_and_benchmarks/ozk/programs/enums.rs index 2ae0e754..0406b672 100644 --- a/src/tests_and_benchmarks/ozk/programs/enums.rs +++ b/src/tests_and_benchmarks/ozk/programs/enums.rs @@ -1,5 +1,5 @@ +mod boxed_proof_item_simple; mod custom_struct_in_data; -mod proof_item; mod rust_by_example_enums; mod two_variants_no_data; mod two_variants_one_data; diff --git a/src/tests_and_benchmarks/ozk/programs/enums/boxed_proof_item_simple.rs b/src/tests_and_benchmarks/ozk/programs/enums/boxed_proof_item_simple.rs new file mode 100644 index 00000000..7983d3e0 --- /dev/null +++ b/src/tests_and_benchmarks/ozk/programs/enums/boxed_proof_item_simple.rs @@ -0,0 +1,215 @@ +#![allow(clippy::assertions_on_constants)] + +// Allows the use of input/output on the native architecture +use crate::tests_and_benchmarks::ozk::rust_shadows as tasm; +use itertools::Itertools; +use num::Zero; +use triton_vm::{BFieldElement, Digest}; +use twenty_first::shared_math::{bfield_codec::BFieldCodec, x_field_element::XFieldElement}; + +#[derive(BFieldCodec)] +pub struct FriResponse { + /// The authentication structure of the Merkle tree. + pub auth_structure: Vec, + /// The values of the opened leaves of the Merkle tree. + pub revealed_leaves: Vec, +} + +#[derive(BFieldCodec)] +pub enum ProofItem { + AuthenticationStructure(Vec), + MasterBaseTableRows(Vec>), + MasterExtTableRows(Vec>), + OutOfDomainBaseRow(Vec), + OutOfDomainExtRow(Vec), + OutOfDomainQuotientSegments([XFieldElement; 4]), + MerkleRoot(Digest), + Log2PaddedHeight(u32), + QuotientSegmentsElements(Vec<[XFieldElement; 4]>), + FriCodeword(Vec), + FriResponse(FriResponse), +} + +impl ProofItem { + fn discriminant(&self) -> BFieldElement { + #[allow(unused_assignments)] + let mut discriminant: BFieldElement = BFieldElement::zero(); + match self { + ProofItem::AuthenticationStructure(_) => { + discriminant = BFieldElement::new(0); + } + ProofItem::MasterBaseTableRows(_) => { + discriminant = BFieldElement::new(1); + } + ProofItem::MasterExtTableRows(_) => { + discriminant = BFieldElement::new(2); + } + ProofItem::OutOfDomainBaseRow(_) => { + discriminant = BFieldElement::new(3); + } + ProofItem::OutOfDomainExtRow(_) => { + discriminant = BFieldElement::new(4); + } + ProofItem::OutOfDomainQuotientSegments(_) => { + discriminant = BFieldElement::new(1); + } + ProofItem::MerkleRoot(_) => { + discriminant = BFieldElement::new(5); + } + ProofItem::Log2PaddedHeight(_) => { + discriminant = BFieldElement::new(6); + } + ProofItem::QuotientSegmentsElements(_) => { + discriminant = BFieldElement::new(7); + } + ProofItem::FriCodeword(_) => { + discriminant = BFieldElement::new(8); + } + ProofItem::FriResponse(_) => { + discriminant = BFieldElement::new(9); + } + }; + + return discriminant; + } + + fn get_merkle_root(&self) -> Digest { + let mut root: Digest = Digest::default(); + match self { + ProofItem::MerkleRoot(digest) => { + root = *digest; + } + _ => { + assert!(false); + } + }; + + return root; + } +} + +fn main() { + let boxed_proof_item: Box = + ProofItem::decode(&tasm::load_from_memory(BFieldElement::new(84))).unwrap(); + tasm::tasm_io_write_to_stdout___bfe(boxed_proof_item.discriminant()); + tasm::tasm_io_write_to_stdout___digest(boxed_proof_item.get_merkle_root()); + + // Crash if not `MerkleRoot` + match *boxed_proof_item { + ProofItem::AuthenticationStructure(_) => { + assert!(false); + } + ProofItem::MasterBaseTableRows(_) => { + assert!(false); + } + ProofItem::MasterExtTableRows(_) => { + assert!(false); + } + ProofItem::OutOfDomainBaseRow(_) => { + assert!(false); + } + ProofItem::OutOfDomainExtRow(_) => { + assert!(false); + } + ProofItem::OutOfDomainQuotientSegments(_) => { + assert!(false); + } + ProofItem::MerkleRoot(a) => { + let b: BFieldElement = BFieldElement::new(10009 << 32); + tasm::tasm_io_write_to_stdout___digest(a); + tasm::tasm_io_write_to_stdout___bfe(b); + assert!(true); + } + ProofItem::Log2PaddedHeight(_) => { + assert!(false); + } + ProofItem::QuotientSegmentsElements(_) => { + assert!(false); + } + ProofItem::FriCodeword(_) => { + assert!(false); + } + ProofItem::FriResponse(_) => { + assert!(false); + } + }; + + // With wildcard + let mut d: BFieldElement = BFieldElement::new(100u64); + match *boxed_proof_item { + ProofItem::AuthenticationStructure(_) => { + assert!(false); + } + _ => { + let c: BFieldElement = BFieldElement::new(555u64 << 32); + assert!(true); + tasm::tasm_io_write_to_stdout___bfe(c); + d = BFieldElement::new(200u64); + } + }; + + tasm::tasm_io_write_to_stdout___bfe(d); + + return; +} + +mod tests { + use super::*; + use itertools::Itertools; + use rand::random; + use std::collections::HashMap; + + use crate::tests_and_benchmarks::ozk::{ozk_parsing, rust_shadows}; + use crate::tests_and_benchmarks::test_helpers::shared_test::{ + execute_compiled_with_stack_memory_and_ins_for_test, init_memory_from, + }; + + #[test] + fn proof_item_enum_test() { + let a_0: Digest = random(); + let proof_item = ProofItem::MerkleRoot(a_0); + let non_determinism = init_memory_from(&proof_item, BFieldElement::new(84)); + let expected_output = [ + vec![BFieldElement::new(5)], + a_0.encode(), + a_0.encode(), + vec![BFieldElement::new(10009 << 32)], + vec![BFieldElement::new(555 << 32)], + vec![BFieldElement::new(200)], + // vec![BFieldElement::new(100)], + ] + .concat(); + let stdin = vec![]; + + // Run test on host machine + let native_output = + rust_shadows::wrap_main_with_io(&main)(stdin.clone(), non_determinism.clone()); + assert_eq!(native_output, expected_output); + + // Run test on Triton-VM + let test_program = ozk_parsing::compile_for_test( + "enums", + "boxed_proof_item_simple", + "main", + crate::ast_types::ListType::Unsafe, + ); + println!("executing:\n{}", test_program.iter().join("\n")); + let vm_output = execute_compiled_with_stack_memory_and_ins_for_test( + &test_program, + vec![], + &mut HashMap::default(), + stdin, + non_determinism, + 0, + ) + .unwrap(); + if expected_output != vm_output.output { + panic!( + "expected_output:\n {}, got:\n{}. Code was:\n{}", + expected_output.iter().join(", "), + vm_output.output.iter().join(", "), + test_program.iter().join("\n") + ); + } + } +} diff --git a/src/tests_and_benchmarks/ozk/programs/enums/proof_item.rs b/src/tests_and_benchmarks/ozk/programs/enums/proof_item.rs deleted file mode 100644 index 869f72ad..00000000 --- a/src/tests_and_benchmarks/ozk/programs/enums/proof_item.rs +++ /dev/null @@ -1,89 +0,0 @@ -// Allows the use of input/output on the native architecture -use crate::tests_and_benchmarks::ozk::rust_shadows as tasm; -use itertools::Itertools; -use triton_vm::{BFieldElement, Digest}; -use twenty_first::shared_math::{bfield_codec::BFieldCodec, x_field_element::XFieldElement}; - -#[derive(BFieldCodec)] -pub struct FriResponse { - /// The authentication structure of the Merkle tree. - pub auth_structure: Vec, - /// The values of the opened leaves of the Merkle tree. - pub revealed_leaves: Vec, -} - -#[derive(BFieldCodec)] -pub enum ProofItem { - AuthenticationStructure(Vec), - MasterBaseTableRows(Vec>), - MasterExtTableRows(Vec>), - OutOfDomainBaseRow(Vec), - OutOfDomainExtRow(Vec), - OutOfDomainQuotientSegments([XFieldElement; 4]), - MerkleRoot(Digest), - Log2PaddedHeight(u32), - QuotientSegmentsElements(Vec<[XFieldElement; 4]>), - FriCodeword(Vec), - FriResponse(FriResponse), -} - -fn main() { - let digest_from_standard_in: Digest = tasm::tasm_io_read_stdin___digest(); - let _a: ProofItem = ProofItem::MerkleRoot(digest_from_standard_in); - // tasm::tasm_io_write_to_stdout___digest(); - return; -} - -mod tests { - use super::*; - use itertools::Itertools; - use rand::random; - use std::collections::HashMap; - use triton_vm::NonDeterminism; - - use crate::tests_and_benchmarks::ozk::{ozk_parsing, rust_shadows}; - use crate::tests_and_benchmarks::test_helpers::shared_test::execute_compiled_with_stack_memory_and_ins_for_test; - - #[test] - fn proof_item_enum_test() { - // let non_determinism = init_memory_from(&test_struct, BFieldElement::new(300)); - let non_determinism = NonDeterminism::default(); - let expected_output = vec![]; - let a_0: Digest = random(); - let mut a0_encoded_reverse = a_0.encode(); - a0_encoded_reverse.reverse(); - let stdin = [a0_encoded_reverse].concat(); - - // Run test on host machine - let native_output = - rust_shadows::wrap_main_with_io(&main)(stdin.clone(), non_determinism.clone()); - assert_eq!(native_output, expected_output); - - // Run test on Triton-VM - let test_program = ozk_parsing::compile_for_test( - "enums", - "proof_item", - "main", - crate::ast_types::ListType::Unsafe, - ); - println!("executing:\n{}", test_program.iter().join("\n")); - let vm_output = execute_compiled_with_stack_memory_and_ins_for_test( - &test_program, - vec![], - &mut HashMap::default(), - stdin, - non_determinism, - 0, - ) - .unwrap(); - // assert_eq!(expected_output, vm_output.output); - if expected_output != vm_output.output { - panic!( - "expected_output:\n {}, got:\n{}. Code was:\n{}", - expected_output.iter().join(", "), - vm_output.output.iter().join(", "), - test_program.iter().join("\n") - ); - } - } -} diff --git a/src/type_checker.rs b/src/type_checker.rs index c78551bf..c9a80011 100644 --- a/src/type_checker.rs +++ b/src/type_checker.rs @@ -530,10 +530,13 @@ fn annotate_stmt( // Verify that match-expression returns an enum-type let match_expression_type = derive_annotate_expr_type(match_expression, None, state, env_fn_signature); - let enum_type = if let ast_types::DataType::Enum(enum_type) = match_expression_type { - enum_type - } else { - panic!("`match` statements are only supported on enum types. For now."); + let (enum_type, is_boxed) = match match_expression_type { + ast_types::DataType::Enum(enum_type) => (enum_type, false), + ast_types::DataType::Boxed(inner) => match *inner.to_owned() { + ast_types::DataType::Enum(enum_type) => (enum_type, true), + other => panic!("`match` statements are only supported on enum types. For now. Got {other}", ) + }, + _ => panic!("`match` statements are only supported on enum types. For now. Got {match_expression_type}", ) }; let mut variants_encountered: HashSet = HashSet::default(); @@ -547,6 +550,8 @@ fn annotate_stmt( "When using wildcard in match statement, wildcard must be used in last match arm. Match expression was for type {}", enum_type.name ); contains_wildcard_arm = true; + + annotate_block_stmt(&mut arm.body, env_fn_signature, state); } ast::MatchCondition::EnumVariant(ast::EnumVariantSelector { enum_name, @@ -565,14 +570,20 @@ fn annotate_stmt( enum_name, "Match conditions on type {} must all be of same type. Got bad type: {enum_name}", enum_type.name); let variant_data_tuple = enum_type.variant_data_type(variant_name); - assert_eq!(variant_data_tuple.element_count(), data_bindings.len(), "Number of bindings must match number of elements in variant data tuple"); + assert!(data_bindings.is_empty() || variant_data_tuple.element_count() == data_bindings.len(), "Number of bindings must match number of elements in variant data tuple"); assert!( has_unique_elements(data_bindings.iter().map(|x| &x.name)), "Name repetition in pattern matching not allowed" ); data_bindings.iter().enumerate().for_each(|(i, x)| { - let new_binding_type = variant_data_tuple.fields[i].to_owned(); + let new_binding_type = if is_boxed { + ast_types::DataType::Boxed(Box::new( + variant_data_tuple.fields[i].to_owned(), + )) + } else { + variant_data_tuple.fields[i].to_owned() + }; state.vtable.insert( x.name.to_owned(), @@ -880,8 +891,19 @@ fn get_method_signature( } } + let declared_method_names = state + .declared_methods + .iter() + .map(|x| { + format!( + "{}: ({})", + &x.signature.name, + x.signature.args.iter().join(",") + ) + }) + .join("\n"); panic!( - "Method call in {} Don't know what type of value '{name}' returns! Receiver type was: {original_receiver_type:?}", env_fn_signature.name + "Method call in {} Don't know what type of value '{name}' returns! Receiver type was: {original_receiver_type:?}\n\nDeclared methods are:\n{}", env_fn_signature.name, declared_method_names ) }