-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlibnsfw.py
181 lines (131 loc) · 5.53 KB
/
libnsfw.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
import math
import numpy as np
import PIL.Image
import caffe
class NSFWModel(object):
model_def_filename = "open_nsfw_model/deploy.prototxt"
model_weights_filename = "open_nsfw_model/resnet_50_1by2_nsfw.caffemodel"
def __init__(self, deffile=model_def_filename, weightsfile=model_weights_filename):
model = caffe.Net(deffile, caffe.TEST, weights=weightsfile)
# Cache some meta-informations about the model
self.model_inshape = model.blobs['data'].data.shape[1:]
self.model_insize = self.model_inshape[1:]
self.model_inname = model.inputs[0]
self.model_outname = next(reversed(model.blobs))
transformer = caffe.io.Transformer({'data': model.blobs[self.model_inname].data.shape})
transformer.set_transpose('data', (2, 0, 1)) # Channel first format
transformer.set_channel_swap('data', (2, 1, 0)) # swap channels from RGB to BGR
transformer.set_raw_scale('data', 255) # rescale from [0, 1] to [0, 255]
transformer.set_mean('data', np.array([104, 117, 123])) # subtract the dataset-mean value in each channel
self.model = model
self.transformer = transformer
def _load_frames(self, pilimg):
maxunevenresize = 0.2
overlap = 0.5
nonoverlap = 1 - overlap
retimgs = []
frameno = 0
while True:
try:
pilimg.seek(frameno)
except EOFError:
break
except OSError:
# The image file might be truncated
break
frame = pilimg
if frame.mode != 'RGB':
frame = frame.convert("RGB")
fw, fh = frame.size
w, h = self.model_insize
if 1 - maxunevenresize <= (fw * h) / (w * fh) <= 1 + maxunevenresize:
fh, fw = h, w
elif fw * h > w * fh:
fw = round(fw * h / fh)
fh = h
else:
fh = round(fh * w / fw)
fw = w
try:
frame = frame.resize((fw, fh), PIL.Image.BILINEAR)
except OSError:
# The image file might be truncated
break
frame = np.array(frame).astype(np.float32) / 255.0
w, h = self.model_insize
nh = math.ceil((fh - h * overlap) / (h * nonoverlap))
nw = math.ceil((fw - w * overlap) / (w * nonoverlap))
for hoff in np.linspace(0, fh - h, nh, dtype=np.int32):
for woff in np.linspace(0, fw - w, nw, dtype=np.int32):
tile = frame[hoff:hoff+h, woff:woff+w]
retimgs.append(tile)
frameno += 10
return np.array(retimgs)
def preprocess_pil(self, pilimgs):
"""
Preprocess PIL-compatible image objects. Each PIL image can result in
several returned frames.
Return the index array of the pilimgs preprocessed and a numpy array of
frames. The index array stores the index of the PIL image that
generated the frame.
"""
# Each PIL image can result in several imgs. This array hold the
# index in pilimgs for each entry in imgs.
pilidx = []
imgs = []
for i, pilimg in enumerate(pilimgs):
frames = self._load_frames(pilimg)
for frame in frames:
frame = self.transformer.preprocess('data', frame)
imgs.append(frame)
pilidx.append(i)
imgs = np.array(imgs)
pilidx = np.array(pilidx, dtype=np.int)
return pilidx, imgs
def preprocess_files(self, files):
"""
Preprocess file-like objects. Each file can result in several returned
frames.
Return the index array of the files preprocessed and a numpy array of
the images. The index array stores the index of the PIL image that
generated the frame.
"""
imgs = []
filesidx = []
for i, f in enumerate(files):
try:
img = PIL.Image.open(f)
except:
continue
imgs.append(img)
filesidx.append(i)
filesidx = np.array(filesidx)
imgidx, frames = self.preprocess_pil(imgs)
return filesidx[imgidx], frames
def eval(self, imgs):
"""Evaluate the NSFW score on some preprocessed images."""
assert imgs.shape[0] == 0 or imgs.shape[1:] == self.model_inshape
inname = self.model_inname
outname = self.model_outname
outputs = self.model.forward_all(blobs=[outname], **{inname: imgs})
outputs = outputs[outname]
# Empty arrays are shaped (0,) instead of (0, 2).
outputs = outputs.reshape((-1, 2))
return outputs[:, 1]
def eval_pil(self, pilimgs):
"""
Evaluate the NSFW score on PIL-compatible image objects.
Return the index of the pilimgs processed and their score.
"""
pilidx, frames = self.preprocess_pil(pilimgs)
uniqidx = sorted(set(pilidx))
scoresframes = self.eval(frames)
out = [scoresframes[pilidx == i].max() for i in uniqidx]
return np.array(uniqidx), np.array(out)
def eval_files(self, files):
"""Evaluate the NSFW score on filenames or file-like objects."""
filesidx, frames = self.preprocess_files(files)
uniqidx = sorted(set(filesidx))
scoresframes = self.eval(frames)
out = [scoresframes[filesidx == i].max() for i in uniqidx]
return np.array(uniqidx), np.array(out)