-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathcpu_caching.py
65 lines (48 loc) · 1.62 KB
/
cpu_caching.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
from monai import data
import torch
from torch.utils.data import Dataset
from src.data.transforms import loading_transforms
import argparse
import pandas as pd
from tqdm import tqdm
class HeadDatasetCache(Dataset):
def __init__(self, roi, in_channels, csv_file, cache_dir=None):
self.data = pd.read_csv(csv_file)
self.load = loading_transforms(roi, in_channels)
self.cache_dir = cache_dir
self.cache_dataset = data.PersistentDataset(
data=list([{"image": d} for d in self.data['img_path']]),
transform=self.load,
cache_dir=self.cache_dir,
)
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
try:
image = self.cache_dataset.__getitem__(idx)
print(f"image: {image['image'].shape}")
return image
except:
print("Error: {}".format(idx))
parser = argparse.ArgumentParser(description='Example of a command-line argument parser')
# Positional argument
parser.add_argument('--start_idx', type=int, help='Path to the input file')
# Optional argument
parser.add_argument('--end_idx', type=int, help='Path to the input file')
# Flag argument (boolean)
args = parser.parse_args()
device = torch.device("cuda")
roi = [96, 96, 96]
csv_file = '<path-to>/datasets/dataset.csv'
cache_dir = '<path-to>/embedding_cache'
train_ds = HeadDatasetCache(
roi,
in_channels=3,
csv_file=csv_file,
cache_dir=cache_dir,
)
for idx in tqdm(range(args.start_idx, args.end_idx)):
try:
b = train_ds.__getitem__(idx)
except:
print("Error: {}".format(idx))