diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..dff4bd2 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +__pycache__/ +data/ diff --git a/README.md b/README.md new file mode 100644 index 0000000..6b6bdea --- /dev/null +++ b/README.md @@ -0,0 +1,31 @@ +# PyTorch with Flask Demo + +A simple, bare-bones example on how to serve an ML model through a web application, using +[PyTorch](https://pytorch.org/) and [Flask](https://flask.palletsprojects.com/en/2.2.x/). +Based off the official [classifier tutorial](https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html). + +Images for testing can be found in the `images/` directory. + +**Installing dependencies** + +``` +$ pip install -r requirements.txt +``` + +**Training the model** + +``` +$ python ./model.py train +``` + +**Testing the model** + +``` +$ python ./model.py test +``` + +**Running the server** + +``` +$ python ./app.py +``` diff --git a/app.py b/app.py new file mode 100644 index 0000000..89d1c3e --- /dev/null +++ b/app.py @@ -0,0 +1,28 @@ +#! /usr/bin/env python + +import io +from flask import Flask, render_template, request +from model import load_saved_model, predict +from PIL import Image + +app = Flask(__name__) +net = load_saved_model() + + +@app.route("/", methods=["GET"]) +def index(): + return render_template("index.html") + + +@app.route("/eval", methods=["POST"]) +def eval(): + image = Image.open(io.BytesIO(request.data)).convert("L") + result = predict(net, image) + + # Dictionary return values are implicitly converted to JSON, + # which the client can parse to get the result. + return {"result": result} + + +if __name__ == "__main__": + app.run(host="0.0.0.0") diff --git a/images/0.jpg b/images/0.jpg new file mode 100644 index 0000000..dbb784c Binary files /dev/null and b/images/0.jpg differ diff --git a/images/1.jpg b/images/1.jpg new file mode 100644 index 0000000..931a97e Binary files /dev/null and b/images/1.jpg differ diff --git a/images/2.jpg b/images/2.jpg new file mode 100644 index 0000000..67afb96 Binary files /dev/null and b/images/2.jpg differ diff --git a/images/3.jpg b/images/3.jpg new file mode 100644 index 0000000..168772f Binary files /dev/null and b/images/3.jpg differ diff --git a/images/4.jpg b/images/4.jpg new file mode 100644 index 0000000..4a7da5e Binary files /dev/null and b/images/4.jpg differ diff --git a/images/5.jpg b/images/5.jpg new file mode 100644 index 0000000..5bf82b3 Binary files /dev/null and b/images/5.jpg differ diff --git a/images/6.jpg b/images/6.jpg new file mode 100644 index 0000000..ee2521e Binary files /dev/null and b/images/6.jpg differ diff --git a/images/7.jpg b/images/7.jpg new file mode 100644 index 0000000..b53650d Binary files /dev/null and b/images/7.jpg differ diff --git a/images/8.jpg b/images/8.jpg new file mode 100644 index 0000000..534f5e7 Binary files /dev/null and b/images/8.jpg differ diff --git a/images/9.jpg b/images/9.jpg new file mode 100644 index 0000000..fab16a9 Binary files /dev/null and b/images/9.jpg differ diff --git a/model.py b/model.py new file mode 100644 index 0000000..b221999 --- /dev/null +++ b/model.py @@ -0,0 +1,148 @@ +#! /usr/bin/env python +# +# Heavily based off of +# https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html + +import sys +import pathlib +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +import torch.utils.data +import torchvision +import torchvision.transforms as transforms +from PIL import Image + +DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") +PATH = pathlib.Path(__file__).parent.joinpath("weight.pt") +CLASSES = [str(x) for x in range(0, 10)] +BATCH_SIZE = 4 + +transform = transforms.Compose( + [ + transforms.ToTensor(), + transforms.Normalize((0.5,), (0.5,)), + ] +) + + +class Net(nn.Module): + def __init__(self): + super().__init__() + self.conv1 = nn.Conv2d(1, 6, 5) + self.pool = nn.MaxPool2d(2, 2) + self.conv2 = nn.Conv2d(6, 16, 5) + self.fc1 = nn.Linear(16 * 4 * 4, 120) + self.fc2 = nn.Linear(120, 84) + self.fc3 = nn.Linear(84, 10) + + def forward(self, x): + x = self.pool(F.relu(self.conv1(x))) + x = self.pool(F.relu(self.conv2(x))) + x = torch.flatten(x, 1) + x = F.relu(self.fc1(x)) + x = F.relu(self.fc2(x)) + x = self.fc3(x) + return x + + +def train() -> Net: + torch.manual_seed(0) + + net = Net() + net.train() + net.to(DEVICE) + + trainset = torchvision.datasets.MNIST( + root="./data", train=True, download=True, transform=transform + ) + trainloader = torch.utils.data.DataLoader( + trainset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4 + ) + + criterion = nn.CrossEntropyLoss() + optimizer = optim.AdamW(net.parameters(), lr=0.001) + + for epoch in range(1, 3): + + running_loss = 0.0 + for i, data in enumerate(trainloader, 1): + + # get the inputs; data is a list of [inputs, labels] + inputs, labels = data[0].to(DEVICE), data[1].to(DEVICE) + + optimizer.zero_grad() + + outputs = net(inputs) + loss = criterion(outputs, labels) + loss.backward() + optimizer.step() + + running_loss += loss.item() + + if i % 2000 == 0: + print(f"[{epoch}, {i:5d}] loss: {running_loss / 2000:.3f}") + running_loss = 0.0 + + return net + + +def test(net: Net): + net.eval() + + testset = torchvision.datasets.MNIST( + root="./data", train=False, download=True, transform=transform + ) + testloader = torch.utils.data.DataLoader( + testset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4 + ) + + correct_pred = {classname: 0 for classname in CLASSES} + total_pred = {classname: 0 for classname in CLASSES} + + with torch.no_grad(): + for data in testloader: + images, labels = data[0].to(DEVICE), data[1].to(DEVICE) + outputs = net(images) + _, predictions = torch.max(outputs, 1) + + for label, prediction in zip(labels, predictions): + if label == prediction: + correct_pred[CLASSES[label]] += 1 + total_pred[CLASSES[label]] += 1 + + for classname, correct_count in correct_pred.items(): + accuracy = 100 * float(correct_count) / total_pred[classname] + print(f"Accuracy for class: {classname:5s} is {accuracy:.1f} %") + + +def predict(net: Net, image: Image) -> str: + with torch.no_grad(): + image = transform(image).to(DEVICE).unsqueeze(0) + outputs = net(image) + _, predictions = torch.max(outputs, 1) + return CLASSES[predictions[0]] + + +def load_saved_model() -> Net: + net = Net() + net.load_state_dict(torch.load(PATH)) + net.to(DEVICE) + net.eval() + + return net + + +if __name__ == "__main__": + if len(sys.argv) != 2: + print("Invalid arguments, must be one of 'train', 'test'") + elif sys.argv[1] == "train": + net = train() + torch.save(net.state_dict(), PATH) + print(f"Finished training model, saved to {PATH}") + elif sys.argv[1] == "test": + net = load_saved_model() + test(net) + else: + print("Invalid argument, must be one of 'train', 'test'") diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..a061794 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,4 @@ +flask +torch +torchvision +Pillow diff --git a/templates/index.html b/templates/index.html new file mode 100644 index 0000000..224b8e2 --- /dev/null +++ b/templates/index.html @@ -0,0 +1,37 @@ + +
+ +