diff --git a/CMakeLists.txt b/CMakeLists.txt index 3aa60779..9a07e9b7 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -60,8 +60,11 @@ if (${CMAKE_BUILD_TYPE} MATCHES Generic) endif () if (${CMAKE_BUILD_TYPE} MATCHES Debug) - set (CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -O -g -fsanitize=address") - set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O -g -fsanitize=address") + # Enable debug symbols and ASan with no optimizations + set (CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -O0 -g -fsanitize=address -fno-omit-frame-pointer") + set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O0 -g -fsanitize=address -fno-omit-frame-pointer") + # Ensure that ASan is linked explicitly + set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -fsanitize=address") else() set (CMAKE_C_FLAGS "${OpenMP_C_FLAGS} ${PIC_FLAG} ${EXTRA_FLAGS}") set (CMAKE_CXX_FLAGS "${OpenMP_CXX_FLAGS} ${PIC_FLAG} ${EXTRA_FLAGS}") diff --git a/src/common/progress.hpp b/src/common/progress.hpp index 7e5a0bca..a88762ef 100644 --- a/src/common/progress.hpp +++ b/src/common/progress.hpp @@ -79,7 +79,7 @@ class ProgressMeter { //std::cerr << input_seconds << " seconds is " << days << " days, " << hours << " hours, " << minutes << " minutes, and " << seconds << " seconds." << std::endl; } void increment(const uint64_t& incr) { - completed += incr; + completed.fetch_add(incr, std::memory_order_relaxed); } }; diff --git a/src/common/seqiter.hpp b/src/common/seqiter.hpp index 9aa38c09..5166ee48 100644 --- a/src/common/seqiter.hpp +++ b/src/common/seqiter.hpp @@ -106,7 +106,7 @@ void for_each_seq_in_file( } } -void for_each_seq_in_file( +void for_each_seq_in_faidx_t( faidx_t* fai, const std::vector& seq_names, const std::function& func) { @@ -119,6 +119,15 @@ void for_each_seq_in_file( } } } + +void for_each_seq_in_file( + const std::string& filename, + const std::vector& seq_names, + const std::function& func) { + faidx_t* fai = fai_load(filename.c_str()); + for_each_seq_in_faidx_t(fai, seq_names, func); + fai_destroy(fai); +} void for_each_seq_in_file_filtered( const std::string& filename, @@ -135,23 +144,22 @@ void for_each_seq_in_file_filtered( int num_seqs = faidx_nseq(fai); for (int i = 0; i < num_seqs; i++) { const char* seq_name = faidx_iseq(fai, i); - bool prefix_skip = true; - for (const auto& prefix : query_prefix) { - if (strncmp(seq_name, prefix.c_str(), prefix.size()) == 0) { - prefix_skip = false; - break; - } - } - if (!query_prefix.empty() && prefix_skip) { - continue; + bool keep = false; + for (const auto& prefix : query_prefix) { + if (strncmp(seq_name, prefix.c_str(), prefix.size()) == 0) { + keep = true; + break; + } + } + if (query_list.empty() || query_list.count(seq_name)) { + keep = true; } - if (!query_list.empty() && query_list.count(seq_name) == 0) { - continue; + if (keep) { + query_seq_names.push_back(seq_name); } - query_seq_names.push_back(seq_name); } - for_each_seq_in_file( + for_each_seq_in_faidx_t( fai, query_seq_names, func); diff --git a/src/interface/main.cpp b/src/interface/main.cpp index 8950f308..342de6a9 100644 --- a/src/interface/main.cpp +++ b/src/interface/main.cpp @@ -67,22 +67,10 @@ int main(int argc, char** argv) { std::cerr << "[wfmash::map] Spaced seed sensitivity " << sps.sensitivity << std::endl; } - //Build the sketch for reference - skch::Sketch referSketch(map_parameters); - - std::chrono::duration timeRefSketch = skch::Time::now() - t0; - std::cerr << "[wfmash::map] time spent computing the reference index: " << timeRefSketch.count() << " sec" << std::endl; - - if (referSketch.minmerIndex.size() == 0) - { - std::cerr << "[wfmash::map] ERROR, reference sketch is empty. Reference sequences shorter than the segment length are not indexed" << std::endl; - return 1; - } - //Map the sequences in query file t0 = skch::Time::now(); - skch::Map mapper = skch::Map(map_parameters, referSketch); + skch::Map mapper = skch::Map(map_parameters); std::chrono::duration timeMapQuery = skch::Time::now() - t0; std::cerr << "[wfmash::map] time spent mapping the query: " << timeMapQuery.count() << " sec" << std::endl; diff --git a/src/interface/parse_args.hpp b/src/interface/parse_args.hpp index b5816c61..ca74044e 100644 --- a/src/interface/parse_args.hpp +++ b/src/interface/parse_args.hpp @@ -104,6 +104,7 @@ void parse_args(int argc, args::ValueFlag mashmap_index(mapping_opts, "FILE", "Use MashMap index in FILE, create if it doesn't exist", {"mm-index"}); args::Flag create_mashmap_index_only(mapping_opts, "create-index-only", "Create only the index file without performing mapping", {"create-index-only"}); args::Flag overwrite_mashmap_index(mapping_opts, "overwrite-mm-index", "Overwrite MashMap index if it exists", {"overwrite-mm-index"}); + args::ValueFlag index_by(mapping_opts, "SIZE", "Set the target total size of sequences for each index subset", {"index-by"}); args::Group alignment_opts(parser, "[ Alignment Options ]"); args::ValueFlag align_input_paf(alignment_opts, "FILE", "derive precise alignments for this input PAF", {'i', "input-paf"}); @@ -649,6 +650,17 @@ void parse_args(int argc, map_parameters.overwrite_index = overwrite_mashmap_index; map_parameters.create_index_only = create_mashmap_index_only; + if (index_by) { + const int64_t index_size = wfmash::handy_parameter(args::get(index_by)); + if (index_size <= 0) { + std::cerr << "[wfmash] ERROR, skch::parseandSave, index-by size must be a positive integer." << std::endl; + exit(1); + } + map_parameters.index_by_size = index_size; + } else { + map_parameters.index_by_size = std::numeric_limits::max(); // Default to indexing all sequences + } + if (approx_mapping) { map_parameters.outFileName = "/dev/stdout"; yeet_parameters.approx_mapping = true; diff --git a/src/map/include/base_types.hpp b/src/map/include/base_types.hpp index 605b7560..0a3bab44 100644 --- a/src/map/include/base_types.hpp +++ b/src/map/include/base_types.hpp @@ -93,6 +93,7 @@ namespace skch { std::string name; //Name of the sequence offset_t len; //Length of the sequence + int groupId; //Group ID for the sequence }; //Label tags for strand information @@ -208,13 +209,16 @@ namespace skch typedef std::vector MappingResultsVector_t; + //Vector type for storing MinmerInfo + typedef std::vector MinVec_Type; + //Container to save copy of kseq object struct InputSeqContainer { - seqno_t seqCounter; //sequence counter + seqno_t seqId; //sequence id offset_t len; //sequence length std::string seq; //sequence string - std::string seqName; //sequence id + std::string name; //sequence name /* @@ -223,11 +227,11 @@ namespace skch * @param[in] kseq_id sequence id name * @param[in] len length of sequence */ - InputSeqContainer(const std::string& s, const std::string& id, seqno_t seqcount) - : seqCounter(seqcount) + InputSeqContainer(const std::string& s, const std::string& name, seqno_t id) + : seqId(id) , len(s.length()) , seq(s) - , seqName(id) { } + , name(name) { } }; struct InputSeqProgContainer : InputSeqContainer @@ -242,8 +246,8 @@ namespace skch * @param[in] kseq_id sequence id name * @param[in] len length of sequence */ - InputSeqProgContainer(const std::string& s, const std::string& id, seqno_t seqcount, progress_meter::ProgressMeter& pm) - : InputSeqContainer(s, id, seqcount) + InputSeqProgContainer(const std::string& s, const std::string& name, seqno_t id, progress_meter::ProgressMeter& pm) + : InputSeqContainer(s, name, id) , progress(pm) { } }; @@ -267,7 +271,7 @@ namespace skch struct QueryMetaData { char *seq; //query sequence pointer - seqno_t seqCounter; //query sequence counter + seqno_t seqId; //query sequence id offset_t len; //length of this query sequence offset_t fullLen; //length of the full sequence it derives from int sketchSize; //sketch size diff --git a/src/map/include/computeMap.hpp b/src/map/include/computeMap.hpp index 88a5ed1c..287bb231 100644 --- a/src/map/include/computeMap.hpp +++ b/src/map/include/computeMap.hpp @@ -22,6 +22,11 @@ namespace fs = std::filesystem; #include #include #include +#include +#include +#include +#include +#include "common/atomic_queue/atomic_queue.h" //Own includes #include "map/include/base_types.hpp" @@ -31,7 +36,6 @@ namespace fs = std::filesystem; #include "map/include/map_stats.hpp" #include "map/include/slidingMap.hpp" #include "map/include/MIIteratorL2.hpp" -#include "map/include/ThreadPool.hpp" #include "map/include/filter.hpp" //External includes @@ -48,14 +52,32 @@ namespace fs = std::filesystem; namespace skch { + struct QueryMappingOutput { + std::string queryName; + std::vector results; + std::mutex mutex; + progress_meter::ProgressMeter& progress; + }; + + struct FragmentData { + const char* seq; + int len; + int fullLen; + seqno_t seqId; + std::string seqName; + int refGroup; + int fragmentIndex; + QueryMappingOutput* output; + std::atomic* fragments_processed; + }; + /** * @class skch::Map * @brief L1 and L2 mapping stages */ class Map { - public: - + private: //Type for Stage L1's predicted candidate location struct L1_candidateLocus_t { @@ -88,13 +110,17 @@ namespace skch private: //algorithm parameters - const skch::Parameters ¶m; + skch::Parameters param; //reference sketch - const skch::Sketch &refSketch; + skch::Sketch* refSketch; + + // Sequence ID manager + std::unique_ptr idManager; - //Container type for saving read sketches during L1 and L2 both - typedef Sketch::MI_Type MinVec_Type; + // Vectors to store query and target sequences + std::vector querySequenceNames; + std::vector targetSequenceNames; typedef Sketch::MIIter_t MIIter_t; @@ -110,10 +136,48 @@ namespace skch //for an L1 candidate if the best intersection size is i; std::vector sketchCutoffs; - //Vector for obtaining group from refId - //if refIdGroup[i] == refIdGroup[j], then sequence i and j have the same prefix; - std::vector refIdGroup; + // Sequence ID manager + // Atomic queues for input and output + typedef atomic_queue::AtomicQueue input_atomic_queue_t; + typedef atomic_queue::AtomicQueue merged_mappings_queue_t; + typedef atomic_queue::AtomicQueue output_atomic_queue_t; + + void processFragment(FragmentData* fragment, + std::vector& intervalPoints, + std::vector& l1Mappings, + MappingResultsVector_t& l2Mappings, + QueryMetaData& Q) { + intervalPoints.clear(); + l1Mappings.clear(); + l2Mappings.clear(); + + Q.seq = const_cast(fragment->seq); + Q.len = fragment->len; + Q.fullLen = fragment->fullLen; + Q.seqId = fragment->seqId; + Q.seqName = fragment->seqName; + Q.refGroup = fragment->refGroup; + + mapSingleQueryFrag(Q, intervalPoints, l1Mappings, l2Mappings); + + std::for_each(l2Mappings.begin(), l2Mappings.end(), [&](MappingResult &e){ + e.queryLen = fragment->fullLen; + e.queryStartPos = fragment->fragmentIndex * param.segLength; + e.queryEndPos = e.queryStartPos + fragment->len; + }); + + { + std::lock_guard lock(fragment->output->mutex); + fragment->output->results.insert(fragment->output->results.end(), l2Mappings.begin(), l2Mappings.end()); + } + + // Update progress after processing the fragment + fragment->output->progress.increment(fragment->len); + fragment->fragments_processed->fetch_add(1, std::memory_order_relaxed); + delete fragment; + } + public: /** @@ -122,62 +186,70 @@ namespace skch * @param[in] refSketch reference sketch * @param[in] f optional user defined custom function to post process the reported mapping results */ - Map(const skch::Parameters &p, const skch::Sketch &refsketch, + Map(skch::Parameters p, PostProcessResultsFn_t f = nullptr) : param(p), - refSketch(refsketch), processMappingResults(f), sketchCutoffs(std::min(p.sketchSize, skch::fixed::ss_table_max) + 1, 1), - refIdGroup(refsketch.metadata.size()) - { - if (p.stage1_topANI_filter) { - this->setProbs(); - } - if (p.skip_prefix) - { - this->setRefGroups(); - } - this->mapQuery(); - } - - private: - - // Sets the groups of reference contigs based on prefix - void setRefGroups() - { - int group = 0; - int start_idx = 0; - int idx = 0; - while (start_idx < this->refSketch.metadata.size()) - { - const auto currPrefix = prefix(this->refSketch.metadata[start_idx].name, param.prefix_delim); - idx = start_idx; - while (idx < this->refSketch.metadata.size() - && currPrefix == prefix(this->refSketch.metadata[idx].name, param.prefix_delim)) + idManager(std::make_unique( + p.querySequences, + p.refSequences, + std::vector{p.query_prefix}, + std::vector{p.target_prefix}, + std::string(1, p.prefix_delim), + p.query_list, + p.target_list)) { - this->refIdGroup[idx++] = group; + std::cerr << "Initializing Map with parameters:" << std::endl; + std::cerr << "Query sequences: " << p.querySequences.size() << std::endl; + std::cerr << "Reference sequences: " << p.refSequences.size() << std::endl; + std::cerr << "Query prefix: " << (p.query_prefix.empty() ? "None" : p.query_prefix[0]) << std::endl; + std::cerr << "Target prefix: " << (p.target_prefix.empty() ? "None" : p.target_prefix) << std::endl; + std::cerr << "Prefix delimiter: '" << p.prefix_delim << "'" << std::endl; + std::cerr << "Query list: " << (p.query_list.empty() ? "None" : p.query_list) << std::endl; + std::cerr << "Target list: " << (p.target_list.empty() ? "None" : p.target_list) << std::endl; + + idManager->dumpState(); + + if (p.stage1_topANI_filter) { + this->setProbs(); + } + this->mapQuery(); } - group++; - start_idx = idx; - } - } - // Gets the ref group of a query based on the prefix - int getRefGroup(const std::string& seqName) - { - const auto queryPrefix = prefix(seqName, param.prefix_delim); - for (int i = 0; i < this->refSketch.metadata.size(); i++) - { - const auto currPrefix = prefix(this->refSketch.metadata[i].name, param.prefix_delim); - if (queryPrefix == currPrefix) - { - return this->refIdGroup[i]; + // Removed populateIdManager() function + + ~Map() = default; + + private: + void buildMetadataFromIndex() { + for (const auto& fileName : param.refSequences) { + faidx_t* fai = fai_load(fileName.c_str()); + if (fai == nullptr) { + std::cerr << "Error: Failed to load FASTA index for file " << fileName << std::endl; + exit(1); + } + + int nseq = faidx_nseq(fai); + for (int i = 0; i < nseq; ++i) { + const char* seq_name = faidx_iseq(fai, i); + int seq_len = faidx_seq_len(fai, seq_name); + if (seq_len == -1) { + std::cerr << "Error: Failed to get length for sequence " << seq_name << std::endl; + continue; + } + // Metadata is now handled by idManager, no need to push_back here + } + + fai_destroy(fai); } - } - // Doesn't belong to any ref group - return -1; } - void setProbs() + + public: + + private: + + void setProbs() { float deltaANI = param.ANIDiff; @@ -262,186 +334,274 @@ namespace skch /** * @brief parse over sequences in query file and map each on the reference */ + void reader_thread(input_atomic_queue_t& input_queue, + std::atomic& reader_done, + progress_meter::ProgressMeter& progress, + SequenceIdManager& idManager) { + // Define allowed_query_names here + std::unordered_set allowed_query_names; + if (!param.query_list.empty()) { + std::ifstream filter_list(param.query_list); + std::string name; + while (getline(filter_list, name)) { + allowed_query_names.insert(name); + } + } + + if (!param.querySequences.empty()) { + const auto& fileName = param.querySequences[0]; // Assume single query input file + seqiter::for_each_seq_in_file( + fileName, + querySequenceNames, + [&](const std::string& seq_name, const std::string& seq) { + seqno_t seqId = idManager.getSequenceId(seq_name); + auto input = new InputSeqProgContainer(seq, seq_name, seqId, progress); + while (!input_queue.try_push(input)) { + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + } + }); + } + reader_done.store(true); + } + + typedef atomic_queue::AtomicQueue query_output_atomic_queue_t; + typedef atomic_queue::AtomicQueue fragment_atomic_queue_t; + + void worker_thread(input_atomic_queue_t& input_queue, + fragment_atomic_queue_t& fragment_queue, + merged_mappings_queue_t& merged_queue, + progress_meter::ProgressMeter& progress, + std::atomic& reader_done, + std::atomic& workers_done) { + while (true) { + InputSeqProgContainer* input = nullptr; + if (input_queue.try_pop(input)) { + auto output = mapModule(input, fragment_queue); + //progress.increment(input->len / 4); + while (!merged_queue.try_push(output)) { + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + } + delete input; + } else if (reader_done.load() && input_queue.was_empty()) { + break; + } else { + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + } + } + } + + void writer_thread(query_output_atomic_queue_t& output_queue, + std::atomic& workers_done, + seqno_t& totalReadsMapped, + std::ofstream& outstrm, + progress_meter::ProgressMeter& progress, + MappingResultsVector_t& allReadMappings) { + int wait_count = 0; + while (true) { + QueryMappingOutput* output = nullptr; + if (output_queue.try_pop(output)) { + wait_count = 0; + if(output->results.size() > 0) + totalReadsMapped++; + if (param.filterMode == filter::ONETOONE) { + allReadMappings.insert(allReadMappings.end(), output->results.begin(), output->results.end()); + } else { + reportReadMappings(output->results, output->queryName, outstrm); + } + delete output; + } else if (workers_done.load() && output_queue.was_empty()) { + ++wait_count; + if (wait_count < 5) { + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + } else { + break; + } + } else { + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + } + } + } + + std::vector> createTargetSubsets(const std::vector& targetSequenceNames) { + std::vector> target_subsets; + uint64_t current_subset_size = 0; + std::vector current_subset; + + for (const auto& seqName : targetSequenceNames) { + seqno_t seqId = idManager->getSequenceId(seqName); + offset_t seqLen = idManager->getSequenceLength(seqId); + current_subset.push_back(seqName); + current_subset_size += seqLen; + + if (current_subset_size >= param.index_by_size || &seqName == &targetSequenceNames.back()) { + if (!current_subset.empty()) { + target_subsets.push_back(current_subset); + } + current_subset.clear(); + current_subset_size = 0; + } + } + if (!current_subset.empty()) { + target_subsets.push_back(current_subset); + } + return target_subsets; + } + void mapQuery() { //Count of reads mapped by us //Some reads are dropped because of short length seqno_t totalReadsPickedForMapping = 0; seqno_t totalReadsMapped = 0; - seqno_t seqCounter = 0; std::ofstream outstrm(param.outFileName); - MappingResultsVector_t allReadMappings; //Aggregate mapping results for the complete run - - //Create the thread pool - ThreadPool threadPool( [this](InputSeqProgContainer* e){return mapModule(e);}, param.threads); - - // allowed set of queries - std::unordered_set allowed_query_names; - if (!param.query_list.empty()) { - std::ifstream filter_list(param.query_list); - std::string name; - while (getline(filter_list, name)) { - allowed_query_names.insert(name); - } - } - - // Count the total number of sequences and sequence length - uint64_t total_seqs = 0; - uint64_t total_seq_length = 0; - for (const auto& fileName : param.querySequences) { - // Check if there is a .fai file - std::string fai_name = fileName + ".fai"; - if (fs::exists(fai_name)) { - std::string line; - std::ifstream in(fai_name.c_str()); - while (std::getline(in, line)) { - auto line_split = CommonFunc::split(line, '\t'); - // if we have a param.target_prefix and the sequence name matches, skip it - auto seq_name = line_split[0]; - bool prefix_skip = true; - for (const auto& prefix : param.query_prefix) { - if (seq_name.substr(0, prefix.size()) == prefix) { - prefix_skip = false; - break; - } - } - if (!allowed_query_names.empty() && allowed_query_names.find(seq_name) != allowed_query_names.end() - || !param.query_prefix.empty() && !prefix_skip) { - total_seqs++; - total_seq_length += std::stoul(line_split[1]); - } - } - } else { - // If .fai file doesn't exist, warn and use the for_each_seq_in_file_filtered function - std::cerr << "[mashmap::skch::Map::mapQuery] WARNING, no .fai index found for " << fileName << ", reading the file to filter query sequences (slow)" << std::endl; - seqiter::for_each_seq_in_file_filtered( - fileName, - param.query_prefix, - allowed_query_names, - [&](const std::string& seq_name, const std::string& seq) { - ++total_seqs; - total_seq_length += seq.size(); - }); - } - } - - progress_meter::ProgressMeter progress(total_seq_length, "[mashmap::skch::Map::mapQuery] mapped"); - - for(const auto &fileName : param.querySequences) - { - -#ifdef DEBUG - std::cerr << "[mashmap::skch::Map::mapQuery] mapping reads in " << fileName << std::endl; -#endif - - seqiter::for_each_seq_in_file_filtered( - fileName, - param.query_prefix, - allowed_query_names, - [&](const std::string& seq_name, - const std::string& seq) { - // todo: offset_t is an 32-bit integer, which could cause problems - offset_t len = seq.length(); - if (param.skip_self - && param.target_prefix != "" - && seq_name.substr(0, param.target_prefix.size()) == param.target_prefix) { - // skip - } else { - if (param.filterMode == filter::ONETOONE) - qmetadata.push_back( ContigInfo{seq_name, len} ); - //Is the read too short? - if(len < param.kmerSize) - { -//#ifdef DEBUG - // TODO Should we somehow revert to < windowSize? - std::cerr << std::endl - << "WARNING, skch::Map::mapQuery, read " - << seq_name << " of " << len << "bp " - << " is not long enough for mapping at segment length " - << param.segLength << std::endl; -//#endif - } - else - { - totalReadsPickedForMapping++; - //Dispatch input to thread - threadPool.runWhenThreadAvailable(new InputSeqProgContainer(seq, seq_name, seqCounter, progress)); - - //Collect output if available - while ( threadPool.outputAvailable() ) { - mapModuleHandleOutput(threadPool.popOutputWhenAvailable(), allReadMappings, totalReadsMapped, outstrm, progress); - } - } - //progress.increment(seq.size()/2); - seqCounter++; - } - }); //Finish reading query input file + // Initialize atomic queues and flags + input_atomic_queue_t input_queue; + merged_mappings_queue_t merged_queue; + fragment_atomic_queue_t fragment_queue; + std::atomic reader_done(false); + std::atomic workers_done(false); + std::atomic fragments_done(false); + + this->querySequenceNames = idManager->getQuerySequenceNames(); + this->targetSequenceNames = idManager->getTargetSequenceNames(); + + // Count the total number of sequences and sequence length + uint64_t total_seqs = querySequenceNames.size(); + uint64_t total_seq_length = 0; + for (const auto& seqName : querySequenceNames) { + total_seq_length += idManager->getSequenceLength(idManager->getSequenceId(seqName)); } - //Collect remaining output objects - while ( threadPool.running() ) - mapModuleHandleOutput(threadPool.popOutputWhenAvailable(), allReadMappings, totalReadsMapped, outstrm, progress); + std::vector> target_subsets = createTargetSubsets(targetSequenceNames); - //Filter over reference axis and report the mappings - if (param.filterMode == filter::ONETOONE) - { - // how many secondary mappings to keep - int n_mappings = param.numMappingsForSegment - 1; - - // Group sequences by query prefix, then pass to ref filter - auto subrange_begin = allReadMappings.begin(); - auto subrange_end = allReadMappings.begin(); - MappingResultsVector_t tmpMappings; - MappingResultsVector_t filteredMappings; + std::unordered_map combinedMappings; - while (subrange_end != allReadMappings.end()) - { - if (param.skip_prefix) - { - int currGroup = this->getRefGroup(qmetadata[subrange_begin->querySeqId].name); - subrange_end = std::find_if_not(subrange_begin, allReadMappings.end(), [this, currGroup] (const auto& allReadMappings_candidate) { - return currGroup == this->getRefGroup(this->qmetadata[allReadMappings_candidate.querySeqId].name); - }); + // For each subset of target sequences + uint64_t subset_count = 0; + for (const auto& target_subset : target_subsets) { + std::cerr << "processing subset " << subset_count << " of " << target_subsets.size() << std::endl; + std::cerr << "entries: "; + for (const auto& seqName : target_subset) { + std::cerr << seqName << " "; } - else - { - subrange_end = allReadMappings.end(); + std::cerr << std::endl; + if (target_subset.empty()) { + continue; // Skip empty subsets } - tmpMappings.insert( - tmpMappings.end(), - std::make_move_iterator(subrange_begin), - std::make_move_iterator(subrange_end)); - // tmpMappings now contains mappings from one group of query sequences to all reference groups - // we now run filterByGroup, which filters based on the reference group. - filterByGroup(tmpMappings, filteredMappings, n_mappings, true); - tmpMappings.clear(); - subrange_begin = subrange_end; - } - allReadMappings = std::move(filteredMappings); + // Build index for the current subset + if (!param.indexFilename.empty() && !param.create_index_only) { + // load index from file + std::string indexFilename = param.indexFilename.string() + "." + std::to_string(subset_count); + refSketch = new skch::Sketch(param, *idManager, target_subset, indexFilename); + } else { + refSketch = new skch::Sketch(param, *idManager, target_subset); + } - //Re-sort mappings by input order of query sequences - //This order may be needed for any post analysis of output - std::sort( - allReadMappings.begin(), allReadMappings.end(), - [](const MappingResult &a, const MappingResult &b) { - return std::tie(a.querySeqId, a.queryStartPos, a.refSeqId, a.refStartPos) - < std::tie(b.querySeqId, b.queryStartPos, b.refSeqId, b.refStartPos); - }); + if (param.create_index_only) { + // Save the index to a file + std::string indexFilename = param.indexFilename.string() + "." + std::to_string(subset_count); + refSketch->writeIndex(indexFilename); + std::cerr << "[mashmap::skch::Map::mapQuery] Index created for subset " << subset_count + << " and saved to " << indexFilename << std::endl; + } else { + processSubset(subset_count, target_subsets.size(), total_seq_length, input_queue, merged_queue, + fragment_queue, reader_done, workers_done, fragments_done, combinedMappings); + } + + // Clean up the current refSketch + delete refSketch; + refSketch = nullptr; + ++subset_count; + } - reportReadMappings(allReadMappings, "", outstrm); + if (param.create_index_only) { + std::cerr << "[mashmap::skch::Map::mapQuery] All indices created successfully. Exiting." << std::endl; + exit(0); } - progress.finish(); + // Process combined mappings + for (auto& [querySeqId, mappings] : combinedMappings) { + // Sort mappings by query position, then reference sequence id, then reference position + std::sort( + mappings.begin(), mappings.end(), + [](const MappingResult &a, const MappingResult &b) { + return std::tie(a.queryStartPos, a.refSeqId, a.refStartPos, a.strand) + < std::tie(b.queryStartPos, b.refSeqId, b.refStartPos, b.strand); + } + ); + + std::string queryName = idManager->getSequenceName(querySeqId); + processAggregatedMappings(queryName, mappings, outstrm); + totalReadsMapped += !mappings.empty(); + } std::cerr << "[mashmap::skch::Map::mapQuery] " << "count of mapped reads = " << totalReadsMapped << ", reads qualified for mapping = " << totalReadsPickedForMapping - << ", total input reads = " << seqCounter + << ", total input reads = " << idManager->size() << ", total input bp = " << total_seq_length << std::endl; + } + + void processSubset(uint64_t subset_count, size_t total_subsets, uint64_t total_seq_length, + input_atomic_queue_t& input_queue, merged_mappings_queue_t& merged_queue, + fragment_atomic_queue_t& fragment_queue, std::atomic& reader_done, + std::atomic& workers_done, std::atomic& fragments_done, + std::unordered_map& combinedMappings) + { + progress_meter::ProgressMeter progress( + total_seq_length, + "[mashmap::skch::Map::mapQuery] mapped (" + + std::to_string(subset_count + 1) + "/" + std::to_string(total_subsets) + ")"); + + // Launch reader thread + std::thread reader([&]() { + reader_thread(input_queue, reader_done, progress, *idManager); + }); + + std::vector fragment_workers; + for (int i = 0; i < param.threads; ++i) { + fragment_workers.emplace_back([&]() { + fragment_thread(fragment_queue, fragments_done); + }); + } + + // Launch worker threads + std::vector workers; + for (int i = 0; i < param.threads; ++i) { + workers.emplace_back([&]() { + worker_thread(input_queue, fragment_queue, merged_queue, progress, reader_done, workers_done); + }); + } + + // Launch aggregator thread + std::thread aggregator([&]() { + aggregator_thread(merged_queue, workers_done, combinedMappings); + }); + + // Wait for all threads to complete + reader.join(); + + for (auto& worker : workers) { + worker.join(); + } + workers_done.store(true); + fragments_done.store(true); + + for (auto& worker : fragment_workers) { + worker.join(); + } + aggregator.join(); + + // Reset flags and clear aggregatedMappings for next iteration + reader_done.store(false); + workers_done.store(false); + fragments_done.store(false); + + progress.finish(); } /** @@ -535,7 +695,8 @@ namespace skch MappingResultsVector_t &unfilteredMappings, MappingResultsVector_t &filteredMappings, int n_mappings, - bool filter_ref) + bool filter_ref, + const SequenceIdManager& idManager) { filteredMappings.reserve(unfilteredMappings.size()); @@ -550,9 +711,9 @@ namespace skch { if (param.skip_prefix) { - int currGroup = this->refIdGroup[subrange_begin->refSeqId]; - subrange_end = std::find_if_not(subrange_begin, unfilteredMappings.end(), [this, currGroup] (const auto& unfilteredMappings_candidate) { - return currGroup == this->refIdGroup[unfilteredMappings_candidate.refSeqId]; + int currGroup = idManager.getRefGroup(subrange_begin->refSeqId); + subrange_end = std::find_if_not(subrange_begin, unfilteredMappings.end(), [this, currGroup, &idManager] (const auto& unfilteredMappings_candidate) { + return currGroup == idManager.getRefGroup(unfilteredMappings_candidate.refSeqId); }); } else @@ -567,7 +728,7 @@ namespace skch { return std::tie(a.queryStartPos, a.refSeqId, a.refStartPos) < std::tie(b.queryStartPos, b.refSeqId, b.refStartPos); }); if (filter_ref) { - skch::Filter::ref::filterMappings(tmpMappings, this->refSketch, n_mappings, param.dropRand, param.overlap_threshold); + skch::Filter::ref::filterMappings(tmpMappings, idManager, n_mappings, param.dropRand, param.overlap_threshold); } else { @@ -585,10 +746,8 @@ namespace skch std::sort( filteredMappings.begin(), filteredMappings.end(), [](const MappingResult &a, const MappingResult &b) { - return std::tie(a.queryStartPos, a.refSeqId, a.refStartPos) - < std::tie(b.queryStartPos, b.refSeqId, b.refStartPos); - //return std::tie(a.refSeqId, a.refStartPos, a.queryStartPos) - //< std::tie(b.refSeqId, b.refStartPos, b.queryStartPos); + return std::tie(a.queryStartPos, a.refSeqId, a.refStartPos, a.strand) + < std::tie(b.queryStartPos, b.refSeqId, b.refStartPos, b.strand); }); } @@ -599,143 +758,141 @@ namespace skch * @param[in] input input read details * @return output object containing the mappings */ - MapModuleOutput* mapModule (InputSeqProgContainer* input) - { - MapModuleOutput* output = new MapModuleOutput(); + QueryMappingOutput* mapModule(InputSeqProgContainer* input, + fragment_atomic_queue_t& fragment_queue) { - //save query sequence name and length - output->qseqName = input->seqName; - output->qseqLen = input->len; + QueryMappingOutput* output = new QueryMappingOutput{input->name, {}, {}, input->progress}; + std::atomic fragments_processed{0}; bool split_mapping = true; - std::vector intervalPoints; - // Reserve the "expected" number of interval points - intervalPoints.reserve( - 2 * param.sketchSize * refSketch.minmerIndex.size() / refSketch.minmerPosLookupIndex.size()); - std::vector l1Mappings; - MappingResultsVector_t l2Mappings; - MappingResultsVector_t unfilteredMappings; - int refGroup = this->getRefGroup(input->seqName); - - if (!param.split || input->len <= param.segLength) - { - QueryMetaData Q; - Q.seq = &(input->seq)[0u]; - Q.len = input->len; - Q.fullLen = input->len; - Q.seqCounter = input->seqCounter; - Q.seqName = input->seqName; - Q.refGroup = refGroup; - - // Map this sequence - mapSingleQueryFrag(Q, intervalPoints, l1Mappings, l2Mappings); - unfilteredMappings.insert(unfilteredMappings.end(), l2Mappings.begin(), l2Mappings.end()); - - // Apply non-merged filtering - filterNonMergedMappings(unfilteredMappings, param); - - split_mapping = false; - input->progress.increment(input->len); + int refGroup = this->idManager->getRefGroup(input->seqId); + + std::vector fragments; + int noOverlapFragmentCount = input->len / param.segLength; + + for (int i = 0; i < noOverlapFragmentCount; i++) { + auto fragment = new FragmentData{ + &(input->seq)[0u] + i * param.segLength, + static_cast(param.segLength), + static_cast(input->len), + input->seqId, + input->name, + refGroup, + i, + output, + &fragments_processed + }; + fragments.push_back(fragment); } - else // Split read mapping - { - int noOverlapFragmentCount = input->len / param.segLength; - - // Map individual non-overlapping fragments in the read - for (int i = 0; i < noOverlapFragmentCount; i++) - { - QueryMetaData Q; - Q.seq = &(input->seq)[0u] + i * param.segLength; - Q.len = param.segLength; - Q.fullLen = input->len; - Q.seqCounter = input->seqCounter; - Q.seqName = input->seqName; - Q.refGroup = refGroup; - - intervalPoints.clear(); - l1Mappings.clear(); - l2Mappings.clear(); - - mapSingleQueryFrag(Q, intervalPoints, l1Mappings, l2Mappings); - - std::for_each(l2Mappings.begin(), l2Mappings.end(), [&](MappingResult &e){ - e.queryLen = input->len; - e.queryStartPos = i * param.segLength; - e.queryEndPos = i * param.segLength + Q.len; - }); - - unfilteredMappings.insert(unfilteredMappings.end(), l2Mappings.begin(), l2Mappings.end()); - input->progress.increment(param.segLength); - } - // Map last overlapping fragment to cover the whole read - if (noOverlapFragmentCount >= 1 && input->len % param.segLength != 0) - { - QueryMetaData Q; - Q.seq = &(input->seq)[0u] + input->len - param.segLength; - Q.len = param.segLength; - Q.seqCounter = input->seqCounter; - Q.seqName = input->seqName; - Q.refGroup = refGroup; - - intervalPoints.clear(); - l1Mappings.clear(); - l2Mappings.clear(); - - mapSingleQueryFrag(Q, intervalPoints, l1Mappings, l2Mappings); - - std::for_each(l2Mappings.begin(), l2Mappings.end(), [&](MappingResult &e){ - e.queryLen = input->len; - e.queryStartPos = input->len - param.segLength; - e.queryEndPos = input->len; - }); - - unfilteredMappings.insert(unfilteredMappings.end(), l2Mappings.begin(), l2Mappings.end()); - input->progress.increment(input->len % param.segLength); - } + if (noOverlapFragmentCount >= 1 && input->len % param.segLength != 0) { + auto fragment = new FragmentData{ + &(input->seq)[0u] + input->len - param.segLength, + static_cast(param.segLength), + static_cast(input->len), + input->seqId, + input->name, + refGroup, + noOverlapFragmentCount, + output, + &fragments_processed + }; + fragments.push_back(fragment); + noOverlapFragmentCount++; + } - if (param.mergeMappings) - { - // the maximally merged mappings are top-level chains - // while the unfiltered mappings now contain splits at max_mapping_length - auto maximallyMergedMappings = - mergeMappingsInRange(unfilteredMappings, param.chain_gap); - // we filter on the top level chains - filterMaximallyMerged(maximallyMergedMappings, param); - // collect splitMappingIds in the maximally merged mappings - robin_hood::unordered_set kept_chains; - for (auto &mapping : maximallyMergedMappings) { - kept_chains.insert(mapping.splitMappingId); - } - // and use them to filter mappings to discard - unfilteredMappings.erase( - std::remove_if(unfilteredMappings.begin(), unfilteredMappings.end(), - [&kept_chains](const MappingResult &mapping) { - return !kept_chains.count(mapping.splitMappingId); - }), - unfilteredMappings.end() - ); - } - else - { - filterNonMergedMappings(unfilteredMappings, param); + for (auto& fragment : fragments) { + while (!fragment_queue.try_push(fragment)) { + //std::this_thread::yield(); // too fast + std::this_thread::sleep_for(std::chrono::milliseconds(10)); } } - // Common post-processing for both merged and non-merged mappings - mappingBoundarySanityCheck(input, unfilteredMappings); - - if (param.filterLengthMismatches) - { - filterFalseHighIdentity(unfilteredMappings); + // Wait for all fragments to be processed + while (fragments_processed.load(std::memory_order_relaxed) < noOverlapFragmentCount) { + std::this_thread::sleep_for(std::chrono::milliseconds(10)); } - - sparsifyMappings(unfilteredMappings); - output->readMappings = std::move(unfilteredMappings); + mappingBoundarySanityCheck(input, output->results); return output; } + void fragment_thread(fragment_atomic_queue_t& fragment_queue, + std::atomic& fragments_done) { + std::vector intervalPoints; + std::vector l1Mappings; + MappingResultsVector_t l2Mappings; + QueryMetaData Q; + + while (!fragments_done.load()) { + FragmentData* fragment = nullptr; + if (fragment_queue.try_pop(fragment)) { + if (fragment) { + processFragment(fragment, intervalPoints, l1Mappings, l2Mappings, Q); + } + } else { + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + } + } + } + + void processAggregatedMappings(const std::string& queryName, MappingResultsVector_t& mappings, std::ofstream& outstrm) { + + // XXX we should fix this combined condition + if (param.mergeMappings && param.split) { + auto maximallyMergedMappings = mergeMappingsInRange(mappings, param.chain_gap); + filterMaximallyMerged(maximallyMergedMappings, param); + robin_hood::unordered_set kept_chains; + for (auto &mapping : maximallyMergedMappings) { + kept_chains.insert(mapping.splitMappingId); + } + mappings.erase( + std::remove_if(mappings.begin(), mappings.end(), + [&kept_chains](const MappingResult &mapping) { + return !kept_chains.count(mapping.splitMappingId); + }), + mappings.end()); + } else { + filterNonMergedMappings(mappings, param); + } + + if (param.filterLengthMismatches) { + filterFalseHighIdentity(mappings); + } + + sparsifyMappings(mappings); + + // Apply group filtering aggregated across all targets + if (param.filterMode == filter::MAP || param.filterMode == filter::ONETOONE) { + MappingResultsVector_t filteredMappings; + filterByGroup(mappings, filteredMappings, param.numMappingsForSegment - 1, param.filterMode == filter::ONETOONE, *idManager); + mappings = std::move(filteredMappings); + } + + reportReadMappings(mappings, queryName, outstrm); + } + + void aggregator_thread(merged_mappings_queue_t& merged_queue, + std::atomic& workers_done, + std::unordered_map& combinedMappings) { + while (true) { + QueryMappingOutput* output = nullptr; + if (merged_queue.try_pop(output)) { + seqno_t querySeqId = idManager->getSequenceId(output->queryName); + combinedMappings[querySeqId].insert( + combinedMappings[querySeqId].end(), + output->results.begin(), + output->results.end() + ); + delete output; + } else if (workers_done.load() && merged_queue.was_empty()) { + break; + } else { + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + } + } + } + /** * @brief routine to handle mapModule's output of mappings * @param[in] output mapping output object @@ -747,9 +904,8 @@ namespace skch void mapModuleHandleOutput(MapModuleOutput* output, Vec &allReadMappings, seqno_t &totalReadsMapped, - std::ofstream &outstrm, - progress_meter::ProgressMeter& progress) - { + std::ofstream &outstrm) { + if(output->readMappings.size() > 0) totalReadsMapped++; @@ -764,8 +920,6 @@ namespace skch reportReadMappings(output->readMappings, output->qseqName, outstrm); } - //progress.increment(output->qseqLen/2 + (output->qseqLen % 2 != 0)); - delete output; } @@ -778,7 +932,7 @@ namespace skch { if (param.filterMode == filter::MAP || param.filterMode == filter::ONETOONE) { MappingResultsVector_t filteredMappings; - filterByGroup(readMappings, filteredMappings, param.numMappingsForSegment - 1, false); + filterByGroup(readMappings, filteredMappings, param.numMappingsForSegment - 1, false, *idManager); readMappings = std::move(filteredMappings); } } @@ -812,9 +966,9 @@ namespace skch { if (param.skip_prefix) { - int currGroup = this->refIdGroup[l1_begin->seqId]; + int currGroup = this->idManager->getRefGroup(l1_begin->seqId); l1_end = std::find_if_not(l1_begin, l1Mappings.end(), [this, currGroup] (const auto& candidate) { - return currGroup == this->refIdGroup[candidate.seqId]; + return currGroup == this->idManager->getRefGroup(candidate.seqId); }); } else @@ -842,7 +996,7 @@ namespace skch std::chrono::duration timeSpentL2 = skch::Time::now() - t1; std::chrono::duration timeSpentMappingFragment = skch::Time::now() - t0; - std::cerr << Q.seqCounter << " " << Q.len + std::cerr << Q.seqId << " " << Q.len << " " << timeSpentL1.count() << " " << timeSpentL2.count() << " " << timeSpentMappingFragment.count() @@ -855,7 +1009,7 @@ namespace skch void getSeedHits(Q_Info &Q) { Q.minmerTableQuery.reserve(param.sketchSize + 1); - CommonFunc::sketchSequence(Q.minmerTableQuery, Q.seq, Q.len, param.kmerSize, param.alphabetSize, param.sketchSize, Q.seqCounter); + CommonFunc::sketchSequence(Q.minmerTableQuery, Q.seq, Q.len, param.kmerSize, param.alphabetSize, param.sketchSize, Q.seqId); if(Q.minmerTableQuery.size() == 0) { Q.sketchSize = 0; return; @@ -869,13 +1023,13 @@ namespace skch // TODO remove them from the original sketch instead of removing for each read auto new_end = std::remove_if(Q.minmerTableQuery.begin(), Q.minmerTableQuery.end(), [&](auto& mi) { - return refSketch.isFreqSeed(mi.hash); + return refSketch->isFreqSeed(mi.hash); }); Q.minmerTableQuery.erase(new_end, Q.minmerTableQuery.end()); Q.sketchSize = Q.minmerTableQuery.size(); #ifdef DEBUG - std::cerr << "INFO, skch::Map::getSeedHits, read id " << Q.seqCounter << ", minmer count = " << Q.minmerTableQuery.size() << ", bad minmers = " << orig_len - Q.sketchSize << "\n"; + std::cerr << "INFO, skch::Map::getSeedHits, read id " << Q.seqId << ", minmer count = " << Q.minmerTableQuery.size() << ", bad minmers = " << orig_len - Q.sketchSize << "\n"; #endif } @@ -895,7 +1049,7 @@ namespace skch { #ifdef DEBUG - std::cerr<< "INFO, skch::Map::getSeedHits, read id " << Q.seqCounter << ", minmer count = " << Q.minmerTableQuery.size() << " " << Q.len << "\n"; + std::cerr<< "INFO, skch::Map::getSeedHits, read id " << Q.seqId << ", minmer count = " << Q.minmerTableQuery.size() << " " << Q.len << "\n"; #endif //For invalid query (example : just NNNs), we may be left with 0 sketch size @@ -912,9 +1066,9 @@ namespace skch for(auto it = Q.minmerTableQuery.begin(); it != Q.minmerTableQuery.end(); it++) { //Check if hash value exists in the reference lookup index - const auto seedFind = refSketch.minmerPosLookupIndex.find(it->hash); + const auto seedFind = refSketch->minmerPosLookupIndex.find(it->hash); - if(seedFind != refSketch.minmerPosLookupIndex.end()) + if(seedFind != refSketch->minmerPosLookupIndex.end()) { pq.emplace_back(boundPtr {seedFind->second.cbegin(), seedFind->second.cend()}); } @@ -924,11 +1078,16 @@ namespace skch while(!pq.empty()) { const IP_const_iterator ip_it = pq.front().it; - const auto& ref = this->refSketch.metadata[ip_it->seqId]; + //const auto& ref = this->sketch_metadata[ip_it->seqId]; + const auto& ref_name = this->idManager->getSequenceName(ip_it->seqId); + //const auto& ref_len = this->idManager.getSeqLen(ip_it->seqId); bool skip_mapping = false; - if (param.skip_self && Q.seqName == ref.name) skip_mapping = true; - if (param.skip_prefix && this->refIdGroup[ip_it->seqId] == Q.refGroup) skip_mapping = true; - if (param.lower_triangular && Q.seqCounter <= ip_it->seqId) skip_mapping = true; + int queryGroup = idManager->getRefGroup(Q.seqId); + int targetGroup = idManager->getRefGroup(ip_it->seqId); + + if (param.skip_self && queryGroup == targetGroup) skip_mapping = true; + if (param.skip_prefix && queryGroup == targetGroup) skip_mapping = true; + if (param.lower_triangular && Q.seqId <= ip_it->seqId) skip_mapping = true; if (!skip_mapping) { intervalPoints.push_back(*ip_it); @@ -946,7 +1105,7 @@ namespace skch } #ifdef DEBUG - std::cerr << "INFO, skch::Map:getSeedHits, read id " << Q.seqCounter << ", Count of seed hits in the reference = " << intervalPoints.size() / 2 << "\n"; + std::cerr << "INFO, skch::Map:getSeedHits, read id " << Q.seqId << ", Count of seed hits in the reference = " << intervalPoints.size() / 2 << "\n"; #endif } @@ -960,7 +1119,7 @@ namespace skch Vec2 &l1Mappings) { #ifdef DEBUG - std::cerr << "INFO, skch::Map:computeL1CandidateRegions, read id " << Q.seqCounter << std::endl; + std::cerr << "INFO, skch::Map:computeL1CandidateRegions, read id " << Q.seqId << std::endl; #endif int overlapCount = 0; @@ -1189,9 +1348,9 @@ namespace skch { if (param.skip_prefix) { - int currGroup = this->refIdGroup[ip_begin->seqId]; + int currGroup = this->idManager->getRefGroup(ip_begin->seqId); ip_end = std::find_if_not(ip_begin, intervalPoints.end(), [this, currGroup] (const auto& ip) { - return currGroup == this->refIdGroup[ip.seqId]; + return currGroup == this->idManager->getRefGroup(ip.seqId); }); } else @@ -1205,10 +1364,59 @@ namespace skch } - // helper to get the prefix of a string - const std::string prefix(const std::string& s, const char c) { - //std::cerr << "prefix of " << s << " by " << c << " is " << s.substr(0, s.find_last_of(c)) << std::endl; - return s.substr(0, s.find_last_of(c)); + /** + * @brief Build metadata and reference groups for sequences + * @details Read FAI files, sort sequences, and assign groups + */ + void buildRefGroups() { + std::vector> seqInfoWithIndex; + size_t totalSeqs = 0; + + for (const auto& fileName : param.refSequences) { + std::string faiName = fileName + ".fai"; + std::ifstream faiFile(faiName); + + if (!faiFile.is_open()) { + std::cerr << "Error: Unable to open FAI file: " << faiName << std::endl; + exit(1); + } + + std::string line; + while (std::getline(faiFile, line)) { + std::istringstream iss(line); + std::string seqName; + offset_t seqLength; + iss >> seqName >> seqLength; + + seqInfoWithIndex.emplace_back(seqName, totalSeqs++, seqLength); + } + } + + std::sort(seqInfoWithIndex.begin(), seqInfoWithIndex.end()); + + std::vector refGroups(totalSeqs); + // Removed as sketch_metadata is no longer used + int currentGroup = 0; + std::string prevPrefix; + + for (const auto& [seqName, originalIndex, seqLength] : seqInfoWithIndex) { + std::string currPrefix = seqName.substr(0, seqName.find_last_of(param.prefix_delim)); + + if (currPrefix != prevPrefix) { + currentGroup++; + prevPrefix = currPrefix; + } + + refGroups[originalIndex] = currentGroup; + // Metadata is now handled by idManager, no need to push_back here + } + + // Removed refIdGroup swap as it's no longer needed + + if (totalSeqs == 0) { + std::cerr << "[mashmap::skch::Map::buildRefGroups] ERROR: No sequences indexed!" << std::endl; + exit(1); + } } /** @@ -1256,7 +1464,7 @@ namespace skch //Report the alignment if it passes our identity threshold and, // if we are in all-vs-all mode, it isn't a self-mapping, // and if we are self-mapping, the query is shorter than the target - const auto& ref = this->refSketch.metadata[l2.seqId]; + const auto& ref = this->idManager->getContigInfo(l2.seqId); if((param.keep_low_pct_id && nucIdentityUpperBound >= param.percentageIdentity) || nucIdentity >= param.percentageIdentity) { @@ -1273,7 +1481,7 @@ namespace skch res.queryStartPos = 0; res.queryEndPos = Q.len; res.refSeqId = l2.seqId; - res.querySeqId = Q.seqCounter; + res.querySeqId = Q.seqId; res.nucIdentity = nucIdentity; res.nucIdentityUpperBound = nucIdentityUpperBound; res.sketchSize = Q.sketchSize; @@ -1320,7 +1528,7 @@ namespace skch //std::cerr << "INFO, skch::Map:computeL2MappedRegions, read id " << Q.seqName << "_" << Q.startPos << std::endl; #endif - auto& minmerIndex = refSketch.minmerIndex; + auto& minmerIndex = refSketch->minmerIndex; //candidateLocus.rangeStartPos -= param.segLength; //candidateLocus.rangeEndPos += param.segLength; @@ -1596,10 +1804,9 @@ namespace skch // Apply group filtering if necessary if (param.filterMode == filter::MAP || param.filterMode == filter::ONETOONE) { MappingResultsVector_t groupFilteredMappings; - filterByGroup(readMappings, groupFilteredMappings, param.numMappingsForSegment - 1, false); + filterByGroup(readMappings, groupFilteredMappings, param.numMappingsForSegment - 1, false, *idManager); readMappings = std::move(groupFilteredMappings); } - } /** @@ -1609,17 +1816,15 @@ namespace skch */ template VecIn mergeMappingsInRange(VecIn &readMappings, - int max_dist) { - assert(param.split == true); - - if(readMappings.size() < 2) return readMappings; + int max_dist) { + if (!param.split || readMappings.size() < 2) return readMappings; //Sort the mappings by query position, then reference sequence id, then reference position std::sort( readMappings.begin(), readMappings.end(), [](const MappingResult &a, const MappingResult &b) { - return std::tie(a.queryStartPos, a.refSeqId, a.refStartPos) - < std::tie(b.queryStartPos, b.refSeqId, b.refStartPos); + return std::tie(a.queryStartPos, a.refSeqId, a.refStartPos, a.strand) + < std::tie(b.queryStartPos, b.refSeqId, b.refStartPos, b.strand); }); //First assign a unique id to each split mapping in the sorted order @@ -1696,8 +1901,8 @@ namespace skch readMappings.begin(), readMappings.end(), [](const MappingResult &a, const MappingResult &b) { - return std::tie(a.splitMappingId, a.queryStartPos, a.refSeqId, a.refStartPos) - < std::tie(b.splitMappingId, b.queryStartPos, b.refSeqId, b.refStartPos); + return std::tie(a.splitMappingId, a.queryStartPos, a.refSeqId, a.refStartPos, a.strand) + < std::tie(b.splitMappingId, b.queryStartPos, b.refSeqId, b.refStartPos, b.strand); }); // Create maximallyMergedMappings @@ -1858,16 +2063,16 @@ namespace skch { if(e.refStartPos < 0) e.refStartPos = 0; - if(e.refStartPos >= this->refSketch.metadata[e.refSeqId].len) - e.refStartPos = this->refSketch.metadata[e.refSeqId].len - 1; + if(e.refStartPos >= this->idManager->getSequenceLength(e.refSeqId)) + e.refStartPos = this->idManager->getSequenceLength(e.refSeqId) - 1; } //reference end pos { if(e.refEndPos < e.refStartPos) e.refEndPos = e.refStartPos; - if(e.refEndPos >= this->refSketch.metadata[e.refSeqId].len) - e.refEndPos = this->refSketch.metadata[e.refSeqId].len - 1; + if(e.refEndPos >= this->idManager->getSequenceLength(e.refSeqId)) + e.refEndPos = this->idManager->getSequenceLength(e.refSeqId) - 1; } //query start pos @@ -1900,18 +2105,16 @@ namespace skch //Print the results for(auto &e : readMappings) { - assert(e.refSeqId < this->refSketch.metadata.size()); - float fakeMapQ = e.nucIdentity == 1 ? 255 : std::round(-10.0 * std::log10(1-(e.nucIdentity))); std::string sep = param.legacy_output ? " " : "\t"; - outstrm << (param.filterMode == filter::ONETOONE ? qmetadata[e.querySeqId].name : queryName) + outstrm << (param.filterMode == filter::ONETOONE ? idManager->getSequenceName(e.querySeqId) : queryName) << sep << e.queryLen << sep << e.queryStartPos << sep << e.queryEndPos - (param.legacy_output ? 1 : 0) << sep << (e.strand == strnd::FWD ? "+" : "-") - << sep << this->refSketch.metadata[e.refSeqId].name - << sep << this->refSketch.metadata[e.refSeqId].len + << sep << idManager->getSequenceName(e.refSeqId) + << sep << this->idManager->getSequenceLength(e.refSeqId) << sep << e.refStartPos << sep << e.refEndPos - (param.legacy_output ? 1 : 0); diff --git a/src/map/include/filter.hpp b/src/map/include/filter.hpp index a6fe6ad3..c05e2ee5 100644 --- a/src/map/include/filter.hpp +++ b/src/map/include/filter.hpp @@ -442,13 +442,13 @@ namespace skch * @param[in/out] eventRecord event record containing end point of segment * @param[in] refsketch reference index class object */ - void refPosDoPlusOne(eventRecord_t &eventRecord, const skch::Sketch &refsketch) + void refPosDoPlusOne(eventRecord_t &eventRecord, const skch::SequenceIdManager &idManager) { seqno_t currentSeqId = std::get<0>(eventRecord); offset_t currentSeqOffSet = std::get<1>(eventRecord); //if offset is at the end of reference sequence, shift to next - if(currentSeqOffSet == refsketch.metadata[currentSeqId].len - 1) + if(currentSeqOffSet == idManager.getSequenceLength(currentSeqId) - 1) { std::get<0>(eventRecord) += 1; //shift id by 1 std::get<1>(eventRecord) = 0; @@ -464,7 +464,7 @@ namespace skch * @param[in] refsketch reference index class object, used to determine ref sequence lengths */ template - void filterMappings(VecIn &readMappings, const skch::Sketch &refsketch, uint16_t secondaryToKeep, bool dropRand, double overlapThreshold) + void filterMappings(VecIn &readMappings, const skch::SequenceIdManager &idManager, uint16_t secondaryToKeep, bool dropRand, double overlapThreshold) { if(readMappings.size() <= 1) return; @@ -490,7 +490,7 @@ namespace skch eventRecord_t endEvent = std::make_tuple(readMappings[i].refSeqId, readMappings[i].refEndPos, event::END, i); //add one to above coordinate - obj.refPosDoPlusOne(endEvent, refsketch); + obj.refPosDoPlusOne(endEvent, idManager); eventSchedule.push_back (endEvent); } diff --git a/src/map/include/map_parameters.hpp b/src/map/include/map_parameters.hpp index 8ee61eab..9a014c31 100644 --- a/src/map/include/map_parameters.hpp +++ b/src/map/include/map_parameters.hpp @@ -69,8 +69,8 @@ struct Parameters bool filterLengthMismatches; //true if filtering out length mismatches float kmerComplexityThreshold; //minimum kmer complexity to consider (default 0) - std::string query_list; // file containing list of query sequence names - std::vector query_prefix; // prefix for query sequences to use + std::string query_list; // file containing list of query sequence names + std::vector query_prefix; // prefix for query sequences to use int sketchSize; bool use_spaced_seeds; // @@ -83,6 +83,7 @@ struct Parameters bool legacy_output; //std::unordered_set high_freq_kmers; // + int64_t index_by_size = std::numeric_limits::max(); // Target total size of sequences for each index subset }; diff --git a/src/map/include/sequenceIds.hpp b/src/map/include/sequenceIds.hpp new file mode 100644 index 00000000..05a6b836 --- /dev/null +++ b/src/map/include/sequenceIds.hpp @@ -0,0 +1,246 @@ +#ifndef SEQUENCE_ID_MANAGER_HPP +#define SEQUENCE_ID_MANAGER_HPP + +#include +#include +#include +#include +#include +#include +#include +#include +#include "base_types.hpp" + +namespace skch { + +class SequenceIdManager { +private: + std::unordered_map sequenceNameToId; + std::vector metadata; + std::vector querySequenceNames; + std::vector targetSequenceNames; + std::vector allPrefixes; + std::string prefixDelim; + +public: + SequenceIdManager(const std::vector& queryFiles, + const std::vector& targetFiles, + const std::vector& queryPrefixes, + const std::vector& targetPrefixes, + const std::string& prefixDelim, + const std::string& queryList = "", + const std::string& targetList = "") + : prefixDelim(prefixDelim) { + allPrefixes = queryPrefixes; + allPrefixes.insert(allPrefixes.end(), targetPrefixes.begin(), targetPrefixes.end()); + populateFromFiles(queryFiles, targetFiles, queryPrefixes, targetPrefixes, prefixDelim, queryList, targetList); + buildRefGroups(); + dumpState(); // Add this line to dump the state after initialization + } + + // Add this method to dump the state of SequenceIdManager + void dumpState() const { + std::cerr << "SequenceIdManager State:" << std::endl; + std::cerr << "Total sequences: " << metadata.size() << std::endl; + std::cerr << "Query sequences: " << querySequenceNames.size() << std::endl; + std::cerr << "Target sequences: " << targetSequenceNames.size() << std::endl; + std::cerr << "\nSequence details:" << std::endl; + for (size_t i = 0; i < metadata.size(); ++i) { + std::cerr << "ID: " << i + << ", Name: " << metadata[i].name + << ", Length: " << metadata[i].len + << ", Group: " << metadata[i].groupId + << ", Type: " << (std::find(querySequenceNames.begin(), querySequenceNames.end(), metadata[i].name) != querySequenceNames.end() ? "Query" : "Target") + << std::endl; + } + std::cerr << "\nGroup details:" << std::endl; + std::unordered_map> groupToSequences; + for (const auto& info : metadata) { + groupToSequences[info.groupId].push_back(info.name); + } + for (const auto& [groupId, sequences] : groupToSequences) { + std::cerr << "Group " << groupId << ": " << sequences.size() << " sequences" << std::endl; + for (const auto& seq : sequences) { + std::cerr << " " << seq << std::endl; + } + } + } + + seqno_t getSequenceId(const std::string& sequenceName) const { + auto it = sequenceNameToId.find(sequenceName); + if (it != sequenceNameToId.end()) { + return it->second; + } + throw std::runtime_error("Sequence name not found: " + sequenceName); + } + + const ContigInfo& getContigInfo(seqno_t id) const { + if (id < static_cast(metadata.size())) { + return metadata[id]; + } + throw std::runtime_error("Invalid sequence ID: " + std::to_string(id)); + } + + const std::string& getSequenceName(seqno_t id) const { + return getContigInfo(id).name; + } + + const offset_t& getSequenceLength(seqno_t id) const { + return getContigInfo(id).len; + } + + size_t size() const { + return metadata.size(); + } + + const std::vector& getMetadata() const { + return metadata; + } + + const std::vector& getQuerySequenceNames() const { return querySequenceNames; } + const std::vector& getTargetSequenceNames() const { return targetSequenceNames; } + + int getRefGroup(seqno_t seqId) const { + if (seqId < metadata.size()) { + return metadata[seqId].groupId; + } + throw std::runtime_error("Invalid sequence ID: " + std::to_string(seqId)); + } + +private: + void buildRefGroups() { + std::vector> seqInfoWithIndex; + size_t totalSeqs = metadata.size(); + + for (size_t i = 0; i < totalSeqs; ++i) { + seqInfoWithIndex.emplace_back(metadata[i].name, i); + } + + std::sort(seqInfoWithIndex.begin(), seqInfoWithIndex.end()); + + int currentGroup = 0; + std::unordered_map groupMap; + + for (const auto& [seqName, originalIndex] : seqInfoWithIndex) { + std::string groupKey; + + if (!allPrefixes.empty()) { + // Check if the sequence matches any of the specified prefixes + auto it = std::find_if(allPrefixes.begin(), allPrefixes.end(), + [&seqName](const std::string& prefix) { return seqName.compare(0, prefix.length(), prefix) == 0; }); + + if (it != allPrefixes.end()) { + groupKey = *it; + } + } + + if (groupKey.empty() && !prefixDelim.empty()) { + // Use prefix before delimiter as group key + size_t pos = seqName.find(prefixDelim); + if (pos != std::string::npos) { + groupKey = seqName.substr(0, pos); + } + } + + if (groupKey.empty()) { + // If no group key found, use the sequence name itself + groupKey = seqName; + } + + if (groupMap.find(groupKey) == groupMap.end()) { + groupMap[groupKey] = ++currentGroup; + } + metadata[originalIndex].groupId = groupMap[groupKey]; + } + + if (totalSeqs == 0) { + std::cerr << "[SequenceIdManager::buildRefGroups] ERROR: No sequences indexed!" << std::endl; + exit(1); + } + } + + std::string getPrefix(const std::string& s) const { + if (!prefixDelim.empty()) { + size_t pos = s.find(prefixDelim); + return (pos != std::string::npos) ? s.substr(0, pos) : s; + } + return s; + } + + void populateFromFiles(const std::vector& queryFiles, + const std::vector& targetFiles, + const std::vector& queryPrefixes, + const std::vector& targetPrefixes, + const std::string& prefixDelim, + const std::string& queryList, + const std::string& targetList) { + std::unordered_set allowedQueryNames; + std::unordered_set allowedTargetNames; + + if (!queryList.empty()) readAllowedNames(queryList, allowedQueryNames); + if (!targetList.empty()) readAllowedNames(targetList, allowedTargetNames); + + for (const auto& file : queryFiles) { + readFAI(file, queryPrefixes, prefixDelim, allowedQueryNames, true); + } + for (const auto& file : targetFiles) { + readFAI(file, targetPrefixes, prefixDelim, allowedTargetNames, false); + } + } + + void readAllowedNames(const std::string& listFile, std::unordered_set& allowedNames) { + std::ifstream file(listFile); + std::string name; + while (std::getline(file, name)) { + allowedNames.insert(name); + } + } + + void readFAI(const std::string& fileName, + const std::vector& prefixes, + const std::string& prefixDelim, + const std::unordered_set& allowedNames, + bool isQuery) { + std::string faiName = fileName + ".fai"; + std::ifstream faiFile(faiName); + if (!faiFile.is_open()) { + std::cerr << "Error: Unable to open FAI file: " << faiName << std::endl; + exit(1); + } + + std::string line; + while (std::getline(faiFile, line)) { + std::istringstream iss(line); + std::string seqName; + offset_t seqLength; + iss >> seqName >> seqLength; + + bool prefixMatch = prefixes.empty() || std::any_of(prefixes.begin(), prefixes.end(), + [&](const std::string& prefix) { return seqName.compare(0, prefix.size(), prefix) == 0; }); + + if (prefixMatch && (allowedNames.empty() || allowedNames.find(seqName) != allowedNames.end())) { + seqno_t seqId = addSequence(seqName, seqLength); + if (isQuery) { + querySequenceNames.push_back(seqName); + } else { + targetSequenceNames.push_back(seqName); + } + } + } + } + + seqno_t addSequence(const std::string& sequenceName, offset_t length) { + auto it = sequenceNameToId.find(sequenceName); + if (it != sequenceNameToId.end()) { + return it->second; + } + seqno_t newId = metadata.size(); + sequenceNameToId[sequenceName] = newId; + metadata.push_back(ContigInfo{sequenceName, length}); + return newId; + } +}; + +} // namespace skch + +#endif // SEQUENCE_ID_MANAGER_HPP diff --git a/src/map/include/winSketch.hpp b/src/map/include/winSketch.hpp index 1777d58c..3dd0302f 100644 --- a/src/map/include/winSketch.hpp +++ b/src/map/include/winSketch.hpp @@ -37,6 +37,12 @@ namespace fs = std::filesystem; #include "common/ankerl/unordered_dense.hpp" #include "common/seqiter.hpp" +#include "common/atomic_queue/atomic_queue.h" +#include "sequenceIds.hpp" +#include "common/atomic_queue/atomic_queue.h" +#include "common/progress.hpp" +#include +#include //#include "assert.hpp" @@ -59,7 +65,7 @@ namespace skch //private members //algorithm parameters - const skch::Parameters ¶m; + skch::Parameters param; //Minmers that occur this or more times will be ignored (computed based on percentageThreshold) uint64_t freqThreshold = std::numeric_limits::max(); @@ -67,11 +73,15 @@ namespace skch //Set of frequent seeds to be ignored ankerl::unordered_dense::set frequentSeeds; - //Make the default constructor private, non-accessible - Sketch(); + //Make the default constructor protected, non-accessible + protected: + Sketch(SequenceIdManager& idMgr) : idManager(idMgr) {} public: + //Flag to indicate if the Sketch is fully initialized + bool isInitialized = false; + using MI_Type = std::vector< MinmerInfo >; using MIIter_t = MI_Type::const_iterator; using HF_Map_t = ankerl::unordered_dense::map; @@ -79,17 +89,8 @@ namespace skch // Frequency of each hash HF_Map_t hashFreq; - //Keep sequence length, name that appear in the sequence (for printing the mappings later) - std::vector< ContigInfo > metadata; - - /* - * Keep the information of what sequences come from what file# - * Example [a, b, c] implies - * file 0 contains 0 .. a-1 sequences - * file 1 contains a .. b-1 - * file 2 contains b .. c-1 - */ - std::vector< int > sequencesByFileInfo; + public: + uint64_t total_seq_length = 0; //Index for fast seed lookup (unordered_map) /* @@ -105,6 +106,12 @@ namespace skch MI_Map_t minmerPosLookupIndex; MI_Type minmerIndex; + // Atomic queues for input and output + using input_queue_t = atomic_queue::AtomicQueue; + using output_queue_t = atomic_queue::AtomicQueue*, 1024>; + input_queue_t input_queue; + output_queue_t output_queue; + private: /** @@ -117,42 +124,158 @@ namespace skch //[... ,x -> y, ...] implies y number of minmers occur x times std::map minmerFreqHistogram; + // Instance of the SequenceIdManager + SequenceIdManager& idManager; + public: /** * @brief constructor * also builds, indexes the minmer table */ - Sketch(const skch::Parameters &p) - : - param(p) { - if (param.indexFilename.empty() - || !stdfs::exists(param.indexFilename) - || param.overwrite_index) - { - this->build(true); - this->computeFreqHist(); - this->computeFreqSeedSet(); - this->dropFreqSeedSet(); - this->hashFreq.clear(); - if (!param.indexFilename.empty()) - { - this->writeIndex(); + Sketch(skch::Parameters p, + SequenceIdManager& idMgr, + const std::vector& targets = {}, + const std::string& indexFilename = "") + : param(std::move(p)), + idManager(idMgr) + { + if (!indexFilename.empty()) { + loadIndex(indexFilename); + } else { + initialize(targets); + } + } + + void loadIndex(const std::string& indexFilename) { + std::ifstream inStream(indexFilename, std::ios::binary); + if (!inStream) { + std::cerr << "Error: Unable to open index file: " << indexFilename << std::endl; + exit(1); + } + readParameters(inStream); + readSketchBinary(inStream); + readPosListBinary(inStream); + readFreqKmersBinary(inStream); + inStream.close(); + isInitialized = true; + std::cerr << "[mashmap::skch::Sketch] Sketch loaded from index file: " << indexFilename << std::endl; + } + + public: + void initialize(const std::vector& targets = {}) { + std::cerr << "[mashmap::skch::Sketch] Initializing Sketch..." << std::endl; + + // Calculate total sequence length + /* + for (const auto& fileName : param.refSequences) { + std::cerr << "targets are " << targets.size() << " "; + for (const auto& target : targets) { + std::cerr << target << " "; + } + std::cerr << std::endl; + seqiter::for_each_seq_in_file( + fileName, + targets, + [&](const std::string& seq_name, const std::string& seq) { + total_seq_length += seq.length(); + }); + } + */ + + this->build(true, targets); + this->computeFreqHist(); + this->computeFreqSeedSet(); + this->dropFreqSeedSet(); + this->hashFreq.clear(); + if (!param.indexFilename.empty()) + { + this->writeIndex(); + } + + std::cerr << "[mashmap::skch::Sketch] Unique minmer hashes after pruning = " << (minmerPosLookupIndex.size() - this->frequentSeeds.size()) << std::endl; + std::cerr << "[mashmap::skch::Sketch] Total minmer windows after pruning = " << minmerIndex.size() << std::endl; + std::cerr << "[mashmap::skch::Sketch] Number of sequences = " << idManager.size() << std::endl; + isInitialized = true; + std::cerr << "[mashmap::skch::Sketch] Sketch initialization complete." << std::endl; + } + + private: + void reader_thread(const std::vector& targets, std::atomic& reader_done) { + for (const auto& fileName : param.refSequences) { + seqiter::for_each_seq_in_file( + fileName, + targets, + [&](const std::string& seq_name, const std::string& seq) { + if (seq.length() >= param.segLength) { + seqno_t seqId = idManager.getSequenceId(seq_name); + auto record = new InputSeqContainer(seq, seq_name, seqId); + input_queue.push(record); + } + // We don't update progress here anymore + }); + } + reader_done.store(true); + } + + void worker_thread(std::atomic& reader_done, progress_meter::ProgressMeter& progress) { + while (true) { + InputSeqContainer* record = nullptr; + if (input_queue.try_pop(record)) { + auto minmers = new MI_Type(); + CommonFunc::addMinmers(*minmers, &(record->seq[0]), record->len, + param.kmerSize, param.segLength, param.alphabetSize, + param.sketchSize, record->seqId); + auto output_pair = new std::pair(record->len, minmers); + output_queue.push(output_pair); + delete record; + } else if (reader_done.load() && input_queue.was_empty()) { + break; + } else { + std::this_thread::sleep_for(std::chrono::milliseconds(10)); } - if (param.create_index_only) - { - std::cerr << "[mashmap::skch::Sketch] Index created successfully. Exiting." << std::endl; - exit(0); + } + } + + void writer_thread(std::atomic& workers_done, progress_meter::ProgressMeter& progress) { + while (true) { + std::pair* output = nullptr; + if (output_queue.try_pop(output)) { + uint64_t seq_length = output->first; + MI_Type* minmers = output->second; + for (const auto& mi : *minmers) { + this->hashFreq[mi.hash]++; + if (minmerPosLookupIndex[mi.hash].size() == 0 + || minmerPosLookupIndex[mi.hash].back().hash != mi.hash + || minmerPosLookupIndex[mi.hash].back().pos != mi.wpos) + { + minmerPosLookupIndex[mi.hash].push_back(IntervalPoint {mi.wpos, mi.hash, mi.seqId, side::OPEN}); + minmerPosLookupIndex[mi.hash].push_back(IntervalPoint {mi.wpos_end, mi.hash, mi.seqId, side::CLOSE}); + } else { + minmerPosLookupIndex[mi.hash].back().pos = mi.wpos_end; + } + } + //this->minmerIndex.insert(this->minmerIndex.end(), minmers->begin(), minmers->end()); + this->minmerIndex.insert( + this->minmerIndex.end(), + std::make_move_iterator(minmers->begin()), + std::make_move_iterator(minmers->end())); + + // Update progress meter + progress.increment(seq_length); + + delete output->second; + delete output; + } else if (workers_done.load() && output_queue.was_empty()) { + break; + } else { + std::this_thread::sleep_for(std::chrono::milliseconds(10)); } - } else { - this->build(false); - this->readIndex(); - } - std::cerr << "[mashmap::skch::Sketch] Unique minmer hashes after pruning = " << (minmerPosLookupIndex.size() - this->frequentSeeds.size()) << std::endl; - std::cerr << "[mashmap::skch::Sketch] Total minmer windows after pruning = " << minmerIndex.size() << std::endl; } - private: + // Finalize progress meter + progress.finish(); + } /** * @brief Get sequence metadata and optionally build the sketch table @@ -160,84 +283,77 @@ namespace skch * @details Iterate through ref sequences to get metadata and * optionally compute and save minmers from the reference sequence(s) * assuming a fixed window size + * @param compute_seeds Whether to compute seeds or just collect metadata + * @param target_ids Set of target sequence IDs to sketch over */ - void build(bool compute_seeds) + void build(bool compute_seeds, const std::vector& target_names = {}) { + std::chrono::time_point t0 = skch::Time::now(); - // allowed set of targets - std::unordered_set allowed_target_names; - if (!param.target_list.empty()) { - std::ifstream filter_list(param.target_list); - std::string name; - while (getline(filter_list, name)) { - allowed_target_names.insert(name); - } - } - - - //sequence counter while parsing file - seqno_t seqCounter = 0; - - //Create the thread pool - ThreadPool threadPool( [this](InputSeqContainer* e) {return buildHelper(e);}, param.threads); - - for(const auto &fileName : param.refSequences) - { - -#ifdef DEBUG - std::cerr << "[mashmap::skch::Sketch::build] building minmer index for " << fileName << std::endl; -#endif - - seqiter::for_each_seq_in_file( - fileName, - allowed_target_names, - param.target_prefix, - [&](const std::string& seq_name, - const std::string& seq) { - // todo: offset_t is an 32-bit integer, which could cause problems - offset_t len = seq.length(); - - //Save the sequence name - metadata.push_back( ContigInfo{seq_name, len} ); - - //Is the sequence too short? - if(len < param.kmerSize) - { -#ifdef DEBUG - std::cerr << "WARNING, skch::Sketch::build, found an unusually short sequence relative to kmer" << std::endl; -#endif - } - else - { - if (compute_seeds) { - threadPool.runWhenThreadAvailable(new InputSeqContainer(seq, seq_name, seqCounter)); + if (compute_seeds) { + std::cerr << "creating seeds" << std::endl; + + //Create the thread pool + ThreadPool threadPool([this](InputSeqContainer* e) { return buildHelper(e); }, param.threads); + + size_t totalSeqProcessed = 0; + size_t totalSeqSkipped = 0; + size_t shortestSeqLength = std::numeric_limits::max(); + for (const auto& fileName : param.refSequences) { + std::cerr << "[mashmap::skch::Sketch::build] Processing file: " << fileName << std::endl; + + seqiter::for_each_seq_in_file( + fileName, + target_names, + [&](const std::string& seq_name, const std::string& seq) { + std::cerr << "on sequence " << seq_name << std::endl; + if (seq.length() >= param.segLength) { + seqno_t seqId = idManager.getSequenceId(seq_name); + threadPool.runWhenThreadAvailable(new InputSeqContainer(seq, seq_name, seqId)); + totalSeqProcessed++; + shortestSeqLength = std::min(shortestSeqLength, seq.length()); + std::cerr << "DEBUG: Processing sequence: " << seq_name << " (length: " << seq.length() << ")" << std::endl; + + //Collect output if available + while (threadPool.outputAvailable()) { + this->buildHandleThreadOutput(threadPool.popOutputWhenAvailable()); + } - //Collect output if available - while ( threadPool.outputAvailable() ) - this->buildHandleThreadOutput(threadPool.popOutputWhenAvailable()); + // Update metadata + // Metadata is now handled by idManager, no need to push_back here + } else { + totalSeqSkipped++; + std::cerr << "WARNING, skch::Sketch::build, skipping short sequence: " << seq_name << " (length: " << seq.length() << ")" << std::endl; } - } - seqCounter++; - }); - - sequencesByFileInfo.push_back(seqCounter); - } - - if (seqCounter == 0) - { - std::cerr << "[mashmap::skch::Sketch::build] ERROR: No sequences indexed!" << std::endl; - exit(1); - } + }); + } + + // Update sequencesByFileInfo + // Removed as sequencesByFileInfo is no longer used + std::cerr << "[mashmap::skch::Sketch::build] Shortest sequence length: " << shortestSeqLength << std::endl; - if (compute_seeds) { //Collect remaining output objects - while ( threadPool.running() ) + while (threadPool.running()) this->buildHandleThreadOutput(threadPool.popOutputWhenAvailable()); + + std::cerr << "[mashmap::skch::Sketch::build] Total sequences processed: " << totalSeqProcessed << std::endl; + std::cerr << "[mashmap::skch::Sketch::build] Total sequences skipped: " << totalSeqSkipped << std::endl; std::cerr << "[mashmap::skch::Sketch::build] Unique minmer hashes before pruning = " << minmerPosLookupIndex.size() << std::endl; std::cerr << "[mashmap::skch::Sketch::build] Total minmer windows before pruning = " << minmerIndex.size() << std::endl; } + + std::chrono::duration timeRefSketch = skch::Time::now() - t0; + std::cerr << "[mashmap::skch::Sketch::build] time spent computing the reference index: " << timeRefSketch.count() << " sec" << std::endl; + + if (this->minmerIndex.size() == 0) + { + std::cerr << "[mashmap::skch::Sketch::build] ERROR, reference sketch is empty. Reference sequences shorter than the kmer size are not indexed" << std::endl; + exit(1); + } } + public: + /** * @brief function to compute minmers given input sequence object * @details this function is run in parallel by multiple threads @@ -257,7 +373,7 @@ namespace skch param.segLength, param.alphabetSize, param.sketchSize, - input->seqCounter); + input->seqId); return thread_output; } @@ -366,9 +482,9 @@ namespace skch /** * @brief Write all index data structures to disk */ - void writeIndex() + void writeIndex(const std::string& filename = "") { - fs::path freqListFilename = fs::path(param.indexFilename); + fs::path freqListFilename = filename.empty() ? fs::path(param.indexFilename) : fs::path(filename); std::ofstream outStream; outStream.open(freqListFilename, std::ios::binary); @@ -594,6 +710,16 @@ namespace skch return frequentSeeds.find(h) != frequentSeeds.end(); } + void clear() + { + hashFreq.clear(); + minmerPosLookupIndex.clear(); + minmerIndex.clear(); + minmerFreqHistogram.clear(); + frequentSeeds.clear(); + freqThreshold = std::numeric_limits::max(); + } + }; //End of class Sketch } //End of namespace skch