Skip to content

Commit

Permalink
Draft refactor
Browse files Browse the repository at this point in the history
Signed-off-by: Joaquin Anton Guirao <[email protected]>
  • Loading branch information
jantonguirao committed Jul 24, 2024
1 parent 5806b33 commit 72f8945
Showing 1 changed file with 101 additions and 149 deletions.
250 changes: 101 additions & 149 deletions dali/operators/imgcodec/image_decoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -468,7 +468,7 @@ class ImageDecoder : public StatelessOperator<Backend> {
}

bool CanInferOutputs() const override {
return true;
return false;
}

void ParseSample(ParsedSample &parsed_sample, span<const uint8_t> encoded) {
Expand Down Expand Up @@ -503,90 +503,8 @@ class ImageDecoder : public StatelessOperator<Backend> {
return std::is_same<MixedBackend, Backend>::value ? thread_pool_.get() : &ws.GetThreadPool();
}


bool SetupImpl(std::vector<OutputDesc> &output_descs, const Workspace &ws) override {
DomainTimeRange tr("Setup", DomainTimeRange::kOrange);
tp_ = GetThreadPool(ws);
assert(tp_ != nullptr);
auto auto_cleanup = AtScopeExit([&] {
tp_ = nullptr;
});

output_descs.resize(1);
auto &input = ws.template Input<CPUBackend>(0);
int nsamples = input.num_samples();

SetupRoiGenerator(spec_, ws);
TensorListShape<> shapes;
shapes.resize(nsamples, 3);
while (static_cast<int>(state_.size()) < nsamples)
state_.push_back(std::make_unique<SampleState>());
rois_.resize(nsamples);

const bool use_cache = cache_ && cache_->IsCacheEnabled() && dtype_ == DALI_UINT8;
auto get_task = [&](int block_idx, int nblocks) {
return [&, block_idx, nblocks](int tid) {
int i_start = nsamples * block_idx / nblocks;
int i_end = nsamples * (block_idx + 1) / nblocks;
for (int i = i_start; i < i_end; i++) {
auto *st = state_[i].get();
assert(st != nullptr);
const auto &input_sample = input[i];

auto src_info = input.GetMeta(i).GetSourceInfo();
if (use_cache && cache_->IsInCache(src_info)) {
auto cached_shape = cache_->CacheImageShape(src_info);
auto roi = GetRoi(spec_, ws, i, cached_shape);
if (!roi.use_roi()) {
shapes.set_tensor_shape(i, cached_shape);
continue;
}
}
ParseSample(st->parsed_sample,
span<const uint8_t>{static_cast<const uint8_t *>(input_sample.raw_data()),
volume(input_sample.shape())});
st->out_shape = st->parsed_sample.dali_img_info.shape;
st->out_shape[2] = NumberOfChannels(format_, st->out_shape[2]);
if (use_orientation_ &&
(st->parsed_sample.nvimgcodec_img_info.orientation.rotated % 180 != 0)) {
std::swap(st->out_shape[0], st->out_shape[1]);
}
ROI &roi = rois_[i] = GetRoi(spec_, ws, i, st->out_shape);
if (roi.use_roi()) {
auto roi_sh = roi.shape();
if (roi.end.size() >= 2) {
DALI_ENFORCE(0 <= roi.end[0] && roi.end[0] <= st->out_shape[0] &&
0 <= roi.end[1] && roi.end[1] <= st->out_shape[1],
"ROI end must fit within the image bounds");
}
if (roi.begin.size() >= 2) {
DALI_ENFORCE(0 <= roi.begin[0] && roi.begin[0] <= st->out_shape[0] &&
0 <= roi.begin[1] && roi.begin[1] <= st->out_shape[1],
"ROI begin must fit within the image bounds");
}
st->out_shape[0] = roi_sh[0];
st->out_shape[1] = roi_sh[1];
}
shapes.set_tensor_shape(i, st->out_shape);
}
};
};

int nblocks = tp_->NumThreads() + 1;
if (nsamples > nblocks * 4) {
int block_idx = 0;
for (; block_idx < tp_->NumThreads(); block_idx++) {
tp_->AddWork(get_task(block_idx, nblocks), -block_idx);
}
tp_->RunAll(false); // start work but not wait
get_task(block_idx, nblocks)(-1); // run last block
tp_->WaitForWork(); // wait for the other threads
} else { // not worth parallelizing
get_task(0, 1)(-1); // run all in current thread
}

output_descs[0] = {std::move(shapes), dtype_};
return true;
bool SetupImpl(std::vector<OutputDesc> &output_desc, const Workspace &ws) override {
return false;
}

/**
Expand Down Expand Up @@ -710,6 +628,8 @@ class ImageDecoder : public StatelessOperator<Backend> {
const auto &input = ws.Input<CPUBackend>(0);
auto &output = ws.template Output<typename OutBackend<Backend>::type>(0);
output.SetLayout("HWC");
output.SetContiguity(BatchContiguity::Noncontiguous);
output.set_type(dtype_);
int nsamples = input.num_samples();
assert(output.num_samples() == nsamples);

Expand All @@ -719,14 +639,12 @@ class ImageDecoder : public StatelessOperator<Backend> {
tp_ = nullptr;
});

bool has_any_roi = false;
for (auto &roi : rois_)
has_any_roi |= roi.use_roi();

nvimgcodecDecodeParams_t decode_params = {NVIMGCODEC_STRUCTURE_TYPE_DECODE_PARAMS,
sizeof(nvimgcodecDecodeParams_t), nullptr};
decode_params.apply_exif_orientation = static_cast<int>(use_orientation_);
decode_params.enable_roi = static_cast<int>(has_any_roi);
SetupRoiGenerator(spec_, ws);
TensorListShape<> shapes;
shapes.resize(nsamples, 3);
while (static_cast<int>(state_.size()) < nsamples)
state_.push_back(std::make_unique<SampleState>());
rois_.resize(nsamples);

assert(static_cast<int>(state_.size()) >= nsamples);
batch_encoded_streams_.clear();
Expand All @@ -735,75 +653,107 @@ class ImageDecoder : public StatelessOperator<Backend> {
batch_images_.reserve(nsamples);
decode_sample_idxs_.clear();
decode_sample_idxs_.reserve(nsamples);
load_from_cache_.resize(nsamples);

// TODO(janton): consider extending cache to different dtype as well
const bool use_cache = cache_ && cache_->IsCacheEnabled() && dtype_ == DALI_UINT8;
if (use_cache) {
int samples_to_load = 0;
DomainTimeRange tr(make_string("CacheLoad"), DomainTimeRange::kOrange);
for (int orig_idx = 0; orig_idx < nsamples; orig_idx++) {
auto src_info = input.GetMeta(orig_idx).GetSourceInfo();
// To simplify things, we do not allow caching ROIs
bool has_roi = rois_[orig_idx].use_roi();
if (cache_->IsInCache(src_info) && !has_roi) {
cache_->DeferCacheLoad(src_info, output.template mutable_tensor<uint8_t>(orig_idx));
samples_to_load++;
} else {
decode_sample_idxs_.push_back(orig_idx);
}
}
if (samples_to_load > 0)
cache_->LoadDeferred(ws.stream());
} else {
decode_sample_idxs_.resize(nsamples);
std::iota(decode_sample_idxs_.begin(), decode_sample_idxs_.end(), 0);
}
auto get_task = [&](int block_idx, int nblocks) {
return [&, block_idx, nblocks](int tid) {
int i_start = nsamples * block_idx / nblocks;
int i_end = nsamples * (block_idx + 1) / nblocks;
for (int i = i_start; i < i_end; i++) {
auto *st = state_[i].get();
assert(st != nullptr);
const auto &input_sample = input[i];

int decode_nsamples = decode_sample_idxs_.size();
{
DomainTimeRange tr(make_string("Prepare descs"), DomainTimeRange::kOrange);
auto get_task = [&](int block_idx, int nblocks) {
return [&, block_idx, nblocks](int tid) {
int i_start = decode_nsamples * block_idx / nblocks;
int i_end = decode_nsamples * (block_idx + 1) / nblocks;
for (int i = i_start; i < i_end; i++) {
int orig_idx = decode_sample_idxs_[i];
PrepareOutput(*state_[orig_idx], output[orig_idx], rois_[orig_idx], ws);
auto src_info = input.GetMeta(i).GetSourceInfo();
if (use_cache && cache_->IsInCache(src_info)) {
auto cached_shape = cache_->CacheImageShape(src_info);
auto roi = GetRoi(spec_, ws, i, cached_shape);
if (!roi.use_roi()) {
output.ResizeSample(i, cached_shape);
cache_->DeferCacheLoad(src_info, output.template mutable_tensor<uint8_t>(i));
load_from_cache_[i] = true;
// shapes.set_tensor_shape(i, cached_shape);
continue;
}
}

load_from_cache_[i] = false;

ParseSample(st->parsed_sample,
span<const uint8_t>{static_cast<const uint8_t *>(input_sample.raw_data()),
volume(input_sample.shape())});
st->out_shape = st->parsed_sample.dali_img_info.shape;
st->out_shape[2] = NumberOfChannels(format_, st->out_shape[2]);
if (use_orientation_ &&
(st->parsed_sample.nvimgcodec_img_info.orientation.rotated % 180 != 0)) {
std::swap(st->out_shape[0], st->out_shape[1]);
}
};
ROI &roi = rois_[i] = GetRoi(spec_, ws, i, st->out_shape);
if (roi.use_roi()) {
auto roi_sh = roi.shape();
if (roi.end.size() >= 2) {
DALI_ENFORCE(0 <= roi.end[0] && roi.end[0] <= st->out_shape[0] &&
0 <= roi.end[1] && roi.end[1] <= st->out_shape[1],
"ROI end must fit within the image bounds");
}
if (roi.begin.size() >= 2) {
DALI_ENFORCE(0 <= roi.begin[0] && roi.begin[0] <= st->out_shape[0] &&
0 <= roi.begin[1] && roi.begin[1] <= st->out_shape[1],
"ROI begin must fit within the image bounds");
}
st->out_shape[0] = roi_sh[0];
st->out_shape[1] = roi_sh[1];
}
output.ResizeSample(i, st->out_shape);
PrepareOutput(*state_[i], output[i], rois_[i], ws);
// shapes.set_tensor_shape(i, st->out_shape);
}
};
};

int nblocks = tp_->NumThreads() + 1;
if (decode_nsamples > nblocks * 4) {
{
DomainTimeRange tr("Setup", DomainTimeRange::kOrange);
if (nsamples <= 16) {
get_task(0, 1)(-1); // run all in current thread
} else {
int nblocks = std::min(tp_->NumThreads() + 1, nsamples / 4);
int block_idx = 0;
for (; block_idx < tp_->NumThreads(); block_idx++) {
for (; block_idx < nblocks; block_idx++) {
tp_->AddWork(get_task(block_idx, nblocks), -block_idx);
}
tp_->RunAll(false); // start work but not wait
get_task(block_idx, nblocks)(-1); // run last block
tp_->WaitForWork(); // wait for the other threads
} else { // not worth parallelizing
get_task(0, 1)(-1); // run all in current thread
tp_->RunAll(false); // start work but not wait
get_task(block_idx, nblocks)(-1); // run last block in current thread
tp_->WaitForWork(); // wait for the other threads
}
}
bool has_any_roi = false;
int samples_to_load = 0;
for (int i = 0; i < nsamples; i++) {
if (!load_from_cache_[i])
decode_sample_idxs_.push_back(i);
else
samples_to_load++;
has_any_roi |= rois_[i].use_roi();

for (int orig_idx : decode_sample_idxs_) {
auto &st = *state_[orig_idx];
batch_encoded_streams_.push_back(st.parsed_sample.encoded_stream);
batch_images_.push_back(st.image);
}
}

// This is a workaround for nvImageCodec <= 0.2
auto any_need_processing = [&]() {
for (int orig_idx : decode_sample_idxs_) {
auto& st = state_[orig_idx];
assert(ws.stream() == st->image_info.cuda_stream); // assuming this is true
if (st->need_processing)
return true;
}
return false;
};
if (ws.has_stream() && need_host_sync_alloc() && any_need_processing()) {
nvimgcodecDecodeParams_t decode_params = {NVIMGCODEC_STRUCTURE_TYPE_DECODE_PARAMS,
sizeof(nvimgcodecDecodeParams_t), nullptr};
decode_params.apply_exif_orientation = static_cast<int>(use_orientation_);
decode_params.enable_roi = static_cast<int>(has_any_roi);

int decode_nsamples = decode_sample_idxs_.size();
bool any_need_processing = false;
for (int orig_idx : decode_sample_idxs_) {
auto &st = *state_[orig_idx];
assert(ws.stream() == st.image_info.cuda_stream); // assuming this is true
any_need_processing |= st.need_processing;
batch_encoded_streams_.push_back(st.parsed_sample.encoded_stream);
batch_images_.push_back(st.image);
}

if (ws.has_stream() && need_host_sync_alloc() && any_need_processing) {
DomainTimeRange tr("alloc sync", DomainTimeRange::kOrange);
CUDA_CALL(cudaStreamSynchronize(ws.stream()));
}
Expand Down Expand Up @@ -911,6 +861,8 @@ class ImageDecoder : public StatelessOperator<Backend> {
// In case of cache, the batch we send to the decoder might have fewer samples than the full batch
// This vector is used to get the original index of the decoded samples
std::vector<size_t> decode_sample_idxs_;
// True if the sample is loaded from cache, false otherwise
std::vector<bool> load_from_cache_;

// Manually loaded extensions
std::vector<nvimgcodecExtensionDesc_t> extensions_descs_;
Expand Down

0 comments on commit 72f8945

Please sign in to comment.