From c8fe57d48eba5c0315302e97427ba3c3bf16efea Mon Sep 17 00:00:00 2001 From: Ethan Steinberg Date: Sun, 11 Aug 2024 21:44:33 -0700 Subject: [PATCH] Sorting working --- build_helper.sh | 2 +- native/binary_version.hh | 2 +- native/create_database.cc | 1488 ++++++++++++++++------------------- src/meds_reader/__init__.py | 27 +- tests/api_test.py | 3 - 5 files changed, 666 insertions(+), 856 deletions(-) diff --git a/build_helper.sh b/build_helper.sh index 89c051d..80b7ea1 100755 --- a/build_helper.sh +++ b/build_helper.sh @@ -1,7 +1,7 @@ #!/bin/sh cd native -bazel build -c opt _meds_reader.so meds_reader_convert +bazel build -c dbg _meds_reader.so meds_reader_convert cd .. rm -f src/meds_reader/_meds_reader* src/meds_reader/meds_reader_convert* diff --git a/native/binary_version.hh b/native/binary_version.hh index 56a3392..f4d9fcc 100644 --- a/native/binary_version.hh +++ b/native/binary_version.hh @@ -1,3 +1,3 @@ #pragma once -const int CURRENT_BINARY_VERSION = 1; \ No newline at end of file +const int CURRENT_BINARY_VERSION = 2; \ No newline at end of file diff --git a/native/create_database.cc b/native/create_database.cc index 55753ff..7d8f30d 100644 --- a/native/create_database.cc +++ b/native/create_database.cc @@ -8,9 +8,11 @@ #include #include +#include #include #include #include +#include #include #include #include @@ -212,6 +214,214 @@ class ZstdRowReader { size_t uncompressed_size; }; +template +struct CappedQueue { + CappedQueue(int num_threads) + : queues(num_threads), semaphore(QUEUE_SIZE * num_threads) {} + + std::vector>> queues; + moodycamel::LightweightSemaphore semaphore; +}; + +template +struct CappedQueueSender { + CappedQueueSender(CappedQueue& q) + : queue(q), num_threads(q.queues.size()) { + slots_to_write = queue.semaphore.waitMany(SEMAPHORE_BLOCK_SIZE); + } + + void send_item(int target_thread_id, T&& item) { + if (slots_to_write == 0) { + slots_to_write = queue.semaphore.waitMany(SEMAPHORE_BLOCK_SIZE); + } + slots_to_write--; + queue.queues[target_thread_id].enqueue({std::move(item)}); + } + + ~CappedQueueSender() { + for (auto& queue : queue.queues) { + queue.enqueue(std::nullopt); + } + queue.semaphore.signal(slots_to_write); + } + + CappedQueue& queue; + ssize_t slots_to_write; + int num_threads; +}; + +template +struct CappedQueueReceiver { + CappedQueueReceiver(CappedQueue& q, int tid) + : queue(q), + thread_id(tid), + num_senders_remaining(queue.queues.size()), + num_read(0), + c_tok(q.queues[tid]), + num_threads(q.queues.size()) {} + + bool get_item(T& item) { + if (num_read == SEMAPHORE_BLOCK_SIZE) { + queue.semaphore.signal(num_read); + num_read = 0; + } + std::optional entry; + + while (!entry) { + queue.queues[thread_id].wait_dequeue(c_tok, entry); + if (!entry) { + num_senders_remaining--; + if (num_senders_remaining == 0) { + return false; + } + } else { + item = std::move(*entry); + num_read++; + return true; + } + } + + abort(); + } + + ~CappedQueueReceiver() { queue.semaphore.signal(num_read); } + + CappedQueue& queue; + int thread_id; + int num_senders_remaining; + int num_read; + moodycamel::ConsumerToken c_tok; + int num_threads; +}; + +struct SharedFile { + SharedFile(const std::filesystem::path& path, int nt) + : num_threads(nt), + cvs(num_threads), + file(path, std::ios_base::out | std::ios_base::binary | + std::ios_base::trunc), + next_shard(0) {} + + template + void run_with_file(size_t requested_shard, F func) { + size_t thread_index = requested_shard % num_threads; + std::unique_lock lock(mutex); + while (next_shard != requested_shard) { + cvs[thread_index].wait(lock); + } + + func(file); + next_shard++; + cvs[next_shard % num_threads].notify_one(); + } + + int num_threads; + std::mutex mutex; + std::vector cvs; + + std::ofstream file; + size_t next_shard; +}; + +void sort_concatenate_shards(int i, const std::filesystem::path& root_path, + SharedFile& data_file, int num_patients_per_shard, + int num_shards_per_thread) { + for (int j = 0; j < num_shards_per_thread; j++) { + int shard = i + j * data_file.num_threads; + + std::filesystem::path shard_path = + root_path / (std::to_string(shard) + ".dat"); + + if (!std::filesystem::exists(shard_path)) { + throw std::runtime_error("Missing shard? " + + std::string(shard_path)); + } + + { + MmapFile shard_file(shard_path); + + std::vector> entries; + + const char* pointer = shard_file.bytes().begin(); + + while (pointer != shard_file.bytes().end()) { + const uint32_t* header = (const uint32_t*)pointer; + uint32_t offset = header[0]; + uint32_t size = header[1]; + pointer += sizeof(uint32_t) * 2; + entries.emplace_back(offset, std::string_view(pointer, size)); + pointer += size; + } + + pdqsort_branchless(std::begin(entries), std::end(entries)); + + data_file.run_with_file(shard, [&](std::ofstream& file) { + if (entries.size() > 0) { + uint64_t offset = (uint64_t)file.tellp(); + std::vector offsets; + offsets.reserve(entries.size()); + + for (const auto& entry : entries) { + offset += entry.second.size(); + offsets.emplace_back(offset); + + file.write(entry.second.data(), entry.second.size()); + } + + file.seekp((entries[0].first + 1) * sizeof(uint64_t)); + file.write((const char*)offsets.data(), + offsets.size() * sizeof(uint64_t)); + file.seekp(offset); + } + }); + } + + std::filesystem::remove(shard_path); + } +} + +size_t get_num_shards(int num_threads, size_t estimated_size) { + int max_size_per_shard = 2000000000; + int num_shards = (estimated_size + max_size_per_shard - 1) / max_size_per_shard; + + num_shards = 30; + + if (num_shards < num_threads) { + num_shards = num_threads; + } + + return num_shards; +} + +void write_files( + int thread_index, const std::filesystem::path& root_path, + int num_patients_per_shard, int shards_per_thread, + CappedQueueReceiver>>& receiver) { + std::pair> entry; + + std::vector shard_files; + + for (int i = 0; i < shards_per_thread; i++) { + int shard_index = i + thread_index * shards_per_thread; + shard_files.emplace_back(root_path / + (std::to_string(shard_index) + ".dat")); + } + + while (true) { + if (!receiver.get_item(entry)) { + return; + } + + int shard = entry.first / num_patients_per_shard; + int shard_offset = shard % shards_per_thread; + + uint32_t header[2] = {entry.first, (uint32_t)entry.second.size()}; + shard_files[shard_offset].write((const char*)&header, sizeof(header)); + shard_files[shard_offset].write(entry.second.data(), + entry.second.size()); + } +} + std::map, int64_t>> get_properties(const parquet::arrow::SchemaManifest& manifest) { std::map, int64_t>> @@ -333,7 +543,7 @@ std::set known_properties = {"code", "numeric_value"}; template void iterate_strings_helper( const std::filesystem::path& filename, const std::string& property_name, - const std::vector& patient_lengths, + const std::vector>& patient_positions, const absl::flat_hash_map& dictionary_entries, F func) { arrow::MemoryPool* pool = arrow::default_memory_pool(); @@ -364,6 +574,7 @@ void iterate_strings_helper( size_t next_patient_index = 0; size_t remaining_events = 0; + size_t current_position = 0; bool has_event = false; absl::flat_hash_map per_patient_values; @@ -410,7 +621,7 @@ void iterate_strings_helper( null_bytes.insert(std::end(null_bytes), std::begin(helper), std::end(helper)); - func(null_bytes); + func(current_position, null_bytes); }; auto write_null = [&]() { @@ -452,7 +663,9 @@ void iterate_strings_helper( has_event = true; } - remaining_events = patient_lengths[next_patient_index++]; + auto next = patient_positions[next_patient_index++]; + current_position = next.first; + remaining_events = next.second; per_patient_values.clear(); null_bytes.clear(); @@ -517,19 +730,19 @@ void iterate_strings( const std::shared_ptr& type, - const std::vector& patient_lengths, + const std::vector>& patient_positions, const absl::flat_hash_map& dictionary_entries, F func) { switch (type->id()) { case arrow::Type::STRING: iterate_strings_helper( - filename, property_name, patient_lengths, dictionary_entries, + filename, property_name, patient_positions, dictionary_entries, func); break; case arrow::Type::LARGE_STRING: iterate_strings_helper( - filename, property_name, patient_lengths, dictionary_entries, + filename, property_name, patient_positions, dictionary_entries, func); break; @@ -538,101 +751,11 @@ void iterate_strings( }; } -std::vector> get_samples( - std::filesystem::path filename, std::string property_name, - const std::shared_ptr& type, - const std::vector& patient_lengths, - const absl::flat_hash_map& dictionary_entries, - size_t num_samples) { - size_t sample_count = 0; - std::vector> samples; - - iterate_strings(filename, property_name, type, patient_lengths, - dictionary_entries, [&](std::vector bytes) { - sample_count++; - - if (samples.size() < num_samples) { - samples.emplace_back(std::move(bytes)); - } else { - size_t j = (size_t)(rand() % sample_count); - if (j < num_samples) { - samples[j] = std::move(bytes); - } - } - }); - - return samples; -} - -std::pair> write_files( - std::filesystem::path filename, std::string property_name, - const std::shared_ptr& type, - const std::vector& patient_lengths, - const absl::flat_hash_map& dictionary_entries, - const std::vector& dictionary, - std::filesystem::path target_filename) { - std::vector offsets; - - std::ofstream output_file(target_filename, std::ios_base::out | - std::ios_base::binary | - std::ios_base::trunc); - - uint64_t num_bytes = 0; - - auto context_deleter = [](ZSTD_CCtx* context) { ZSTD_freeCCtx(context); }; - - std::unique_ptr context2{ - ZSTD_createCCtx(), context_deleter}; - - size_t res = - ZSTD_CCtx_setParameter(context2.get(), ZSTD_c_compressionLevel, 22); - - if (ZSTD_isError(res)) { - throw std::runtime_error("Could not set the compression level"); - } - - res = ZSTD_CCtx_loadDictionary(context2.get(), dictionary.data(), - dictionary.size()); - if (ZSTD_isError(res)) { - throw std::runtime_error("Could not load the dictionary"); - } - - iterate_strings( - filename, property_name, type, patient_lengths, dictionary_entries, - [&](std::vector bytes) { - std::vector final_bytes(sizeof(uint32_t) + - ZSTD_compressBound(bytes.size())); - res = ZSTD_compress2( - context2.get(), final_bytes.data() + sizeof(uint32_t), - final_bytes.size(), bytes.data(), bytes.size()); - - if (ZSTD_isError(res)) { - throw std::runtime_error("Unable to compress"); - } - - final_bytes.resize(res + sizeof(uint32_t)); - - uint32_t* length_pointer = (uint32_t*)final_bytes.data(); - *length_pointer = bytes.size(); - - offsets.push_back(num_bytes); - - output_file.write(final_bytes.data(), final_bytes.size()); - - num_bytes += final_bytes.size(); - }); - - return std::make_pair(num_bytes, std::move(offsets)); -} - template void string_reader_thread_helper( const std::filesystem::path& filename, const std::string& property_name, - const std::vector& patient_lengths, - std::vector< - moodycamel::BlockingConcurrentQueue>>& - all_write_queues, - moodycamel::LightweightSemaphore& queue_semaphore) { + const std::vector>& patient_positions, + CappedQueueSender>& sender) { arrow::MemoryPool* pool = arrow::default_memory_pool(); // Configure general Parquet reader settings @@ -658,8 +781,6 @@ void string_reader_thread_helper( int64_t column = properties.find(property_name)->second.second; std::vector columns = {(int)column}; - ssize_t slots_to_write = queue_semaphore.waitMany(SEMAPHORE_BLOCK_SIZE); - absl::flat_hash_map items; size_t next_patient_index = 0; @@ -669,16 +790,11 @@ void string_reader_thread_helper( auto flush_patient = [&]() { for (auto& item : items) { size_t h = std::hash{}(item.first); - size_t partition = h % all_write_queues.size(); + size_t partition = h % sender.num_threads; item.second += (((uint64_t)1) << 32); - if (slots_to_write == 0) { - slots_to_write = queue_semaphore.waitMany(SEMAPHORE_BLOCK_SIZE); - } - slots_to_write--; - - all_write_queues[partition].enqueue(std::move(item)); + sender.send_item(partition, std::move(item)); } }; @@ -691,7 +807,7 @@ void string_reader_thread_helper( } items.clear(); - remaining_events = patient_lengths[next_patient_index++]; + remaining_events = patient_positions[next_patient_index++].second; } if (!item.empty()) { @@ -728,29 +844,22 @@ void string_reader_thread_helper( if (items.size() != 0) { flush_patient(); } - - queue_semaphore.signal(slots_to_write); } void string_reader_thread( const std::filesystem::path& filename, const std::string& property_name, const std::shared_ptr& type, - const std::vector& patient_lengths, - std::vector< - moodycamel::BlockingConcurrentQueue>>& - all_write_queues, - moodycamel::LightweightSemaphore& queue_semaphore) { + const std::vector>& patient_positions, + CappedQueueSender>& sender) { switch (type->id()) { case arrow::Type::STRING: string_reader_thread_helper( - filename, property_name, patient_lengths, all_write_queues, - queue_semaphore); + filename, property_name, patient_positions, sender); break; case arrow::Type::LARGE_STRING: string_reader_thread_helper( - filename, property_name, patient_lengths, all_write_queues, - queue_semaphore); + filename, property_name, patient_positions, sender); break; default: @@ -760,11 +869,7 @@ void string_reader_thread( void string_writer_thread( std::filesystem::path folder_to_write_to, - moodycamel::BlockingConcurrentQueue>& - queue, - moodycamel::LightweightSemaphore& queue_semaphore, int num_threads) { - moodycamel::ConsumerToken c_tok(queue); - + CappedQueueReceiver>& receiver) { int next_index = 0; auto context_deleter = [](ZSTD_CCtx* context) { ZSTD_freeCCtx(context); }; @@ -798,32 +903,11 @@ void string_writer_thread( bytes_written = 0; }; - ssize_t num_read = 0; - - int num_remaining_threads = num_threads; - - std::pair entry; + std::pair entry; while (true) { - queue.wait_dequeue(c_tok, entry); - - if (entry.first.empty()) { - num_remaining_threads--; - if (num_remaining_threads == 0) { - if (items.size() != 0) { - flush(); - } - queue_semaphore.signal(num_read); - return; - } else { - continue; - } - } - - num_read++; - - if (num_read == SEMAPHORE_BLOCK_SIZE) { - queue_semaphore.signal(SEMAPHORE_BLOCK_SIZE); - num_read = 0; + bool got_next = receiver.get_item(entry); + if (!got_next) { + break; } size_t item_size = entry.first.size(); @@ -918,9 +1002,9 @@ std::vector> merger_thread( return entries; } -template +template void run_all(const std::vector& work_entries, int num_threads, - F func, H end, G other) { + F func) { std::vector threads; moodycamel::BlockingConcurrentQueue> work_queue; @@ -931,14 +1015,12 @@ void run_all(const std::vector& work_entries, int num_threads, for (int i = 0; i < num_threads; i++) { work_queue.enqueue(std::nullopt); - threads.emplace_back([&work_queue, &func, &end]() { + threads.emplace_back([&work_queue, &func]() { std::optional next_entry; while (true) { work_queue.wait_dequeue(next_entry); if (!next_entry) { - end(); - return; } @@ -947,154 +1029,165 @@ void run_all(const std::vector& work_entries, int num_threads, }); } - other(threads); - for (auto& thread : threads) { thread.join(); } } -template -void run_all(const std::vector& work_entries, int num_threads, - F func) { - run_all( - work_entries, num_threads, func, []() {}, - [](std::vector& threads) {}); -} - -void process_string_property( - const std::string& property_name, - - const std::shared_ptr& type, - const std::vector>& patient_lengths, - std::filesystem::path temp_path, const std::vector& work_entries, - int num_threads) { - std::filesystem::path string_path = temp_path / property_name; - std::filesystem::create_directories(string_path); +template +void run_reader_writer(const std::vector& work_entries, + int num_threads, R reader, W writer) { + CappedQueue queue(num_threads); - // We have to start by reading everything in + std::vector threads; + moodycamel::BlockingConcurrentQueue> work_queue; - std::vector< - moodycamel::BlockingConcurrentQueue>> - queues(num_threads); + for (const auto& entry : work_entries) { + work_queue.enqueue({entry}); + } - moodycamel::LightweightSemaphore write_semaphore(QUEUE_SIZE * num_threads); + for (int i = 0; i < num_threads; i++) { + work_queue.enqueue(std::nullopt); - auto reader = [&property_name, &type, &patient_lengths, &queues, - &write_semaphore](const std::filesystem::path fname, - size_t index) { - string_reader_thread(fname, property_name, type, patient_lengths[index], - queues, write_semaphore); - }; + threads.emplace_back([&work_queue, &reader, &queue]() { + std::optional next_entry; + CappedQueueSender sender(queue); + while (true) { + work_queue.wait_dequeue(next_entry); - auto writer = [num_threads, &string_path, &queues, - &write_semaphore](std::vector& threads) { - for (int i = 0; i < num_threads; i++) { - threads.emplace_back( - [i, &string_path, &queues, &write_semaphore, num_threads]() { - std::filesystem::path writer_path = - string_path / std::to_string(i); - std::filesystem::create_directories(writer_path); - string_writer_thread(writer_path, queues[i], - write_semaphore, num_threads); - }); - } - }; + if (!next_entry) { + return; + } - auto end_reader = [&queues]() { - for (auto& queue : queues) { - queue.enqueue(std::make_pair("", 0)); - } - }; - run_all(work_entries, num_threads, reader, end_reader, writer); + std::apply(reader, + std::tuple_cat(std::move(*next_entry), + std::forward_as_tuple(sender))); + } + }); - std::vector>> all_entries( - num_threads); + threads.emplace_back([i, &writer, &queue]() { + CappedQueueReceiver receiver(queue, i); + writer(i, receiver); + }); + } - { - std::vector threads; + for (auto& thread : threads) { + thread.join(); + } +} - for (int i = 0; i < num_threads; i++) { - threads.emplace_back([i, &string_path, &all_entries]() { - std::filesystem::path writer_path = - string_path / std::to_string(i); - all_entries[i] = merger_thread(writer_path); - std::filesystem::remove_all(writer_path); - }); - } +template +std::pair>, size_t> get_samples( + std::filesystem::path filename, + const std::vector>& patient_positions, + size_t num_samples, F func) { + size_t sample_count = 0; + std::vector> samples; - for (auto& thread : threads) { - thread.join(); - } - } + size_t estimated_size = 0; - absl::flat_hash_map> found; + func(filename, patient_positions, + [&](uint32_t position, std::vector bytes) { + sample_count++; + estimated_size += bytes.size(); - std::vector> entries; - size_t e_index = 0; - for (auto& e : all_entries) { - size_t ei_index = 0; - for (auto& ei : e) { -#ifndef NDEBUG - if (found.count(ei.second) != 0) { - std::cout << "Got duplicate! " << ei.second << " " << ei.first - << " " << e_index << " " << ei_index << " " - << found[ei.second].first << " " - << found[ei.second].second << std::endl; - abort(); - } - found[ei.second].first = e_index; - found[ei.second].second = ei_index; -#endif + if (samples.size() < num_samples) { + samples.emplace_back(std::move(bytes)); + } else { + size_t j = (size_t)(rand() % sample_count); + if (j < num_samples) { + samples[j] = std::move(bytes); + } + } + }); - entries.emplace_back(std::move(ei)); - ei_index++; - } - e_index++; - } + return {samples, estimated_size}; +} +template +void read_files( + const std::filesystem::path& filename, + const std::vector>& patient_positions, + const std::vector& dictionary, int num_patients_per_shard, + int num_shards_per_thread, + CappedQueueSender>>& sender, F func) { auto context_deleter = [](ZSTD_CCtx* context) { ZSTD_freeCCtx(context); }; - pdqsort(std::begin(entries), std::end(entries), - std::greater>()); - - absl::flat_hash_map dictionary_entries; + std::unique_ptr context2{ + ZSTD_createCCtx(), context_deleter}; - std::unique_ptr context{ - ZSTD_createCCtx(), context_deleter}; + size_t res = + ZSTD_CCtx_setParameter(context2.get(), ZSTD_c_compressionLevel, 22); - { - ZstdRowWriter writer((string_path / "dictionary").string(), - context.get()); + if (ZSTD_isError(res)) { + throw std::runtime_error("Could not set the compression level"); + } - for (size_t i = 0; i < entries.size(); i++) { - const auto& e = entries[i]; - auto iter = dictionary_entries.try_emplace(std::move(e.second), i); - if (!iter.second) { - throw std::runtime_error("Already inserted? " + e.second); - } - writer.add_next(e.second, e.first); - } + res = ZSTD_CCtx_loadDictionary(context2.get(), dictionary.data(), + dictionary.size()); + if (ZSTD_isError(res)) { + throw std::runtime_error("Could not load the dictionary"); } + func(filename, patient_positions, + [&](uint32_t position, std::vector bytes) { + std::vector final_bytes(sizeof(uint32_t) + + ZSTD_compressBound(bytes.size())); + res = ZSTD_compress2( + context2.get(), final_bytes.data() + sizeof(uint32_t), + final_bytes.size(), bytes.data(), bytes.size()); + + if (ZSTD_isError(res)) { + throw std::runtime_error("Unable to compress"); + } + + final_bytes.resize(res + sizeof(uint32_t)); + + uint32_t* length_pointer = (uint32_t*)final_bytes.data(); + *length_pointer = bytes.size(); + + int shard = position / num_patients_per_shard; + + int thread = shard / num_shards_per_thread; + + sender.send_item(thread, {position, std::move(final_bytes)}); + }); +} + +template +void process_generic_property( + const std::filesystem::path& string_path, const std::string& property_name, + const std::vector>>& + patient_positions, + const std::vector& work_entries, int num_threads, + int num_patients, I iterate) { + // We have to start by reading everything in + std::vector>> all_samples( work_entries.size()); size_t num_samples_per_entry = (10 * 1000 + work_entries.size() - 1) / work_entries.size(); + std::atomic estimated_size; + run_all(work_entries, num_threads, - [&all_samples, &property_name, &type, &patient_lengths, - &dictionary_entries, - num_samples_per_entry](std::filesystem::path path, size_t index) { - all_samples[index] = get_samples( - path, property_name, type, patient_lengths[index], - dictionary_entries, num_samples_per_entry); + [&all_samples, &property_name, &patient_positions, &estimated_size, + num_samples_per_entry, + &iterate](std::filesystem::path path, size_t index) { + auto res = get_samples(path, patient_positions[index], + num_samples_per_entry, iterate); + all_samples[index] = std::move(res.first); + estimated_size.fetch_add(res.second); }); std::vector sample_sizes; std::vector sample_buffer; + int num_shards = get_num_shards(num_threads, estimated_size); + int num_patients_per_shard = (num_patients + num_shards - 1) / num_shards; + int num_shards_per_thread = (num_shards + num_threads - 1) / num_threads; + for (const auto& samples : all_samples) { for (const auto& sample : samples) { sample_sizes.push_back(sample.size()); @@ -1123,77 +1216,167 @@ void process_string_property( zdict.write(dictionary.data(), dictionary.size()); } - std::vector>> all_lengths( - work_entries.size()); - - run_all(work_entries, num_threads, - [&all_lengths, &string_path, &property_name, &type, - &patient_lengths, &dictionary_entries, - &dictionary](std::filesystem::path path, size_t index) { - std::filesystem::path target_path = - string_path / (std::to_string(index) + ".data"); - all_lengths[index] = write_files( - path, property_name, type, patient_lengths[index], - dictionary_entries, dictionary, target_path); - }); + auto reader = + [&string_path, &property_name, &patient_positions, &dictionary, + num_patients_per_shard, num_shards_per_thread, &iterate]( + const std::filesystem::path& path, size_t index, + CappedQueueSender>>& sender) { + read_files(path, patient_positions[index], dictionary, + num_patients_per_shard, num_shards_per_thread, sender, + iterate); + }; + + auto writer = + [&string_path, num_patients_per_shard, num_shards_per_thread]( + int i, CappedQueueReceiver>>& + receiver) { + write_files(i, string_path, num_patients_per_shard, + num_shards_per_thread, receiver); + }; + + run_reader_writer>>( + work_entries, num_threads, reader, writer); + + uint64_t starting_offset = (num_patients + 1) * sizeof(uint64_t); + SharedFile data_file(string_path / "data", num_threads); + data_file.file.write((const char*)&starting_offset, + sizeof(starting_offset)); + data_file.file.seekp(starting_offset); - size_t num_patients = 0; + std::vector threads; + for (int i = 0; i < num_threads; i++) { + threads.emplace_back([i, &string_path, &data_file, + num_patients_per_shard, num_shards_per_thread]() { + sort_concatenate_shards(i, string_path, data_file, + num_patients_per_shard, + num_shards_per_thread); + }); + } - for (const auto& length : all_lengths) { - num_patients += length.second.size(); + for (auto& thread : threads) { + thread.join(); } +} - uint64_t current_offset = (num_patients + 1) * sizeof(uint64_t); +void process_string_property( + const std::string& property_name, + const std::shared_ptr& type, + const std::vector>>& + patient_positions, + std::filesystem::path temp_path, const std::vector& work_entries, + int num_threads, int num_patients) { + std::filesystem::path string_path = temp_path / property_name; + std::filesystem::create_directories(string_path); - for (auto& length : all_lengths) { - size_t temp = length.first; + // We have to start by reading everything in + auto reader = + [&property_name, &type, &patient_positions]( + const std::filesystem::path& fname, size_t index, + CappedQueueSender>& sender) { + string_reader_thread(fname, property_name, type, + patient_positions[index], sender); + }; + + auto writer = + [&string_path]( + size_t index, + CappedQueueReceiver>& receiver) { + std::filesystem::path writer_path = + string_path / std::to_string(index); + std::filesystem::create_directories(writer_path); + string_writer_thread(writer_path, receiver); + }; + + run_reader_writer>( + work_entries, num_threads, reader, writer); - length.first = current_offset; + std::vector>> all_entries( + num_threads); - current_offset += temp; - } + { + std::vector threads; - run_all(work_entries, num_threads, - [&all_lengths](std::filesystem::path path, size_t index) { - auto& item = all_lengths[index]; - for (auto& val : item.second) { - val += item.first; - } + for (int i = 0; i < num_threads; i++) { + threads.emplace_back([i, &string_path, &all_entries]() { + std::filesystem::path writer_path = + string_path / std::to_string(i); + all_entries[i] = merger_thread(writer_path); + std::filesystem::remove_all(writer_path); }); + } - std::ofstream data_file((string_path / "data"), std::ios_base::out | - std::ios_base::binary | - std::ios_base::trunc); - data_file.exceptions(std::ifstream::failbit | std::ifstream::badbit); + for (auto& thread : threads) { + thread.join(); + } + } - for (const auto& entry : work_entries) { - auto& item = all_lengths[entry.second]; - ssize_t num_to_write = item.second.size() * sizeof(uint64_t); - const char* buffer = (const char*)item.second.data(); - data_file.write(buffer, num_to_write); + absl::flat_hash_map> found; + + std::vector> entries; + size_t e_index = 0; + for (auto& e : all_entries) { + size_t ei_index = 0; + for (auto& ei : e) { +#ifndef NDEBUG + if (found.count(ei.second) != 0) { + std::cout << "Got duplicate! " << ei.second << " " << ei.first + << " " << e_index << " " << ei_index << " " + << found[ei.second].first << " " + << found[ei.second].second << std::endl; + abort(); + } + found[ei.second].first = e_index; + found[ei.second].second = ei_index; +#endif + + entries.emplace_back(std::move(ei)); + ei_index++; + } + e_index++; } - data_file.write((const char*)¤t_offset, sizeof(current_offset)); + auto context_deleter = [](ZSTD_CCtx* context) { ZSTD_freeCCtx(context); }; - for (const auto& entry : work_entries) { - std::filesystem::path entry_path = - string_path / (std::to_string(entry.second) + ".data"); + pdqsort(std::begin(entries), std::end(entries), + std::greater>()); - if (all_lengths[entry.second].second.size() > 0) { - std::ifstream entry_file(entry_path, - std::ios_base::in | std::ios_base::binary); + absl::flat_hash_map dictionary_entries; - data_file << entry_file.rdbuf(); - } + std::unique_ptr context{ + ZSTD_createCCtx(), context_deleter}; - std::filesystem::remove(entry_path); + { + ZstdRowWriter writer((string_path / "dictionary").string(), + context.get()); + + for (size_t i = 0; i < entries.size(); i++) { + const auto& e = entries[i]; + auto iter = dictionary_entries.try_emplace(std::move(e.second), i); + if (!iter.second) { + throw std::runtime_error("Already inserted? " + e.second); + } + writer.add_next(e.second, e.first); + } } + + auto iterate = + [&property_name, &type, &dictionary_entries]( + const std::filesystem::path& path, + const std::vector>& patient_positions, + auto func) { + return iterate_strings(path, property_name, type, patient_positions, + dictionary_entries, func); + }; + + process_generic_property(string_path, property_name, patient_positions, + work_entries, num_threads, num_patients, iterate); } template -void iterate_primitive(std::filesystem::path filename, - std::string property_name, - const std::vector& patient_lengths, F func) { +void iterate_primitive( + std::filesystem::path filename, std::string property_name, + const std::vector>& patient_positions, + F func) { arrow::MemoryPool* pool = arrow::default_memory_pool(); // Configure general Parquet reader settings @@ -1220,7 +1403,8 @@ void iterate_primitive(std::filesystem::path filename, std::vector columns = {(int)column}; size_t next_patient_index = 0; - size_t remaining_events; + size_t remaining_events = 0; + size_t current_position = 0; bool has_event = false; std::vector null_bytes; @@ -1244,7 +1428,7 @@ void iterate_primitive(std::filesystem::path filename, null_bytes.insert(std::end(null_bytes), std::begin(value_bytes), std::end(value_bytes)); - func(std::move(null_bytes)); + func(current_position, std::move(null_bytes)); }; auto write_null = [&]() { @@ -1273,7 +1457,9 @@ void iterate_primitive(std::filesystem::path filename, has_event = true; } - remaining_events = patient_lengths[next_patient_index++]; + auto next = patient_positions[next_patient_index++]; + current_position = next.first; + remaining_events = next.second; null_bytes.clear(); value_bytes.clear(); @@ -1288,256 +1474,78 @@ void iterate_primitive(std::filesystem::path filename, nullmap_offset = (uint64_t*)null_bytes.data(); } - if (value.empty()) { - write_null(); - } else { - write_value(value); - } - - remaining_events--; - }; - - for (int64_t row_group = 0; row_group < arrow_reader->num_row_groups(); - row_group++) { - std::shared_ptr table; - PARQUET_THROW_NOT_OK( - arrow_reader->ReadRowGroup(row_group, columns, &table)); - - auto chunked_values = table->GetColumnByName(property_name); - - for (const auto& array : chunked_values->chunks()) { - auto primitive_array = - std::dynamic_pointer_cast(array); - if (primitive_array == nullptr) { - throw std::runtime_error("Could not cast property"); - } - - int32_t type_bytes = primitive_array->type()->byte_width(); - - for (int64_t i = 0; i < primitive_array->length(); i++) { - if (primitive_array->IsNull(i)) { - add_value(std::string_view()); - } else { - std::string_view item( - (const char*)primitive_array->values()->data() + - (primitive_array->offset() + i) * type_bytes, - (size_t)type_bytes); - add_value(item); - } - } - } - } - - if (has_event) { - flush_patient(); - } -} - -std::vector> get_primitive_samples( - std::filesystem::path filename, std::string property_name, - const std::vector& patient_lengths, size_t num_samples) { - std::vector> samples; - size_t sample_count = 0; - - iterate_primitive(filename, property_name, patient_lengths, - [&](std::vector bytes) { - sample_count++; - - if (samples.size() < num_samples) { - samples.emplace_back(std::move(bytes)); - } else { - size_t j = (size_t)(rand() % sample_count); - if (j < num_samples) { - samples[j] = std::move(bytes); - } - } - }); - - return samples; -} - -std::pair> write_primitive_files( - std::filesystem::path filename, std::string property_name, - const std::vector& patient_lengths, - const std::vector& dictionary, - std::filesystem::path target_filename) { - std::vector offsets; - - std::ofstream output_file(target_filename, std::ios_base::out | - std::ios_base::binary | - std::ios_base::trunc); - output_file.exceptions(std::ifstream::failbit | std::ifstream::badbit); - - uint64_t num_bytes = 0; - - auto context_deleter = [](ZSTD_CCtx* context) { ZSTD_freeCCtx(context); }; - - std::unique_ptr context2{ - ZSTD_createCCtx(), context_deleter}; - - size_t res = - ZSTD_CCtx_setParameter(context2.get(), ZSTD_c_compressionLevel, 22); - - if (ZSTD_isError(res)) { - throw std::runtime_error("Could not set the compression level"); - } - - res = ZSTD_CCtx_loadDictionary(context2.get(), dictionary.data(), - dictionary.size()); - if (ZSTD_isError(res)) { - throw std::runtime_error("Could not load the dictionary"); - } - - iterate_primitive( - filename, property_name, patient_lengths, [&](std::vector bytes) { - std::vector final_bytes(sizeof(uint32_t) + - ZSTD_compressBound(bytes.size())); - res = ZSTD_compress2( - context2.get(), final_bytes.data() + sizeof(uint32_t), - final_bytes.size(), bytes.data(), bytes.size()); - - if (ZSTD_isError(res)) { - throw std::runtime_error("Unable to compress"); - } - - final_bytes.resize(res + sizeof(uint32_t)); - - uint32_t* length_pointer = (uint32_t*)final_bytes.data(); - *length_pointer = bytes.size(); - - offsets.push_back(num_bytes); - - output_file.write(final_bytes.data(), final_bytes.size()); - - num_bytes += final_bytes.size(); - }); - - return std::make_pair(num_bytes, std::move(offsets)); -} - -void process_primitive_property( - const std::string& property_name, - const std::vector>& patient_lengths, - std::filesystem::path temp_path, const std::vector& work_entries, - int num_threads) { - std::filesystem::path string_path = temp_path / property_name; - std::filesystem::create_directories(string_path); - - // We have to start by reading everything in - - std::vector>> all_samples( - work_entries.size()); - - size_t num_samples_per_entry = - (10 * 1000 + work_entries.size() - 1) / work_entries.size(); - - run_all(work_entries, num_threads, - [&all_samples, &property_name, &patient_lengths, - num_samples_per_entry](std::filesystem::path path, size_t index) { - all_samples[index] = get_primitive_samples( - path, property_name, patient_lengths[index], - num_samples_per_entry); - }); - - std::vector sample_sizes; - std::vector sample_buffer; - - for (const auto& samples : all_samples) { - for (const auto& sample : samples) { - sample_sizes.push_back(sample.size()); - sample_buffer.insert(std::end(sample_buffer), std::begin(sample), - std::end(sample)); - } - } - - size_t dictionary_size = 100 * 1000; // 100 kilobytes - std::vector dictionary(dictionary_size); - - size_t dict_size = ZDICT_trainFromBuffer( - dictionary.data(), dictionary.size(), sample_buffer.data(), - sample_sizes.data(), sample_sizes.size()); - - if (ZDICT_isError(dict_size)) { - dict_size = 0; - } - - dictionary.resize(dict_size); - - { - std::ofstream zdict( - string_path / std::string("zdict"), - std::ios_base::out | std::ios_base::binary | std::ios_base::trunc); - zdict.write(dictionary.data(), dictionary.size()); - } - - std::vector>> all_lengths( - work_entries.size()); - - run_all(work_entries, num_threads, - [&all_lengths, &string_path, &property_name, &patient_lengths, - &dictionary](std::filesystem::path path, size_t index) { - std::filesystem::path target_path = - string_path / (std::to_string(index) + ".data"); - all_lengths[index] = write_primitive_files( - path, property_name, patient_lengths[index], dictionary, - target_path); - }); - - uint64_t num_patients = 0; - - for (const auto& length : all_lengths) { - num_patients += length.second.size(); - } - - uint64_t current_offset = (num_patients + 1) * sizeof(uint64_t); - - for (auto& length : all_lengths) { - size_t temp = length.first; - - length.first = current_offset; - - current_offset += temp; - } - - run_all(work_entries, num_threads, - [&all_lengths](std::filesystem::path path, size_t index) { - auto& item = all_lengths[index]; - for (auto& val : item.second) { - val += item.first; - } - }); + if (value.empty()) { + write_null(); + } else { + write_value(value); + } - std::ofstream data_file((string_path / "data"), std::ios_base::out | - std::ios_base::binary | - std::ios_base::trunc); + remaining_events--; + }; - for (const auto& entry : work_entries) { - auto& item = all_lengths[entry.second]; - ssize_t num_to_write = item.second.size() * sizeof(uint64_t); - const char* buffer = (const char*)item.second.data(); - data_file.write(buffer, num_to_write); - } + for (int64_t row_group = 0; row_group < arrow_reader->num_row_groups(); + row_group++) { + std::shared_ptr table; + PARQUET_THROW_NOT_OK( + arrow_reader->ReadRowGroup(row_group, columns, &table)); - data_file.write((const char*)¤t_offset, sizeof(current_offset)); + auto chunked_values = table->GetColumnByName(property_name); - for (const auto& entry : work_entries) { - std::filesystem::path entry_path = - string_path / (std::to_string(entry.second) + ".data"); + for (const auto& array : chunked_values->chunks()) { + auto primitive_array = + std::dynamic_pointer_cast(array); + if (primitive_array == nullptr) { + throw std::runtime_error("Could not cast property"); + } - if (all_lengths[entry.second].second.size() > 0) { - std::ifstream entry_file(entry_path, - std::ios_base::in | std::ios_base::binary); + int32_t type_bytes = primitive_array->type()->byte_width(); - data_file << entry_file.rdbuf(); + for (int64_t i = 0; i < primitive_array->length(); i++) { + if (primitive_array->IsNull(i)) { + add_value(std::string_view()); + } else { + std::string_view item( + (const char*)primitive_array->values()->data() + + (primitive_array->offset() + i) * type_bytes, + (size_t)type_bytes); + add_value(item); + } + } } + } - std::filesystem::remove(entry_path); + if (has_event) { + flush_patient(); } } +void process_primitive_property( + const std::string& property_name, + const std::vector>>& + patient_positions, + std::filesystem::path temp_path, const std::vector& work_entries, + int num_threads, int num_patients) { + std::filesystem::path string_path = temp_path / property_name; + std::filesystem::create_directories(string_path); + + auto iterate = + [&property_name]( + const std::filesystem::path& path, + const std::vector>& patient_positions, + auto func) { + return iterate_primitive(path, property_name, patient_positions, + func); + }; + + process_generic_property(string_path, property_name, patient_positions, + work_entries, num_threads, num_patients, iterate); +} + template -void iterate_time(std::filesystem::path filename, std::string property_name, - const std::vector& patient_lengths, F func) { +void iterate_time( + std::filesystem::path filename, std::string property_name, + const std::vector>& patient_positions, + F func) { arrow::MemoryPool* pool = arrow::default_memory_pool(); // Configure general Parquet reader settings @@ -1564,7 +1572,8 @@ void iterate_time(std::filesystem::path filename, std::string property_name, std::vector columns = {(int)column}; size_t next_patient_index = 0; - size_t remaining_events; + size_t remaining_events = 0; + uint32_t current_position = 0; bool has_event = false; constexpr int64_t micros_per_second = ((int64_t)1000) * 1000; @@ -1634,7 +1643,7 @@ void iterate_time(std::filesystem::path filename, std::string property_name, helper.resize(sizeof(int64_t) + sizeof(uint32_t) + count); - func(std::move(helper)); + func(current_position, std::move(helper)); }; auto add_time = [&](std::optional time) { @@ -1645,7 +1654,9 @@ void iterate_time(std::filesystem::path filename, std::string property_name, has_event = true; } - remaining_events = patient_lengths[next_patient_index++]; + auto next = patient_positions[next_patient_index++]; + current_position = next.first; + remaining_events = next.second; values.clear(); values.push_back(0); @@ -1717,209 +1728,28 @@ void iterate_time(std::filesystem::path filename, std::string property_name, } } -std::vector> get_time_samples( - std::filesystem::path filename, std::string property_name, - const std::vector& patient_lengths, size_t num_samples) { - std::vector> samples; - size_t sample_count = 0; - - iterate_time(filename, property_name, patient_lengths, - [&](std::vector bytes) { - sample_count++; - - if (samples.size() < num_samples) { - samples.emplace_back(std::move(bytes)); - } else { - size_t j = (size_t)(rand() % sample_count); - if (j < num_samples) { - samples[j] = std::move(bytes); - } - } - }); - - return samples; -} - -std::pair> write_time_files( - std::filesystem::path filename, std::string property_name, - const std::vector& patient_lengths, - const std::vector& dictionary, - std::filesystem::path target_filename) { - std::vector offsets; - - std::ofstream output_file(target_filename, std::ios_base::out | - std::ios_base::binary | - std::ios_base::trunc); - output_file.exceptions(std::ifstream::failbit | std::ifstream::badbit); - - uint64_t num_bytes = 0; - - auto context_deleter = [](ZSTD_CCtx* context) { ZSTD_freeCCtx(context); }; - - std::unique_ptr context2{ - ZSTD_createCCtx(), context_deleter}; - - size_t res = - ZSTD_CCtx_setParameter(context2.get(), ZSTD_c_compressionLevel, 22); - - if (ZSTD_isError(res)) { - throw std::runtime_error("Could not set the compression level"); - } - - res = ZSTD_CCtx_loadDictionary(context2.get(), dictionary.data(), - dictionary.size()); - if (ZSTD_isError(res)) { - throw std::runtime_error("Could not load the dictionary"); - } - - iterate_time( - filename, property_name, patient_lengths, [&](std::vector bytes) { - std::vector final_bytes(sizeof(uint32_t) + - ZSTD_compressBound(bytes.size())); - res = ZSTD_compress2( - context2.get(), final_bytes.data() + sizeof(uint32_t), - final_bytes.size(), bytes.data(), bytes.size()); - - if (ZSTD_isError(res)) { - throw std::runtime_error("Unable to compress"); - } - - final_bytes.resize(res + sizeof(uint32_t)); - - uint32_t* length_pointer = (uint32_t*)final_bytes.data(); - *length_pointer = bytes.size(); - - offsets.push_back(num_bytes); - - output_file.write(final_bytes.data(), final_bytes.size()); - - num_bytes += final_bytes.size(); - }); - - return std::make_pair(num_bytes, std::move(offsets)); -} - void process_time_property( const std::string& property_name, - const std::vector>& patient_lengths, + const std::vector>>& + patient_positions, std::filesystem::path temp_path, const std::vector& work_entries, - int num_threads) { + int num_threads, int num_patients) { std::filesystem::path string_path = temp_path / property_name; std::filesystem::create_directories(string_path); - // We have to start by reading everything in - - std::vector>> all_samples( - work_entries.size()); - - size_t num_samples_per_entry = - (10 * 1000 + work_entries.size() - 1) / work_entries.size(); - - run_all(work_entries, num_threads, - [&all_samples, &property_name, &patient_lengths, - num_samples_per_entry](std::filesystem::path path, size_t index) { - all_samples[index] = get_time_samples(path, property_name, - patient_lengths[index], - num_samples_per_entry); - }); - - std::vector sample_sizes; - std::vector sample_buffer; - - for (const auto& samples : all_samples) { - for (const auto& sample : samples) { - sample_sizes.push_back(sample.size()); - sample_buffer.insert(std::end(sample_buffer), std::begin(sample), - std::end(sample)); - } - } - - size_t dictionary_size = 100 * 1000; // 100 kilobytes - std::vector dictionary(dictionary_size); - - size_t dict_size = ZDICT_trainFromBuffer( - dictionary.data(), dictionary.size(), sample_buffer.data(), - sample_sizes.data(), sample_sizes.size()); - - if (ZDICT_isError(dict_size)) { - dict_size = 0; - } - - dictionary.resize(dict_size); - - { - std::ofstream zdict( - string_path / std::string("zdict"), - std::ios_base::out | std::ios_base::binary | std::ios_base::trunc); - zdict.write(dictionary.data(), dictionary.size()); - } - - std::vector>> all_lengths( - work_entries.size()); - - run_all(work_entries, num_threads, - [&all_lengths, &string_path, &property_name, &patient_lengths, - &dictionary](std::filesystem::path path, size_t index) { - std::filesystem::path target_path = - string_path / (std::to_string(index) + ".data"); - all_lengths[index] = write_time_files(path, property_name, - patient_lengths[index], - dictionary, target_path); - }); - - uint64_t num_patients = 0; - - for (const auto& length : all_lengths) { - num_patients += length.second.size(); - } - - uint64_t current_offset = (num_patients + 1) * sizeof(uint64_t); - - for (auto& length : all_lengths) { - size_t temp = length.first; - - length.first = current_offset; - - current_offset += temp; - } - - run_all(work_entries, num_threads, - [&all_lengths](std::filesystem::path path, size_t index) { - auto& item = all_lengths[index]; - for (auto& val : item.second) { - val += item.first; - } - }); - - std::ofstream data_file((string_path / "data"), std::ios_base::out | - std::ios_base::binary | - std::ios_base::trunc); - - for (const auto& entry : work_entries) { - auto& item = all_lengths[entry.second]; - ssize_t num_to_write = item.second.size() * sizeof(uint64_t); - const char* buffer = (const char*)item.second.data(); - data_file.write(buffer, num_to_write); - } - - data_file.write((const char*)¤t_offset, sizeof(current_offset)); - - for (const auto& entry : work_entries) { - std::filesystem::path entry_path = - string_path / (std::to_string(entry.second) + ".data"); - - if (all_lengths[entry.second].second.size() > 0) { - std::ifstream entry_file(entry_path, - std::ios_base::in | std::ios_base::binary); - - data_file << entry_file.rdbuf(); - } + auto iterate = + [&property_name]( + const std::filesystem::path& path, + const std::vector>& patient_positions, + auto func) { + return iterate_time(path, property_name, patient_positions, func); + }; - std::filesystem::remove(entry_path); - } + process_generic_property(string_path, property_name, patient_positions, + work_entries, num_threads, num_patients, iterate); } -std::pair, std::vector> get_patient_ids( +std::vector> get_patient_ids( std::filesystem::path filename) { arrow::MemoryPool* pool = arrow::default_memory_pool(); @@ -1946,18 +1776,14 @@ std::pair, std::vector> get_patient_ids( int64_t patient_id_column = properties.find("patient_id")->second.second; std::vector columns = {(int)patient_id_column}; - std::vector patient_ids; - - std::vector unique_patient_ids; - std::vector lengths; + std::vector> result; bool has_event = false; int64_t current_patient_id = 0; size_t current_patient_length = 0; auto flush_patient = [&]() { - unique_patient_ids.push_back(current_patient_id); - lengths.push_back(current_patient_length); + result.emplace_back(current_patient_id, current_patient_length); }; auto add_patient_id = [&](int64_t patient_id) { @@ -2006,58 +1832,21 @@ std::pair, std::vector> get_patient_ids( flush_patient(); } - return {std::move(unique_patient_ids), std::move(lengths)}; + return result; } -std::pair>, uint32_t> process_patient_id( +std::vector>> process_patient_id( std::filesystem::path temp_path, const std::vector& work_entries, int num_threads) { - std::vector, std::vector>> - all_lengths(work_entries.size()); - - std::vector> result_lengths(work_entries.size()); - - run_all( - work_entries, num_threads, - [&temp_path, &all_lengths](std::filesystem::path path, size_t index) { - all_lengths[index] = get_patient_ids(path); - }); - - size_t num_patients = 0; - - for (const auto& length : all_lengths) { - num_patients += length.first.size(); - } - - { - std::ofstream pids_file( - temp_path / "patient_id", - std::ios_base::out | std::ios_base::binary | std::ios_base::trunc); - std::ofstream lengths_file( - temp_path / "meds_reader.length", - std::ios_base::out | std::ios_base::binary | std::ios_base::trunc); - - absl::flat_hash_set seen_patient_ids; - - for (const auto& entry : work_entries) { - auto& item = all_lengths[entry.second]; - pids_file.write((const char*)item.first.data(), - sizeof(int64_t) * item.first.size()); - for (int64_t pid : item.first) { - if (seen_patient_ids.count(pid) != 0) { - throw std::runtime_error("Had duplicate patient ids " + - std::to_string(pid)); - } - seen_patient_ids.insert(pid); - } - lengths_file.write((const char*)item.second.data(), - sizeof(uint32_t) * item.second.size()); + std::vector>> result( + work_entries.size()); - result_lengths[entry.second] = std::move(item.second); - } - } + run_all(work_entries, num_threads, + [&temp_path, &result](std::filesystem::path path, size_t index) { + result[index] = get_patient_ids(path); + }); - return {result_lengths, num_patients}; + return result; } DataType convert_to_datatype(const std::shared_ptr& type) { @@ -2103,23 +1892,26 @@ DataType convert_to_datatype(const std::shared_ptr& type) { abort(); } -void process_property(const std::string& property_name, - const std::shared_ptr& type, - const std::vector>& patient_lengths, - const std::filesystem::path& temp_path, - const std::vector& work_entries, - int num_threads) { +void process_property( + const std::string& property_name, + const std::shared_ptr& type, + const std::vector>>& + patient_positions, + const std::filesystem::path& temp_path, + const std::vector& work_entries, int num_threads, + int num_patients) { if (property_name == "time") { - process_time_property(property_name, patient_lengths, temp_path, - work_entries, num_threads); + process_time_property(property_name, patient_positions, temp_path, + work_entries, num_threads, num_patients); return; } switch (type->id()) { case arrow::Type::STRING: case arrow::Type::LARGE_STRING: - process_string_property(property_name, type, patient_lengths, - temp_path, work_entries, num_threads); + process_string_property(property_name, type, patient_positions, + temp_path, work_entries, num_threads, + num_patients); return; case arrow::Type::TIMESTAMP: @@ -2133,8 +1925,9 @@ void process_property(const std::string& property_name, case arrow::Type::UINT16: case arrow::Type::UINT32: case arrow::Type::UINT64: - process_primitive_property(property_name, patient_lengths, - temp_path, work_entries, num_threads); + process_primitive_property(property_name, patient_positions, + temp_path, work_entries, num_threads, + num_patients); return; default: @@ -2261,7 +2054,6 @@ struct PropertyNullReader { } } - MmapFile& zdict_file; MmapFile& data_file; std::unique_ptr context; @@ -2401,10 +2193,10 @@ void process_null_map( int num_patients) { MmapFile length_file(destination_path / "meds_reader.length"); absl::Span patient_lengths = length_file.data(); - + std::vector zdict_files; std::vector data_files; - + for (const auto& entry : properties) { zdict_files.emplace_back(destination_path / entry.first / "zdict"); data_files.emplace_back(destination_path / entry.first / "data"); @@ -2413,9 +2205,10 @@ void process_null_map( std::vector> property_readers(num_threads); for (int i = 0; i < num_threads; i++) { for (size_t j = 0; j < properties.size(); j++) { -const auto& entry = properties[j]; + const auto& entry = properties[j]; property_readers[i].emplace_back(destination_path / entry.first, - entry.first, zdict_files[j], data_files[j]); + entry.first, zdict_files[j], + data_files[j]); } } @@ -2573,15 +2366,58 @@ void create_database(const char* source, const char* destination, } } - auto patient_lengths_and_num_patients = + auto patient_ids_and_lengths = process_patient_id(destination_path, work_entries, num_threads); - const auto& patient_lengths = patient_lengths_and_num_patients.first; - const auto& num_patients = patient_lengths_and_num_patients.second; + std::vector> + patient_ids_and_location; + std::vector>> patient_positions; + + for (size_t i = 0; i < patient_ids_and_lengths.size(); i++) { + patient_positions.emplace_back(patient_ids_and_lengths[i].size()); + for (size_t j = 0; j < patient_ids_and_lengths[i].size(); j++) { + auto e = patient_ids_and_lengths[i][j]; + patient_ids_and_location.emplace_back(e.first, e.second, i, j); + } + } + + pdqsort(std::begin(patient_ids_and_location), + std::end(patient_ids_and_location)); + + std::vector flat_patient_ids; + std::vector flat_patient_lengths; + + for (size_t i = 0; i < patient_ids_and_location.size(); i++) { + auto entry = patient_ids_and_location[i]; + flat_patient_ids.emplace_back(std::get<0>(entry)); + flat_patient_lengths.emplace_back(std::get<1>(entry)); + patient_positions[std::get<2>(entry)][std::get<3>(entry)] = { + i, std::get<1>(entry)}; + } + + { + std::ofstream patient_ids_file( + destination_path / "patient_id", + std::ios_base::out | std::ios_base::binary | std::ios_base::trunc); + + patient_ids_file.write((const char*)flat_patient_ids.data(), + sizeof(int64_t) * flat_patient_ids.size()); + + std::ofstream patient_lengths_file( + destination_path / "meds_reader.length", + std::ios_base::out | std::ios_base::binary | std::ios_base::trunc); + + patient_lengths_file.write( + (const char*)flat_patient_lengths.data(), + sizeof(uint32_t) * flat_patient_lengths.size()); + } + + size_t num_patients = patient_ids_and_location.size(); for (const auto& property : properties) { - process_property(property.first, property.second, patient_lengths, - destination_path, work_entries, num_threads); + process_property(property.first, property.second, patient_positions, + destination_path, work_entries, num_threads, + num_patients); } if (properties.size() > 64) { diff --git a/src/meds_reader/__init__.py b/src/meds_reader/__init__.py index 4f88bc9..ce8e14b 100644 --- a/src/meds_reader/__init__.py +++ b/src/meds_reader/__init__.py @@ -145,24 +145,6 @@ def _runner( result_queue.put(result) -def _filter_patients( - all_patient_ids: np.ndarray, filter_list: Sequence[int] -) -> np.ndarray: - found_patients = all_patient_ids[np.isin(all_patient_ids, filter_list)] - if len(found_patients) != len(filter_list): - if len(set(filter_list)) != len(filter_list): - raise ValueError( - f"Called filter with a set of patient ids with duplicates {len(set(filter_list))} {len(filter_list)}" - ) - - missing_patients = [a for a in filter_list if a not in all_patient_ids] - raise ValueError( - f"Called filter, but couldn't find patients {repr(missing_patients)} {len(filter_list)} {len(found_patients)}" - ) - - return found_patients - - class _PatientDatabaseWrapper: def __init__(self, db: PatientDatabase, patients_ids: np.ndarray): self._db = db @@ -186,10 +168,7 @@ def __iter__(self) -> Iterator[int]: def filter(self, patient_ids: Sequence[int]): return cast( - PatientDatabase, - _PatientDatabaseWrapper( - self._db, _filter_patients(self._selected_patients, patient_ids) - ), + PatientDatabase, _PatientDatabaseWrapper(self._db, np.sort(patient_ids)) ) def map(self, map_func: Callable[[Iterator[Patient]], A]) -> Iterator[A]: @@ -245,9 +224,7 @@ def filter(self, patient_ids: Sequence[int]) -> PatientDatabase: """Filter to a provided set of patient ids""" return cast( PatientDatabase, - _PatientDatabaseWrapper( - self, _filter_patients(self._all_patient_ids, patient_ids) - ), + _PatientDatabaseWrapper(self, np.sort(patient_ids)), ) def map( diff --git a/tests/api_test.py b/tests/api_test.py index 5ae22c7..1724d4a 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -171,9 +171,6 @@ def test_lookup(patient_database): def test_filter(patient_database): - with pytest.raises(ValueError): - sub_database = patient_database.filter([32, 45345]) - sub_database = patient_database.filter([32]) assert len(sub_database) == 1