-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathModelNetDataLoader.py
128 lines (108 loc) · 3.85 KB
/
ModelNetDataLoader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import os
import warnings
import numpy as np
from torch.utils.data import Dataset
warnings.filterwarnings("ignore")
def pc_normalize(pc):
centroid = np.mean(pc, axis=0)
pc = pc - centroid
m = np.max(np.sqrt(np.sum(pc**2, axis=1)))
pc = pc / m
return pc
def farthest_point_sample(point, npoint):
"""
Farthest point sampler works as follows:
1. Initialize the sample set S with a random point
2. Pick point P not in S, which maximizes the distance d(P, S)
3. Repeat step 2 until |S| = npoint
Input:
xyz: pointcloud data, [N, D]
npoint: number of samples
Return:
centroids: sampled pointcloud index, [npoint, D]
"""
N, D = point.shape
xyz = point[:, :3]
centroids = np.zeros((npoint,))
distance = np.ones((N,)) * 1e10
farthest = np.random.randint(0, N)
for i in range(npoint):
centroids[i] = farthest
centroid = xyz[farthest, :]
dist = np.sum((xyz - centroid) ** 2, -1)
mask = dist < distance
distance[mask] = dist[mask]
farthest = np.argmax(distance, -1)
point = point[centroids.astype(np.int32)]
return point
class ModelNetDataLoader(Dataset):
def __init__(
self,
root,
npoint=1024,
split="train",
fps=False,
normal_channel=True,
cache_size=15000,
):
"""
Input:
root: the root path to the local data files
npoint: number of points from each cloud
split: which split of the data, 'train' or 'test'
fps: whether to sample points with farthest point sampler
normal_channel: whether to use additional channel
cache_size: the cache size of in-memory point clouds
"""
self.root = root
self.npoints = npoint
self.fps = fps
self.catfile = os.path.join(self.root, "modelnet40_shape_names.txt")
self.cat = [line.rstrip() for line in open(self.catfile)]
self.classes = dict(zip(self.cat, range(len(self.cat))))
self.normal_channel = normal_channel
shape_ids = {}
shape_ids["train"] = [
line.rstrip()
for line in open(os.path.join(self.root, "modelnet40_train.txt"))
]
shape_ids["test"] = [
line.rstrip()
for line in open(os.path.join(self.root, "modelnet40_test.txt"))
]
assert split == "train" or split == "test"
shape_names = ["_".join(x.split("_")[0:-1]) for x in shape_ids[split]]
# list of (shape_name, shape_txt_file_path) tuple
self.datapath = [
(
shape_names[i],
os.path.join(self.root, shape_names[i], shape_ids[split][i])
+ ".txt",
)
for i in range(len(shape_ids[split]))
]
print("The size of %s data is %d" % (split, len(self.datapath)))
self.cache_size = cache_size
self.cache = {}
def __len__(self):
return len(self.datapath)
def _get_item(self, index):
if index in self.cache:
point_set, cls = self.cache[index]
else:
fn = self.datapath[index]
cls = self.classes[self.datapath[index][0]]
cls = np.array([cls]).astype(np.int32)
point_set = np.loadtxt(fn[1], delimiter=",").astype(np.float32)
if self.fps:
point_set = farthest_point_sample(point_set, self.npoints)
else:
point_set = point_set[0 : self.npoints, :]
point_set[:, 0:3] = pc_normalize(point_set[:, 0:3])
if not self.normal_channel:
point_set = point_set[:, 0:3]
if len(self.cache) < self.cache_size:
self.cache[index] = (point_set, cls)
return point_set, cls
def __getitem__(self, index):
return self._get_item(index)