-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit eaad043
Showing
17 changed files
with
250 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
__pycache__/ | ||
data/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
# PyTorch with Flask Demo | ||
|
||
A simple, bare-bones example on how to serve an ML model through a web application, using | ||
[PyTorch](https://pytorch.org/) and [Flask](https://flask.palletsprojects.com/en/2.2.x/). | ||
Based off the official [classifier tutorial](https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html). | ||
|
||
Images for testing can be found in the `images/` directory. | ||
|
||
**Installing dependencies** | ||
|
||
``` | ||
$ pip install -r requirements.txt | ||
``` | ||
|
||
**Training the model** | ||
|
||
``` | ||
$ python ./model.py train | ||
``` | ||
|
||
**Testing the model** | ||
|
||
``` | ||
$ python ./model.py test | ||
``` | ||
|
||
**Running the server** | ||
|
||
``` | ||
$ python ./app.py | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
#! /usr/bin/env python | ||
|
||
import io | ||
from flask import Flask, render_template, request | ||
from model import load_saved_model, predict | ||
from PIL import Image | ||
|
||
app = Flask(__name__) | ||
net = load_saved_model() | ||
|
||
|
||
@app.route("/", methods=["GET"]) | ||
def index(): | ||
return render_template("index.html") | ||
|
||
|
||
@app.route("/eval", methods=["POST"]) | ||
def eval(): | ||
image = Image.open(io.BytesIO(request.data)).convert("L") | ||
result = predict(net, image) | ||
|
||
# Dictionary return values are implicitly converted to JSON, | ||
# which the client can parse to get the result. | ||
return {"result": result} | ||
|
||
|
||
if __name__ == "__main__": | ||
app.run(host="0.0.0.0") |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,148 @@ | ||
#! /usr/bin/env python | ||
# | ||
# Heavily based off of | ||
# https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html | ||
|
||
import sys | ||
import pathlib | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
import torch.optim as optim | ||
import torch.utils.data | ||
import torchvision | ||
import torchvision.transforms as transforms | ||
from PIL import Image | ||
|
||
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | ||
PATH = pathlib.Path(__file__).parent.joinpath("weight.pt") | ||
CLASSES = [str(x) for x in range(0, 10)] | ||
BATCH_SIZE = 4 | ||
|
||
transform = transforms.Compose( | ||
[ | ||
transforms.ToTensor(), | ||
transforms.Normalize((0.5,), (0.5,)), | ||
] | ||
) | ||
|
||
|
||
class Net(nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.conv1 = nn.Conv2d(1, 6, 5) | ||
self.pool = nn.MaxPool2d(2, 2) | ||
self.conv2 = nn.Conv2d(6, 16, 5) | ||
self.fc1 = nn.Linear(16 * 4 * 4, 120) | ||
self.fc2 = nn.Linear(120, 84) | ||
self.fc3 = nn.Linear(84, 10) | ||
|
||
def forward(self, x): | ||
x = self.pool(F.relu(self.conv1(x))) | ||
x = self.pool(F.relu(self.conv2(x))) | ||
x = torch.flatten(x, 1) | ||
x = F.relu(self.fc1(x)) | ||
x = F.relu(self.fc2(x)) | ||
x = self.fc3(x) | ||
return x | ||
|
||
|
||
def train() -> Net: | ||
torch.manual_seed(0) | ||
|
||
net = Net() | ||
net.train() | ||
net.to(DEVICE) | ||
|
||
trainset = torchvision.datasets.MNIST( | ||
root="./data", train=True, download=True, transform=transform | ||
) | ||
trainloader = torch.utils.data.DataLoader( | ||
trainset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4 | ||
) | ||
|
||
criterion = nn.CrossEntropyLoss() | ||
optimizer = optim.AdamW(net.parameters(), lr=0.001) | ||
|
||
for epoch in range(1, 3): | ||
|
||
running_loss = 0.0 | ||
for i, data in enumerate(trainloader, 1): | ||
|
||
# get the inputs; data is a list of [inputs, labels] | ||
inputs, labels = data[0].to(DEVICE), data[1].to(DEVICE) | ||
|
||
optimizer.zero_grad() | ||
|
||
outputs = net(inputs) | ||
loss = criterion(outputs, labels) | ||
loss.backward() | ||
optimizer.step() | ||
|
||
running_loss += loss.item() | ||
|
||
if i % 2000 == 0: | ||
print(f"[{epoch}, {i:5d}] loss: {running_loss / 2000:.3f}") | ||
running_loss = 0.0 | ||
|
||
return net | ||
|
||
|
||
def test(net: Net): | ||
net.eval() | ||
|
||
testset = torchvision.datasets.MNIST( | ||
root="./data", train=False, download=True, transform=transform | ||
) | ||
testloader = torch.utils.data.DataLoader( | ||
testset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4 | ||
) | ||
|
||
correct_pred = {classname: 0 for classname in CLASSES} | ||
total_pred = {classname: 0 for classname in CLASSES} | ||
|
||
with torch.no_grad(): | ||
for data in testloader: | ||
images, labels = data[0].to(DEVICE), data[1].to(DEVICE) | ||
outputs = net(images) | ||
_, predictions = torch.max(outputs, 1) | ||
|
||
for label, prediction in zip(labels, predictions): | ||
if label == prediction: | ||
correct_pred[CLASSES[label]] += 1 | ||
total_pred[CLASSES[label]] += 1 | ||
|
||
for classname, correct_count in correct_pred.items(): | ||
accuracy = 100 * float(correct_count) / total_pred[classname] | ||
print(f"Accuracy for class: {classname:5s} is {accuracy:.1f} %") | ||
|
||
|
||
def predict(net: Net, image: Image) -> str: | ||
with torch.no_grad(): | ||
image = transform(image).to(DEVICE).unsqueeze(0) | ||
outputs = net(image) | ||
_, predictions = torch.max(outputs, 1) | ||
return CLASSES[predictions[0]] | ||
|
||
|
||
def load_saved_model() -> Net: | ||
net = Net() | ||
net.load_state_dict(torch.load(PATH)) | ||
net.to(DEVICE) | ||
net.eval() | ||
|
||
return net | ||
|
||
|
||
if __name__ == "__main__": | ||
if len(sys.argv) != 2: | ||
print("Invalid arguments, must be one of 'train', 'test'") | ||
elif sys.argv[1] == "train": | ||
net = train() | ||
torch.save(net.state_dict(), PATH) | ||
print(f"Finished training model, saved to {PATH}") | ||
elif sys.argv[1] == "test": | ||
net = load_saved_model() | ||
test(net) | ||
else: | ||
print("Invalid argument, must be one of 'train', 'test'") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
flask | ||
torch | ||
torchvision | ||
Pillow |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
<html> | ||
<head> | ||
<meta name="viewport" content="width=device-width, initial-scale=1.0" /> | ||
<title>My Model</title> | ||
|
||
<script> | ||
async function submit() { | ||
// Get the HTML elements. | ||
let prediction = document.getElementById("prediction"); | ||
let file = document.getElementById("file"); | ||
|
||
// Ensure that there is a file uploaded, otherwise do nothing. | ||
if (file.files.length === 0) { | ||
return; | ||
} | ||
|
||
// Call the server's "/eval" endpoint with the file. | ||
let response = await fetch("/eval", { | ||
method: "POST", | ||
body: file.files[0], | ||
}); | ||
|
||
// Parse the JSON response. | ||
let json = await response.json(); | ||
|
||
// Set the prediction and clear the input. | ||
prediction.innerText = `Prediction: ${json.result}`; | ||
file.value = ""; | ||
} | ||
</script> | ||
</head> | ||
<body> | ||
<input id="file" type="file" /> | ||
<button onclick="submit();">Submit</button> | ||
<p id="prediction"></p> | ||
</body> | ||
</html> |
Binary file not shown.