Skip to content

Commit

Permalink
Fix minibatch counting in python dataset reader
Browse files Browse the repository at this point in the history
  • Loading branch information
fiedorowicz1 committed Jul 2, 2024
1 parent 93e04a8 commit 6dbae1a
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,9 @@ class python_dataset_reader : public generic_data_reader
#ifdef LBANN_HAS_DISTCONV
/** @brief Whether or not tensor needs shuffling for distconv. */
bool m_tensor_shuffle_required = true;
/** @brief The current number of minibatches in the epoch that have been
* fetched and returned by fetch_data_block. */
uint64_t m_fetched_minibatch_count;
#endif // LBANN_HAS_DISTCONV
};

Expand Down
8 changes: 6 additions & 2 deletions src/data_ingestion/readers/data_reader_python_dataset.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -187,13 +187,14 @@ void python_dataset_reader::shuffle_responses(DataType* responses_ptr)
execution_mode mode = exec_mode_from_string(get_role());
dataset& ds = get_trainer().get_data_coordinator().get_dataset(mode);
uint64_t global_mb_size{};
if (m_dataset_minibatch_offset < (ds.get_num_iterations_per_epoch() - 1)) {
if (m_fetched_minibatch_count < (ds.get_num_iterations_per_epoch() - 1)) {
global_mb_size = ds.get_mini_batch_size();
}
else if (m_dataset_minibatch_offset ==
else if (m_fetched_minibatch_count ==
(ds.get_num_iterations_per_epoch() - 1)) {
global_mb_size = ds.get_last_mini_batch_size();
}
m_fetched_minibatch_count++;

uint64_t local_mb_size = global_mb_size / nprocs;
uint64_t extra_samples = global_mb_size % nprocs;
Expand Down Expand Up @@ -344,6 +345,9 @@ void python_dataset_reader::queue_epoch()
m_dataset_minibatch_offset = 0;
m_dataset_sample_offset = 0;
m_queued_samples = 0;
#ifdef LBANN_HAS_DISTCONV
m_fetched_minibatch_count = 0;
#endif // LBANN_HAS_DISTCONV

// Prefetch the first set of samples (if less than minibatch size, the first
// minibatch read will take care of the rest)
Expand Down

0 comments on commit 6dbae1a

Please sign in to comment.