-
Notifications
You must be signed in to change notification settings - Fork 9
/
custom_dataset.h
48 lines (35 loc) · 1.67 KB
/
custom_dataset.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
#pragma once
#include <vector>
#include <tuple>
#include <torch/torch.h>
#include <opencv2/opencv.hpp>
#include "utils.h"
class CustomDataset : public torch::data::Dataset<CustomDataset>
{
private:
std::vector<std::tuple<std::string /*file location*/, int64_t /*label*/>> csv_;
public:
explicit CustomDataset(std::string& file_names_csv)
// Load csv file with file locations and labels.
: csv_(ReadCsv(file_names_csv)) {
};
// Override the get method to load custom data.
torch::data::Example<> get(size_t index) override {
std::string file_location = std::get<0>(csv_[index]);
int64_t label = std::get<1>(csv_[index]);
// Load image with OpenCV.
cv::Mat img = cv::imread(file_location);
// Convert the image and label to a tensor.
// Here we need to clone the data, as from_blob does not change the ownership of the underlying memory,
// which, therefore, still belongs to OpenCV. If we did not clone the data at this point, the memory
// would be deallocated after leaving the scope of this get method, which results in undefined behavior.
torch::Tensor img_tensor = torch::from_blob(img.data, {img.rows, img.cols, 3}, torch::kByte).clone();
img_tensor = img_tensor.permute({2, 0, 1}); // convert to CxHxW
torch::Tensor label_tensor = torch::full({1}, label);
return {img_tensor, label_tensor};
};
// Override the size method to infer the size of the data set.
torch::optional<size_t> size() const override {
return csv_.size();
};
};