-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcounter.py
105 lines (81 loc) · 2.57 KB
/
counter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import os
import glob
import io
import xml.etree.ElementTree as ET
import pandas as pd
import random
def remove(num_obj, num_img_removed):
all_singleobj_files = []
for xml_file in glob.glob("./train/" + '/*.xml'):
tree = ET.parse(xml_file)
root = tree.getroot()
obj_list = root.findall('object')
if len(obj_list)==num_obj:
all_singleobj_files.append(xml_file)
some_singleobj_files = random.sample(all_singleobj_files, num_img_removed)
for xml_file in some_singleobj_files:
img_name = xml_file.replace("xml", "jpg")
os.remove(img_name)
os.remove(xml_file)
#
#
# remove(1, 30)
# remove(1, 35)
counter = {}
class_dic = {}
counter_above2 = 0
total_img = 0
xml_list = []
for xml_file in glob.glob("./train/" + '/*.xml'):
tree = ET.parse(xml_file)
root = tree.getroot()
obj_list = root.findall('object')
total_img += 1
if len(obj_list) > 2:
counter_above2 += 1
if len(obj_list) not in counter:
counter[len(obj_list)] = 1
else:
counter[len(obj_list)] += 1
for member in obj_list:
class_name = member.find('name').text
if class_name not in class_dic:
class_dic[class_name] = [1,]
else:
class_dic[class_name][0] += 1
print("counter_above2", counter_above2)
print("total_img", total_img)
print(class_dic)
# # temp
# for key, value in class_dic:
#
class_table = pd.DataFrame(class_dic).T
dataset_distrib = pd.DataFrame(counter, index= [0]).T.sort_index()
print("class_table", class_table)
print("dataset_distrib", dataset_distrib)
import matplotlib.pyplot as plt
plt.bar(dataset_distrib.index, dataset_distrib[0])
plt.title("frequency of num_classes per img")
plt.xlabel("number of objects")
plt.ylabel("frequency")
plt.savefig("distrib.png")
plt.clf()
plt.bar(class_table.index, class_table[0])
plt.title("frequency of each class")
plt.xlabel("number of objects")
plt.ylabel("frequency")
plt.savefig("class_distrib.png")
#
import random
#
#check whether all img are in a jpg & xml pair
imgs = glob.glob("./train/" + '/*.jpg')
xmls = glob.glob("./train/" + '/*.xml')
for xml_file in xmls:
img_name = xml_file.replace("xml", "jpg")
if img_name not in imgs:
print(xml_file.replace(".xml", ""), "is NOT MATCHING, remove it")
for jpg_file in imgs:
xml_name = jpg_file.replace("jpg", "xml")
if xml_name not in xmls:
print(jpg_file.replace(".jpg", ""), "is NOT MATCHING, remove it")