-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
139 lines (123 loc) · 4.98 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import streamlit as st
from PIL import Image
import torch
import torchvision.transforms as transforms
from torch import nn
import torch.nn.functional as F
class ChestXRayModel3(nn.Module):
"""
Model replicates the architecture of TinyVGG adding another conv block layer.
"""
def __init__(self, input_shape: int, hidden_units: int, output_shape: int):
super().__init__()
self.conv_block_l = nn.Sequential(
nn.Conv2d(in_channels=input_shape,
out_channels=hidden_units,
kernel_size=3,
stride=1,
padding=1), # values we can set our self are hyperparameters
nn.ReLU(),
nn.Conv2d(in_channels=hidden_units,
out_channels=hidden_units,
kernel_size=3,
stride=1,
padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2)
)
self.conv_block_2 = nn.Sequential(
nn.Conv2d(in_channels=hidden_units,
out_channels=hidden_units,
kernel_size=3,
stride=1,
padding=1),
nn.ReLU(),
nn.Conv2d(in_channels=hidden_units,
out_channels=hidden_units,
kernel_size=3,
stride=1,
padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2)
)
self.conv_block_3 = nn.Sequential(
nn.Conv2d(in_channels=hidden_units,
out_channels=hidden_units,
kernel_size=3,
stride=1,
padding=1),
nn.ReLU(),
nn.Conv2d(in_channels=hidden_units,
out_channels=hidden_units,
kernel_size=3,
stride=1,
padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2)
)
self.classifier = nn.Sequential(
nn.Flatten(),
nn.Linear(in_features=hidden_units*28*28,
out_features=output_shape)
)
def forward(self, x):
x = self.conv_block_l(x)
# print(x.shape)
x = self.conv_block_2(x)
# print(x.shape)
x = self.conv_block_3(x)
# print(x.shape)
x = self.classifier(x)
return x
# Load your trained model (make sure to provide the path to your model)
@st.cache_resource
def load_model():
model = ChestXRayModel3(input_shape=1,
hidden_units=10,
output_shape=2) # Initialize the model
model.load_state_dict(torch.load('ChestXRay-PNEUMONIADetection.pth', map_location=torch.device('cpu')))
model.eval()
return model
# Preprocess the image
def preprocess_image(image):
transform = transforms.Compose([
transforms.Grayscale(num_output_channels=1), # Convert to grayscale
transforms.Resize((224, 224)),
transforms.ToTensor()
])
return transform(image).unsqueeze(0)
# Predict pneumonia
def predict(image, model):
with torch.no_grad():
output = model(image)
probabilities = F.softmax(output, dim=1)
_, predicted_class = torch.max(probabilities, 1)
return probabilities, predicted_class.item()
# Streamlit app
st.markdown("<h1 style='text-align: center; color: #003366;'>Pneumonia Detection from Chest X-Ray</h1>", unsafe_allow_html=True)
st.markdown("<p style='text-align: center; color: #666666; font-size: 18px;'>Upload a chest X-ray image, and the model will analyze it to predict whether it indicates pneumonia.</p>", unsafe_allow_html=True)
# Upload image
uploaded_file = st.file_uploader("**Choose a chest X-ray image:**", type="jpg")
if uploaded_file is not None:
# Load image
image = Image.open(uploaded_file)
st.image(image, caption='Uploaded Image', use_column_width=True)
# Load model
model = load_model()
# Preprocess and predict
image_tensor = preprocess_image(image)
probabilities, predicted_class = predict(image_tensor, model)
# Display result
class_names = ['Normal', 'Pneumonia'] # Adjust based on your model's classes
prediction_label = class_names[predicted_class]
confidence_score = probabilities[0][predicted_class].item() * 100
# Format the output nicely
st.subheader("Prediction Results:")
if prediction_label == 'Pneumonia':
st.markdown(f"<h3 style='color: red;'>⚠️ {prediction_label}</h3>", unsafe_allow_html=True)
else:
st.markdown(f"<h3 style='color: green;'>✔️ {prediction_label}</h3>", unsafe_allow_html=True)
st.write(f"**Confidence Level:** {confidence_score:.2f}%")
st.write("This confidence score represents the likelihood that the model's prediction is correct.")
# Additional note for users
st.info("**Note:** This tool is for informational purposes only and should not be used as a substitute for professional medical advice, diagnosis, or treatment.")