diff --git a/data/processes/augment_data.py b/data/processes/augment_data.py index b06f575..55ce8de 100644 --- a/data/processes/augment_data.py +++ b/data/processes/augment_data.py @@ -63,13 +63,20 @@ def may_augment_annotation(self, aug: imgaug.augmenters.Augmenter, data, shape): line_polys = [] keypoints = [] texts = [] + new_polys = [] + for line in data['lines']: texts.append(line['text']) + new_poly = [] for p in line['poly']: + new_poly.append((p[0], p[1])) keypoints.append(imgaug.Keypoint(p[0], p[1])) + new_polys.append(new_poly) - keypoints = aug.augment_keypoints([imgaug.KeypointsOnImage(keypoints=keypoints, shape=shape)])[0].keypoints - new_polys = np.array([p.x, p.y] for p in keypoints).reshape([-1, 4, 2]) + if not self.only_resize: + keypoints = aug.augment_keypoints([imgaug.KeypointsOnImage(keypoints=keypoints, shape=shape)])[0].keypoints + new_polys = np.array([[p.x, p.y] for p in keypoints]).reshape((-1, 4, 2)) + for i in range(len(texts)): poly = new_polys[i] line_polys.append({ @@ -79,4 +86,3 @@ def may_augment_annotation(self, aug: imgaug.augmenters.Augmenter, data, shape): }) data['polys'] = line_polys -