Skip to content

Commit

Permalink
[UPDATED FEATURE-993] Migrate Image Training Endpoint to Django (#1053)
Browse files Browse the repository at this point in the history
* adding image training endpoint

* 🎨 Format Python code with psf/black

* importing image for dl_model.py

* fixing return typing and typos

* optimizing imports and adding cifar dataset

* deleting local downloaded data

* adding return typing

* 🎨 Format Python code with psf/black

* fixing comment

* adding dummy keyword arg

* 🎨 Format Python code with psf/black

* added logging

* testing endpoint

* adding maxpool2d and flatten

* updating criterion

* removing print statements

* fixing cifar typo

* fixing nits and feedback

* 🎨 Format Python code with psf/black

---------

Co-authored-by: codingwithsurya <[email protected]>
Co-authored-by: karkir0003 <[email protected]>
Co-authored-by: karkir0003 <[email protected]>
  • Loading branch information
4 people authored Nov 20, 2023
1 parent b5acba0 commit c9f07ec
Show file tree
Hide file tree
Showing 8 changed files with 180 additions and 4 deletions.
2 changes: 1 addition & 1 deletion dlp-cli
2 changes: 1 addition & 1 deletion training/training/core/criterion.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class CELossHandler(CriterionHandler):
def _compute_loss(self, output, labels):
output = torch.reshape(
output,
(output.shape[0], output.shape[2]),
(output.shape[0], output.shape[-1]),
)
labels = labels.squeeze_()
return nn.CrossEntropyLoss(reduction="mean")(output, labels.long())
Expand Down
107 changes: 106 additions & 1 deletion training/training/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,21 @@
import torch
from torch.utils.data import Dataset
from torch.autograd import Variable
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from enum import Enum
import os
import shutil


class TrainTestDatasetCreator(ABC):
"Creator that creates train and test PyTorch datasets"
"""
Creator that creates train and test PyTorch datasets from a given dataset.
This class serves as an abstract base class for creating training and testing
datasets compatible with PyTorch's dataset structure. Implementations should
define specific methods for dataset processing and loading.
"""

@abstractmethod
def createTrainDataset(self) -> Dataset:
Expand Down Expand Up @@ -98,3 +109,97 @@ def getCategoryList(self) -> list[str]:
if self._category_list is None:
raise Exception("Category list not available")
return self._category_list


class DefaultImageDatasets(Enum):
MNIST = "MNIST"
FASHION_MNIST = "FashionMNIST"
KMNIST = "KMNIST"
CIFAR10 = "CIFAR10"


class ImageDefaultDatasetCreator(TrainTestDatasetCreator):
def __init__(
self,
dataset_name: str,
train_transform: None,
test_transform: None,
batch_size: int = 32,
shuffle: bool = True,
):
if dataset_name not in DefaultImageDatasets.__members__:
raise Exception(
f"The {dataset_name} file does not currently exist in our inventory. Please submit a request to the contributors of the repository"
)

self.dataset_dir = "./training/image_data_uploads"
self.train_transform = train_transform or transforms.Compose(
[transforms.ToTensor()]
)

self.test_transform = test_transform or transforms.Compose(
[transforms.ToTensor()]
)
self.batch_size = batch_size
self.shuffle = shuffle

# Ensure the directory exists
os.makedirs(self.dataset_dir, exist_ok=True)
print(f"train transform: {train_transform}")
print(f"test transform: {test_transform}")
# Load the datasets

self.train_set = datasets.__dict__[dataset_name](
root=self.dataset_dir,
train=True,
download=True,
transform=self.train_transform,
)
self.test_set = datasets.__dict__[dataset_name](
root=self.dataset_dir,
train=False,
download=True,
transform=self.test_transform,
)

@classmethod
def fromDefault(
cls,
dataset_name: str,
train_transform=None,
test_transform=None,
batch_size: int = 32,
shuffle: bool = True,
) -> "ImageDefaultDatasetCreator":
return cls(dataset_name, train_transform, test_transform, batch_size, shuffle)

def delete_datasets_from_directory(self):
if os.path.exists(self.dataset_dir):
try:
shutil.rmtree(self.dataset_dir)
print(f"Successfully deleted {self.dataset_dir}")
except Exception as e:
print(f"Failed to delete {self.dataset_dir} with error: {e}")

def createTrainDataset(self) -> DataLoader:
train_loader = DataLoader(
self.train_set,
batch_size=self.batch_size,
shuffle=self.shuffle,
drop_last=True,
)
self.delete_datasets_from_directory() # Delete datasets after loading
return train_loader

def createTestDataset(self) -> DataLoader:
test_loader = DataLoader(
self.test_set,
batch_size=self.batch_size,
shuffle=self.shuffle,
drop_last=True,
)
self.delete_datasets_from_directory() # Delete datasets after loading
return test_loader

def getCategoryList(self) -> list[str]:
return self.train_set.classes if hasattr(self.train_set, "classes") else []
5 changes: 5 additions & 0 deletions training/training/core/dl_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch.nn as nn

from training.routes.tabular.schemas import LayerParams
from training.routes.image.schemas import LayerParams


class DLModel(nn.Module):
Expand All @@ -13,6 +14,10 @@ class DLModel(nn.Module):
"SOFTMAX": nn.Softmax,
"SIGMOID": nn.Sigmoid,
"LOGSOFTMAX": nn.LogSoftmax,
"CONV2D": nn.Conv2d,
"DROPOUT": nn.Dropout,
"MAXPOOL2D": nn.MaxPool2d,
"FLATTEN": nn.Flatten,
}

