Skip to content

Commit

Permalink
🎨 Format Python code with psf/black
Browse files Browse the repository at this point in the history
  • Loading branch information
karkir0003 authored and github-actions committed Nov 20, 2023
1 parent 08ff344 commit a6448a8
Show file tree
Hide file tree
Showing 5 changed files with 15 additions and 16 deletions.
12 changes: 6 additions & 6 deletions training/training/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def __init__(
self.train_transform = train_transform or transforms.Compose(
[transforms.ToTensor()]
)

self.test_transform = test_transform or transforms.Compose(
[transforms.ToTensor()]
)
Expand All @@ -139,10 +139,10 @@ def __init__(

# Ensure the directory exists
os.makedirs(self.dataset_dir, exist_ok=True)
print(f'train transform: {train_transform}')
print(f'test transform: {test_transform}')
print(f"train transform: {train_transform}")
print(f"test transform: {test_transform}")
# Load the datasets

self.train_set = datasets.__dict__[dataset_name](
root=self.dataset_dir,
train=True,
Expand Down Expand Up @@ -182,7 +182,7 @@ def createTrainDataset(self) -> DataLoader:
shuffle=self.shuffle,
drop_last=True,
)
self.delete_datasets_from_directory() # Delete datasets after loading
self.delete_datasets_from_directory() # Delete datasets after loading
return train_loader

def createTestDataset(self) -> DataLoader:
Expand All @@ -192,7 +192,7 @@ def createTestDataset(self) -> DataLoader:
shuffle=self.shuffle,
drop_last=True,
)
self.delete_datasets_from_directory() # Delete datasets after loading
self.delete_datasets_from_directory() # Delete datasets after loading
return test_loader

def getCategoryList(self) -> list[str]:
Expand Down
2 changes: 1 addition & 1 deletion training/training/core/dl_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class DLModel(nn.Module):
"CONV2D": nn.Conv2d,
"DROPOUT": nn.Dropout,
"MAXPOOL2D": nn.MaxPool2d,
"FLATTEN": nn.Flatten
"FLATTEN": nn.Flatten,
}

def __init__(self, layer_list: list[nn.Module]):
Expand Down
8 changes: 4 additions & 4 deletions training/training/core/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,10 @@ def _train_step(self, inputs: torch.Tensor, labels: torch.Tensor):
self.optimizer.zero_grad() # zero out gradient for each batch
self.model.forward(inputs) # make prediction on input
self._outputs: torch.Tensor = self.model(inputs) # make prediction on input
print('MODEL FORWARD PASS DONE!!!!')
print(f'output dim: {self._outputs.shape}')
print(f'label dim: {labels.shape}')
print(f'loss function used: {self.criterionHandler}')
print("MODEL FORWARD PASS DONE!!!!")
print(f"output dim: {self._outputs.shape}")
print(f"label dim: {labels.shape}")
print(f"loss function used: {self.criterionHandler}")
loss = self.criterionHandler.compute_loss(self._outputs, labels)
loss.backward() # backpropagation
self.optimizer.step() # adjust optimizer weights
Expand Down
6 changes: 2 additions & 4 deletions training/training/routes/image/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,7 @@
@router.post("", auth=FirebaseAuth())
def imageTrain(request: HttpRequest, imageParams: ImageParams):
if imageParams.default:
dataCreator = ImageDefaultDatasetCreator.fromDefault(
imageParams.default
)
dataCreator = ImageDefaultDatasetCreator.fromDefault(imageParams.default)
print(vars(dataCreator))
train_loader = dataCreator.createTrainDataset()
test_loader = dataCreator.createTestDataset()
Expand All @@ -37,7 +35,7 @@ def imageTrain(request: HttpRequest, imageParams: ImageParams):
# )

model = DLModel.fromLayerParamsList(imageParams.user_arch)
print(f'model is: {model}')
print(f"model is: {model}")
optimizer = getOptimizer(model, imageParams.optimizer_name, 0.05)
criterionHandler = getCriterionHandler(imageParams.criterion)
if imageParams.problem_type == "CLASSIFICATION":
Expand Down
3 changes: 2 additions & 1 deletion training/training/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from training.routes.datasets.default.columns import router as default_dataset_router
from training.routes.tabular.tabular import router as tabular_router
from training.routes.image.image import router as image_router

# from training.routes.datasets.default import get_default_datasets_router
# from training.routes.tabular import get_tabular_router
# from training.routes.image import image_router
Expand All @@ -41,4 +42,4 @@ def test(request: HttpRequest):
urlpatterns = [
path("admin/", admin.site.urls),
path("api/", api.urls), # type: ignore
]
]

0 comments on commit a6448a8

Please sign in to comment.