diff --git a/data_configs/buggy_data.yaml b/data_configs/buggy_data.yaml new file mode 100644 index 00000000..085dba83 --- /dev/null +++ b/data_configs/buggy_data.yaml @@ -0,0 +1,41 @@ +# Images and labels direcotry should be relative to train.py +TRAIN_DIR_IMAGES: '../input/coco2017/train2017' +TRAIN_DIR_LABELS: '../input/coco2017/train_xml' +VALID_DIR_IMAGES: '../input/coco2017/val2017' +VALID_DIR_LABELS: '../input/coco2017/val_xml' + +# Class names. +CLASSES: [ + '__background__', + 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', + 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', + 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', + 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', + 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', + 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', + 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', + 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', + 'hair drier', 'toothbrush' +] + +COCO_91_CLASSES: [ + '__background__', + 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', + 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign', + 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', + 'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A', + 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', + 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', + 'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', + 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', + 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table', + 'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', + 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book', + 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush' +] + +# Number of classes (object classes + 1 for background class in Faster RCNN). +NC: 81 + +# Whether to save the predictions of the validation set while training. +SAVE_VALID_PREDICTION_IMAGES: True \ No newline at end of file diff --git a/datasets.py b/datasets.py index 15b7f3a9..a5551f48 100644 --- a/datasets.py +++ b/datasets.py @@ -56,13 +56,50 @@ def __init__( def read_and_clean(self): print('Checking Labels and images...') - # Discard any image file when no annotation file is found. + images_to_remove = [] + problematic_images = [] + for image_name in tqdm(self.all_images, total=len(self.all_images)): possible_xml_name = os.path.join(self.labels_path, os.path.splitext(image_name)[0]+'.xml') if possible_xml_name not in self.all_annot_paths: - print(f"{possible_xml_name} not found...") - print(f"Removing {image_name} image") - self.all_images = [image_instance for image_instance in self.all_images if image_instance != image_name] + print(f"⚠️ {possible_xml_name} not found... Removing {image_name}") + images_to_remove.append(image_name) + continue + + # Check for invalid bounding boxes + tree = et.parse(possible_xml_name) + root = tree.getroot() + invalid_bbox = False + + for member in root.findall('object'): + xmin = float(member.find('bndbox').find('xmin').text) + xmax = float(member.find('bndbox').find('xmax').text) + ymin = float(member.find('bndbox').find('ymin').text) + ymax = float(member.find('bndbox').find('ymax').text) + + if xmin >= xmax or ymin >= ymax: + invalid_bbox = True + break + + if invalid_bbox: + problematic_images.append(image_name) + images_to_remove.append(image_name) + + # Remove problematic images and their annotations + self.all_images = [img for img in self.all_images if img not in images_to_remove] + self.all_annot_paths = [ + path for path in self.all_annot_paths + if not any(os.path.splitext(os.path.basename(path))[0] + ext in images_to_remove + for ext in self.image_file_types) + ] + + # Print warnings for problematic images + if problematic_images: + print("\n⚠️ The following images have invalid bounding boxes and will be removed:") + for img in problematic_images: + print(f"⚠️ {img}") + + print(f"Removed {len(images_to_remove)} problematic images and annotations.") def resize(self, im, square=False): if square: