Skip to content

Commit

Permalink
Attempt fix large number of threads
Browse files Browse the repository at this point in the history
Signed-off-by: Joaquin Anton Guirao <[email protected]>
  • Loading branch information
jantonguirao committed Jul 23, 2024
1 parent 0a35363 commit f156720
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 13 deletions.
3 changes: 3 additions & 0 deletions dali/core/cuda_event_pool.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ CUDAEventPool::CUDAEventPool(unsigned event_flags) {
int num_devices = 0;
CUDA_CALL(cudaGetDeviceCount(&num_devices));
dev_events_.resize(num_devices);
for (int i = 0; i < 20000; i++) {
Put(CUDAEvent::CreateWithFlags(cudaEventDisableTiming));
}
}

CUDAEvent CUDAEventPool::Get(int device_id) {
Expand Down
26 changes: 13 additions & 13 deletions dali/operators/imgcodec/image_decoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -572,17 +572,17 @@ class ImageDecoder : public StatelessOperator<Backend> {
};
};

int nblocks = tp_->NumThreads() + 1;
if (nsamples > nblocks * 4) {
if (nsamples < 16) {
get_task(0, 1)(-1); // run all in current thread
} else {
int nblocks = std::max(tp_->NumThreads() + 1, 8);
int block_idx = 0;
for (; block_idx < tp_->NumThreads(); block_idx++) {
for (; block_idx < nblocks - 1; block_idx++) {
tp_->AddWork(get_task(block_idx, nblocks), -block_idx);
}
tp_->RunAll(false); // start work but not wait
get_task(block_idx, nblocks)(-1); // run last block
tp_->WaitForWork(); // wait for the other threads
} else { // not worth parallelizing
get_task(0, 1)(-1); // run all in current thread
tp_->RunAll(false); // start work but not wait
get_task(block_idx, nblocks)(-1); // run last block
tp_->WaitForWork(); // wait for the other threads
}

output_descs[0] = {std::move(shapes), dtype_};
Expand Down Expand Up @@ -773,17 +773,17 @@ class ImageDecoder : public StatelessOperator<Backend> {
};
};

int nblocks = tp_->NumThreads() + 1;
if (decode_nsamples > nblocks * 4) {
if (decode_nsamples < 16) {
get_task(0, 1)(-1); // run all in current thread
} else {
int nblocks = std::max(tp_->NumThreads() + 1, 8);
int block_idx = 0;
for (; block_idx < tp_->NumThreads(); block_idx++) {
for (; block_idx < nblocks - 1; block_idx++) {
tp_->AddWork(get_task(block_idx, nblocks), -block_idx);
}
tp_->RunAll(false); // start work but not wait
get_task(block_idx, nblocks)(-1); // run last block
tp_->WaitForWork(); // wait for the other threads
} else { // not worth parallelizing
get_task(0, 1)(-1); // run all in current thread
}

for (int orig_idx : decode_sample_idxs_) {
Expand Down

0 comments on commit f156720

Please sign in to comment.