-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathinference.py
109 lines (95 loc) · 3.85 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
from flask import Flask, request, render_template
from tensorflow.keras.models import load_model # type: ignore
from tensorflow.keras.preprocessing import image # type: ignore
import numpy as np
import os
class InferenceModel:
"""
A class to load a trained model and handle file uploads for predictions.
"""
def __init__(self, model_path):
"""
Initialize the InferenceModel class.
Args:
model_path (str): Path to the saved Keras model.
"""
self.model = load_model(model_path)
self.app = Flask(__name__)
self.app.config['UPLOAD_FOLDER'] = 'uploads'
self.model_path = model_path
@self.app.route('/', methods=['GET', 'POST'])
def upload_file():
"""
Handle file upload and prediction requests.
Returns:
--------
str
The rendered HTML template with the result or error message.
"""
if request.method == 'POST':
# check if the post request has the file part
if 'file' not in request.files:
return render_template('index.html', error='no file part')
file = request.files['file']
# if user does not select file, browser also
# submit an empty part without filename
if file.filename == '':
return render_template('index.html', error='no selected file')
if file and self.allowed_file(file.filename):
# save the uploaded file to the uploads directory
filename = os.path.join(self.app.config['UPLOAD_FOLDER'], file.filename)
file.save(filename)
# predict if the image is Real or Fake
prediction, prediction_percentage = self.predict_image(filename)
# clean up the uploaded file
os.remove(filename)
# determine result message
result = 'Fake' if prediction >= 0.5 else 'Real'
# render result to the user
return render_template('index.html', result=result, prediction_percentage=prediction_percentage)
else:
return render_template('index.html', error='allowed file types are png, jpg, jpeg')
return render_template('index.html')
def allowed_file(self, filename):
"""
Check if a file has an allowed extension.
Parameters:
-----------
filename : str
The name of the file to check.
Returns:
--------
bool
True if the file has an allowed extension, False otherwise.
"""
ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg'}
return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
def predict_image(self, file_path):
"""
Predict whether an image is Real or Fake using the loaded model.
Parameters:
-----------
file_path : str
The path to the image file.
Returns:
--------
tuple
A tuple containing the prediction and the prediction percentage.
"""
img = image.load_img(file_path, target_size=(128, 128))
img_array = image.img_to_array(img)
img_array = np.expand_dims(img_array, axis=0)
result = self.model.predict(img_array)
prediction = result[0][0]
prediction_percentage = prediction * 100
return prediction, prediction_percentage
def run(self):
"""
Run the Flask application with the loaded model.
"""
self.app.run(debug=True)
if __name__ == '__main__':
# inference
model_path = 'deepfake_detector_model.keras'
inference_model = InferenceModel(model_path)
inference_model.run()