def __init__(self, layer_list: list[nn.Module]):
Expand Down
Empty file.
42 changes: 42 additions & 0 deletions training/training/routes/image/image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from typing import Literal, Optional
from django.http import HttpRequest
from ninja import Router, Schema
from training.core.criterion import getCriterionHandler
from training.core.dl_model import DLModel
from training.core.dataset import ImageDefaultDatasetCreator
from torch.utils.data import DataLoader
from training.core.optimizer import getOptimizer
from training.core.trainer import ClassificationTrainer
from training.routes.image.schemas import ImageParams
from training.core.authenticator import FirebaseAuth

router = Router()


@router.post("", auth=FirebaseAuth())
def imageTrain(request: HttpRequest, imageParams: ImageParams):
if imageParams.default:
dataCreator = ImageDefaultDatasetCreator.fromDefault(imageParams.default)
print(vars(dataCreator))
train_loader = dataCreator.createTrainDataset()
test_loader = dataCreator.createTestDataset()
model = DLModel.fromLayerParamsList(imageParams.user_arch)
# print(f'model is: {model}')
optimizer = getOptimizer(model, imageParams.optimizer_name, 0.05)
criterionHandler = getCriterionHandler(imageParams.criterion)
if imageParams.problem_type == "CLASSIFICATION":
trainer = ClassificationTrainer(
train_loader,
test_loader,
model,
optimizer,
criterionHandler,
imageParams.epochs,
dataCreator.getCategoryList(),
)
for epoch_result in trainer:
print(epoch_result)
print(trainer.labels_last_epoch, trainer.y_pred_last_epoch)
print(trainer.generate_confusion_matrix())
print(trainer.generate_AUC_ROC_CURVE())
return trainer.generate_AUC_ROC_CURVE()
22 changes: 22 additions & 0 deletions training/training/routes/image/schemas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from typing import Any, Literal, Optional
from ninja import Schema


class LayerParams(Schema):
value: str
parameters: list[Any]


class ImageParams(Schema):
# target: str
# features: list[str]
name: str
problem_type: Literal["CLASSIFICATION"]
default: Optional[str]
criterion: str
optimizer_name: str
shuffle: bool
epochs: int
test_size: float
batch_size: int
user_arch: list[LayerParams]
4 changes: 3 additions & 1 deletion training/training/urls.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""
URL configuration for training project.
The `urlpatterns` list routes URLs to views. For more information please see:
https://docs.djangoproject.com/en/4.2/topics/http/urls/
Examples:
Expand All @@ -19,8 +18,10 @@
from django.http import HttpRequest
from django.urls import path
from ninja import NinjaAPI, Schema

from training.routes.datasets.default.columns import router as default_dataset_router
from training.routes.tabular.tabular import router as tabular_router
from training.routes.image.image import router as image_router

api = NinjaAPI()

Expand All @@ -32,6 +33,7 @@ def test(request: HttpRequest):

api.add_router("/datasets/default/", default_dataset_router)
api.add_router("/tabular", tabular_router)
api.add_router("/image", image_router)

urlpatterns = [
path("admin/", admin.site.urls),
Expand Down

0 comments on commit c9f07ec

Please sign in to comment.