From 35e6e56cc56c1d53152c8147a2af0fade2e50e0a Mon Sep 17 00:00:00 2001 From: Theodoros Kasampalis Date: Mon, 9 Dec 2024 13:17:57 -0600 Subject: [PATCH] adding a tail_call_info event that tells us whether an apply_rule function exited with a tail call --- bindings/python/ast.cpp | 10 ++++ docs/proof-trace.md | 3 + include/kllvm/binary/ProofTraceParser.h | 50 ++++++++++++++++ include/kllvm/codegen/ProofEvent.h | 15 +++++ include/runtime/header.h | 2 + include/runtime/proof_trace_writer.h | 7 +++ lib/binary/ProofTraceParser.cpp | 8 +++ lib/codegen/CreateTerm.cpp | 11 ++++ lib/codegen/Decision.cpp | 1 + lib/codegen/ProofEvent.cpp | 74 ++++++++++++++++++++++++ runtime/util/ConfigurationSerializer.cpp | 6 ++ 11 files changed, 187 insertions(+) diff --git a/bindings/python/ast.cpp b/bindings/python/ast.cpp index f53d8a867..2ddf3e6db 100644 --- a/bindings/python/ast.cpp +++ b/bindings/python/ast.cpp @@ -426,6 +426,16 @@ void bind_proof_trace(py::module_ &m) { "function_name", &llvm_pattern_matching_failure_event::get_function_name); + py::class_< + llvm_tail_call_info_event, + std::shared_ptr>( + proof_trace, "llvm_tail_call_info_event", step_event) + .def_property_readonly( + "callern_name", + &llvm_tail_call_info_event::get_caller_name) + .def_property_readonly( + "is_tail", &llvm_tail_call_info_event::is_tail); + py::class_>( proof_trace, "llvm_function_event", step_event) .def_property_readonly("name", &llvm_function_event::get_name) diff --git a/docs/proof-trace.md b/docs/proof-trace.md index 1ef3c1e19..f67a29d21 100644 --- a/docs/proof-trace.md +++ b/docs/proof-trace.md @@ -38,6 +38,7 @@ event ::= hook | side_cond_exit | config | pattern_matching_failure + | tail_call_info arg ::= kore_term @@ -60,6 +61,8 @@ rule ::= WORD(0x22) ordinal arity variable* side_cond_entry ::= WORD(0xEE) ordinal arity variable* side_cond_exit ::= WORD(0x33) ordinal boolean_result +tail_call_info ::= WORD(0x55) function_name boolean_result + config ::= WORD(0xFF) kore_term string ::= diff --git a/include/kllvm/binary/ProofTraceParser.h b/include/kllvm/binary/ProofTraceParser.h index 93912c4b9..36b3c4178 100644 --- a/include/kllvm/binary/ProofTraceParser.h +++ b/include/kllvm/binary/ProofTraceParser.h @@ -35,6 +35,7 @@ constexpr uint64_t rule_event_sentinel = detail::word(0x22); constexpr uint64_t side_condition_event_sentinel = detail::word(0xEE); constexpr uint64_t side_condition_end_sentinel = detail::word(0x33); constexpr uint64_t pattern_matching_failure_sentinel = detail::word(0x44); +constexpr uint64_t tail_call_info_sentinel = detail::word(0x55); class llvm_step_event : public std::enable_shared_from_this { public: @@ -172,6 +173,31 @@ class llvm_pattern_matching_failure_event : public llvm_step_event { const override; }; +class llvm_tail_call_info_event : public llvm_step_event { +private: + std::string caller_name_; + bool is_tail_; + + llvm_tail_call_info_event(std::string caller_name, bool is_tail) + : caller_name_(std::move(caller_name)) + , is_tail_(is_tail) { } + +public: + static sptr + create(std::string caller_name, bool is_tail) { + return sptr( + new llvm_tail_call_info_event(std::move(caller_name), is_tail)); + } + + [[nodiscard]] std::string const &get_caller_name() const { + return caller_name_; + } + [[nodiscard]] bool is_tail() const { return is_tail_; } + + void print(std::ostream &out, bool expand_terms, unsigned indent = 0U) + const override; +}; + class llvm_event; class llvm_function_event : public llvm_step_event { @@ -599,6 +625,27 @@ class proof_trace_parser { return event; } + sptr static parse_tail_call_info( + proof_trace_buffer &buffer) { + if (!buffer.check_word(tail_call_info_sentinel)) { + return nullptr; + } + + std::string caller_name; + if (!buffer.read_string(caller_name)) { + return nullptr; + } + + bool is_tail = false; + if (!buffer.read_bool(is_tail)) { + return nullptr; + } + + auto event = llvm_tail_call_info_event::create(caller_name, is_tail); + + return event; + } + bool parse_argument(proof_trace_buffer &buffer, llvm_event &event) { if (buffer.eof() || buffer.peek() != '\x7F') { return false; @@ -634,6 +681,9 @@ class proof_trace_parser { case pattern_matching_failure_sentinel: return parse_pattern_matching_failure(buffer); + case tail_call_info_sentinel: + return parse_tail_call_info(buffer); + default: return nullptr; } } diff --git a/include/kllvm/codegen/ProofEvent.h b/include/kllvm/codegen/ProofEvent.h index cf66b4cf4..1a5bcc7a8 100644 --- a/include/kllvm/codegen/ProofEvent.h +++ b/include/kllvm/codegen/ProofEvent.h @@ -31,6 +31,8 @@ class proof_event { */ std::pair proof_branch(std::string const &label, llvm::BasicBlock *insert_at_end); + std::pair + proof_branch(std::string const &label, llvm::Instruction *insert_before); /* * Set up a standard event prelude by creating a pair of basic blocks for the @@ -42,6 +44,8 @@ class proof_event { */ std::tuple event_prelude(std::string const &label, llvm::BasicBlock *insert_at_end); + std::tuple + event_prelude(std::string const &label, llvm::Instruction *insert_before); /* * Set up a check of whether a new proof hint chunk should be started. The @@ -172,6 +176,13 @@ class proof_event { llvm::Value *proof_writer, std::string const &function_name, llvm::BasicBlock *insert_at_end); + /* + * Emit a call to the `tail_call_info` API of the specified `proof_writer`. + */ + llvm::CallInst *emit_write_tail_call_info( + llvm::Value *proof_writer, std::string const &caller_name, + bool is_tail, llvm::BasicBlock *insert_at_end); + /* * Emit a call to the `start_new_chunk` API of the specified `proof_writer`. */ @@ -228,6 +239,10 @@ class proof_event { [[nodiscard]] llvm::BasicBlock *pattern_matching_failure( kore_composite_pattern const &pattern, llvm::BasicBlock *current_block); + [[nodiscard]] llvm::BasicBlock *tail_call_info( + std::string const &caller_name, bool is_tail, + llvm::Instruction *insert_before, llvm::BasicBlock *current_block); + proof_event(kore_definition *definition, llvm::Module *module) : definition_(definition) , module_(module) diff --git a/include/runtime/header.h b/include/runtime/header.h index d82cae004..d6174a517 100644 --- a/include/runtime/header.h +++ b/include/runtime/header.h @@ -228,6 +228,8 @@ void write_side_condition_event_post_to_proof_trace( void *proof_writer, uint64_t ordinal, bool side_cond_result); void write_pattern_matching_failure_to_proof_trace( void *proof_writer, char const *function_name); +void write_tail_call_info_to_proof_trace( + void *proof_writer, char const *caller_name, bool is_tail); void write_configuration_to_proof_trace( void *proof_writer, block *config, bool is_initial); void start_new_chunk_in_proof_trace(void *proof_writer); diff --git a/include/runtime/proof_trace_writer.h b/include/runtime/proof_trace_writer.h index 55c07c589..51ae069f0 100644 --- a/include/runtime/proof_trace_writer.h +++ b/include/runtime/proof_trace_writer.h @@ -33,6 +33,7 @@ class proof_trace_writer { side_condition_event_post(uint64_t ordinal, bool side_cond_result) = 0; virtual void pattern_matching_failure(char const *function_name) = 0; + virtual void tail_call_info(char const *caller_name, bool is_tail) = 0; virtual void configuration(block *config, bool is_initial) = 0; virtual void start_new_chunk() = 0; virtual void end_of_trace() = 0; @@ -163,6 +164,12 @@ class proof_trace_file_writer : public proof_trace_writer { write_null_terminated_string(function_name); } + void tail_call_info(char const *caller_name, bool is_tail) override { + write_uint64(kllvm::tail_call_info_sentinel); + write_null_terminated_string(caller_name); + write_bool(is_tail); + } + void configuration(block *config, bool is_initial) override { write_uint64(kllvm::config_sentinel); serialize_configuration_to_proof_trace(file_, config, 0); diff --git a/lib/binary/ProofTraceParser.cpp b/lib/binary/ProofTraceParser.cpp index 7e92a0219..3d502ae09 100644 --- a/lib/binary/ProofTraceParser.cpp +++ b/lib/binary/ProofTraceParser.cpp @@ -84,6 +84,14 @@ void llvm_pattern_matching_failure_event::print( "{}pattern matching failure: {}\n", indent, function_name_); } +void llvm_tail_call_info_event::print( + std::ostream &out, bool expand_terms, unsigned ind) const { + std::string indent(ind * indent_size, ' '); + out << fmt::format( + "{}tail_call_info: {} {}\n", indent, caller_name_, + (is_tail_ ? "tail" : "notail")); +} + void llvm_function_event::print( std::ostream &out, bool expand_terms, unsigned ind) const { std::string indent(ind * indent_size, ' '); diff --git a/lib/codegen/CreateTerm.cpp b/lib/codegen/CreateTerm.cpp index 3bb71c900..dcf86bc5b 100644 --- a/lib/codegen/CreateTerm.cpp +++ b/lib/codegen/CreateTerm.cpp @@ -1282,7 +1282,18 @@ bool make_function( if (call->getCallingConv() == llvm::CallingConv::Tail && can_tail_call(call->getType())) { call->setTailCallKind(llvm::CallInst::TCK_MustTail); + current_block = + proof_event(definition, module) + .tail_call_info(name, true, call, current_block); + } else { + current_block = + proof_event(definition, module) + .tail_call_info(name, false, nullptr, current_block); } + } else { + current_block = + proof_event(definition, module) + .tail_call_info(name, false, nullptr, current_block); } } auto *ret diff --git a/lib/codegen/Decision.cpp b/lib/codegen/Decision.cpp index 6ac04460a..95ff07146 100644 --- a/lib/codegen/Decision.cpp +++ b/lib/codegen/Decision.cpp @@ -603,6 +603,7 @@ void leaf_node::codegen(decision *d) { d->current_block_ = proof_event(d->definition_, d->module_) .rewrite_event_pre(axiom, arity, vars, subst, d->current_block_); + // maybe report here as part of the rule event whether a tail call happened if (d->profile_matching_) { llvm::CallInst::Create( diff --git a/lib/codegen/ProofEvent.cpp b/lib/codegen/ProofEvent.cpp index 232d91cf9..7a34686dc 100644 --- a/lib/codegen/ProofEvent.cpp +++ b/lib/codegen/ProofEvent.cpp @@ -304,6 +304,27 @@ llvm::CallInst *proof_event::emit_write_pattern_matching_failure( return b.CreateCall(func, {proof_writer, var_function_name}); } +llvm::CallInst *proof_event::emit_write_tail_call_info( + llvm::Value *proof_writer, std::string const &caller_name, + bool is_tail, llvm::BasicBlock *insert_at_end) { + auto b = llvm::IRBuilder(insert_at_end); + + auto *void_ty = llvm::Type::getVoidTy(ctx_); + auto *i8_ptr_ty = llvm::PointerType::getUnqual(ctx_); + auto *i8_ty = llvm::Type::getInt64Ty(ctx_); + + auto *func_ty + = llvm::FunctionType::get(void_ty, {i8_ptr_ty, i8_ptr_ty, i8_ty}, false); + + auto *func = get_or_insert_function( + module_, "write_tail_call_info_to_proof_trace", func_ty); + + auto *var_caller_name + = b.CreateGlobalStringPtr(caller_name, "", 0, module_); + auto *var_is_tail = llvm::ConstantInt::get(i8_ty, is_tail); + return b.CreateCall(func, {proof_writer, var_caller_name, var_is_tail}); +} + llvm::CallInst *proof_event::emit_start_new_chunk( llvm::Value *proof_writer, llvm::BasicBlock *insert_at_end) { auto b = llvm::IRBuilder(insert_at_end); @@ -372,6 +393,27 @@ std::pair proof_event::proof_branch( return {true_block, merge_block}; } +std::pair proof_event::proof_branch( + std::string const &label, llvm::Instruction *insert_before) { + auto *i1_ty = llvm::Type::getInt1Ty(ctx_); + + auto *proof_output_flag = module_->getOrInsertGlobal("proof_output", i1_ty); + auto *proof_output = new llvm::LoadInst( + i1_ty, proof_output_flag, "proof_output", insert_before); + + auto *f = insert_before->getParent()->getParent(); + auto *true_block + = llvm::BasicBlock::Create(ctx_, fmt::format("if_{}", label), f); + auto *merge_block + = llvm::BasicBlock::Create(ctx_, fmt::format("tail_{}", label), f); + + emit_no_op(merge_block); + + llvm::BranchInst::Create( + true_block, merge_block, proof_output, insert_before); + return {true_block, merge_block}; +} + std::tuple proof_event::event_prelude( std::string const &label, llvm::BasicBlock *insert_at_end) { @@ -379,6 +421,13 @@ proof_event::event_prelude( return {true_block, merge_block, emit_get_proof_trace_writer(true_block)}; } +std::tuple +proof_event::event_prelude( + std::string const &label, llvm::Instruction *insert_before) { + auto [true_block, merge_block] = proof_branch(label, insert_before); + return {true_block, merge_block, emit_get_proof_trace_writer(true_block)}; +} + llvm::BasicBlock *proof_event::check_for_emit_new_chunk( llvm::BasicBlock *insert_at_end, llvm::BasicBlock *merge_block) { auto *f = insert_at_end->getParent(); @@ -695,4 +744,29 @@ llvm::BasicBlock *proof_event::pattern_matching_failure( return merge_block; } +llvm::BasicBlock *proof_event::tail_call_info( + std::string const &caller_name, bool is_tail, + llvm::Instruction *insert_before, llvm::BasicBlock *current_block) { + + if (!proof_hint_instrumentation) { + return current_block; + } + + std::tuple prelude; + if (is_tail) { + assert(insert_before); + prelude = event_prelude("tail_call_info", insert_before); + } else { + prelude = event_prelude("tail_call_info", current_block); + } + + auto [true_block, merge_block, proof_writer] = prelude; + + emit_write_tail_call_info(proof_writer, caller_name, is_tail, true_block); + + llvm::BranchInst::Create(merge_block, true_block); + + return merge_block; +} + } // namespace kllvm diff --git a/runtime/util/ConfigurationSerializer.cpp b/runtime/util/ConfigurationSerializer.cpp index ac5af3e23..1072eb353 100644 --- a/runtime/util/ConfigurationSerializer.cpp +++ b/runtime/util/ConfigurationSerializer.cpp @@ -717,6 +717,12 @@ void write_pattern_matching_failure_to_proof_trace( ->pattern_matching_failure(function_name); } +void write_tail_call_info_to_proof_trace( + void *proof_writer, char const *caller_name, bool is_tail) { + static_cast(proof_writer) + ->tail_call_info(caller_name, is_tail); +} + void write_configuration_to_proof_trace( void *proof_writer, block *config, bool is_initial) { static_cast(proof_writer)