-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
5127caa
commit 65b691b
Showing
118 changed files
with
387 additions
and
10,347 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,55 +1,55 @@ | ||
## @file cnn.py | ||
# @author Wilson Castello Branco Neto (<mailto:[email protected]>) / Robson Costa (<mailto:[email protected]>) | ||
# @brief CNN class. | ||
# @version 0.1.0 | ||
# @since 06/12/2024 | ||
# @date 09/12/2024 | ||
# @copyright Copyright © since 2024 <a href="https://agrotechlab.lages.ifsc.edu.br" target="_blank">AgroTechLab</a>.\n | ||
# ![LICENSE license](../figs/license.png)<br> | ||
# Licensed under the CC BY-NC-SA (<i>Creative Commons Attribution-NonCommercial-ShareAlike</i>) 4.0 International Unported License (the <em>"License"</em>). You may not | ||
# use this file except in compliance with the License. You may obtain a copy of the License <a href="https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode" target="_blank">here</a>. | ||
# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an <em>"as is" basis, | ||
# without warranties or conditions of any kind</em>, either express or implied. See the License for the specific language governing permissions | ||
# and limitations under the License. | ||
import time | ||
import numpy as np | ||
import torch | ||
from torch import nn, optim | ||
from torch.utils import data | ||
from torchvision import models | ||
|
||
## CNN class. | ||
# @brief Train CCN models. | ||
class CNN: | ||
## @fn __init__ | ||
# @brief The CNN class initializer | ||
# @param train_data Training data | ||
# @param validation_data Validation data | ||
# @param test_data Test data | ||
# @param batch_size Batch size | ||
"""CNN Trainer class. | ||
This class is responsible for training a CNN model. | ||
Parameters: | ||
train_data (torchvision.datasets.ImageFolder): Training data. | ||
validation_data (torchvision.datasets.ImageFolder): Validation data. | ||
test_data (torchvision.datasets.ImageFolder): Test data. | ||
batch_size (int): Batch size. | ||
""" | ||
|
||
def __init__(self, train_data, validation_data, test_data, batch_size): | ||
"""The CNN Trainer class constructor.""" | ||
|
||
## Train data loader | ||
# Train data loader | ||
self.train_loader = data.DataLoader(train_data, batch_size=batch_size, shuffle=True) | ||
|
||
## Validation data loader | ||
# Validation data loader | ||
self.validation_loader = data.DataLoader(validation_data, batch_size=batch_size, shuffle=False) | ||
|
||
## Test data loader | ||
# Test data loader | ||
self.test_loader = data.DataLoader(test_data, batch_size=batch_size, shuffle=False) | ||
|
||
## Trainer device type | ||
# Trainer device type | ||
self.device = torch.device("cpu") | ||
|
||
## @fn create_and_train_cnn | ||
# @brief Create and train a CNN model | ||
# @param model_name Model name | ||
# @param num_epochs Number of epochs | ||
# @param learning_rate Learning rate | ||
# @param weight_decay Weight decay | ||
# @param replications Number of replications | ||
# @return Result name, average accuracy, maximum accuracy, iteration of maximum accuracy, duration | ||
|
||
def create_and_train_cnn(self, model_name, num_epochs, learning_rate, weight_decay, replications): | ||
"""Create and train a CNN model. | ||
Parameters: | ||
model_name (str): Model name to be trained. | ||
num_epochs (int): Number of epochs to be trained. | ||
learning_rate (float): Learning rate to be used at train. | ||
weight_decay (float): Weight decay to be used at train. | ||
replications (int): Number of replications used at each trained model. | ||
Returns: | ||
(dict): A dict mapping keys to the: | ||
* 'result_name': (str) Result name. | ||
* 'acc_avg': (float) Average accuracy. | ||
* 'iter_acc_max': (int) Iteration of maximum accuracy. | ||
* 'duration': (float) Duration of training. | ||
""" | ||
begin = time.time() | ||
sum = 0 | ||
acc_max = 0 | ||
|
@@ -69,11 +69,19 @@ def create_and_train_cnn(self, model_name, num_epochs, learning_rate, weight_dec | |
result_name = f"{model_name}-{num_epochs}-{learning_rate}-{weight_decay}" | ||
return result_name, acc_avg, iter_acc_max, duration | ||
|
||
## @fn create_model | ||
# @brief Create a model | ||
# @param model_name Model name | ||
# @return Model | ||
|
||
def create_model(self, model_name): | ||
"""Create a function to a CNN model to be trained. | ||
Note: | ||
At moment, the models available are: [VGG11, Alexnet, MobilenetV3Large]. | ||
Parameters: | ||
model_name (str): CNN model name. | ||
Returns: | ||
(function): Function to CNN model selected. | ||
""" | ||
if (model_name=='VGG11'): | ||
model = models.vgg11(weights='DEFAULT') | ||
for param in model.parameters(): | ||
|
@@ -91,41 +99,60 @@ def create_model(self, model_name): | |
for param in model.parameters(): | ||
param.requires_grad = False | ||
model.classifier[3] = nn.Linear(model.classifier[3].in_features,2) | ||
return model | ||
return model | ||
|
||
## @fn create_optimizer | ||
# @brief Create an optimizer | ||
# @param model Model | ||
# @param learning_rate Learning rate | ||
# @param weight_decay Weight decay | ||
# @return Optimizer | ||
|
||
def create_optimizer(self, model, learning_rate, weight_decay): | ||
"""Create an optimizer. | ||
Parameters: | ||
model (function): CNN function. | ||
learning_rate (float): Learning rate | ||
weight_decay (float): Weight decay | ||
Returns: | ||
(object): Optimizer object. | ||
""" | ||
update = [] | ||
for name,param in model.named_parameters(): | ||
if param.requires_grad == True: | ||
update.append(param) | ||
optimizerSGD = optim.SGD(update, lr=learning_rate, weight_decay=weight_decay) | ||
return optimizerSGD | ||
|
||
## @fn create_criterion | ||
# @brief Create a loss criterion | ||
# @return criterionCEL | ||
|
||
def create_criterion(self): | ||
"""Create a loss criterion. | ||
Parameters: | ||
None | ||
Returns: | ||
(object): Cross entropy loss object. | ||
""" | ||
criterionCEL = nn.CrossEntropyLoss() | ||
return criterionCEL | ||
|
||
## @fn train_model | ||
# @brief Train a model | ||
# @param model Model | ||
# @param train_loader Training data loader | ||
# @param optimizer Optimizer | ||
# @param criterion Loss criterion | ||
# @param model_name Model name | ||
# @param num_epochs Number of epochs | ||
# @param learning_rate Learning rate | ||
# @param weight_decay Weight decay | ||
# @param replication Replication | ||
def train_model(self, model, train_loader, optimizer, criterion, model_name, num_epochs, learning_rate, weight_decay, replication): | ||
|
||
def train_model(self, model, train_loader, optimizer, criterion, model_name, num_epochs, learning_rate, weight_decay, replication): | ||
"""Train a CNN model. | ||
Train a CNN model and save it (PTH file) at 'models' directory. | ||
Parameters: | ||
model (function): Model function. | ||
train_loader (DataLoader): Training data loader | ||
optimizer (object): Optimizer object. | ||
criterion (object): CEL object. | ||
model_name (str): Model name. | ||
num_epochs (int): Number of epochs. | ||
learning_rate (float): Learning rate. | ||
weight_decay (float): Weight decay. | ||
replication (int): Replication. | ||
Returns: | ||
None | ||
""" | ||
model.to(self.device) | ||
min_loss = 100 | ||
e_measures = [] | ||
|
@@ -136,14 +163,19 @@ def train_model(self, model, train_loader, optimizer, criterion, model_name, num | |
nome_arquivo = f"./models/{model_name}_{num_epochs}_{learning_rate}_{weight_decay}_{replication}.pth" | ||
torch.save(model.state_dict(), nome_arquivo) | ||
|
||
## @fn train_epoch | ||
# @brief Train an epoch | ||
# @param model Model | ||
# @param trainLoader Training data loader | ||
# @param optimizer Optimizer | ||
# @param criterion Loss criterion | ||
# @return Mean of losses | ||
|
||
def train_epoch(self, model, trainLoader, optimizer, criterion): | ||
"""Train an epoch. | ||
Parameters: | ||
model (function): Model function. | ||
trainLoader (DataLoader): Training data loader. | ||
optimizer (object): Optimizer object. | ||
criterion (object): CEL object. | ||
Returns: | ||
(float): Mean of losses. | ||
""" | ||
model.train() | ||
losses = [] | ||
for X, y in trainLoader: | ||
|
@@ -164,6 +196,15 @@ def train_epoch(self, model, trainLoader, optimizer, criterion): | |
# @param loader Data loader | ||
# @return Accuracy | ||
def evaluate_model(self, model, loader): | ||
"""Evaluate a model. | ||
Parameters: | ||
model (function): Model function. | ||
loader (DataLoader): Data loader | ||
Returns: | ||
(float): Model (trained) accuracy. | ||
""" | ||
total = 0 | ||
correct = 0 | ||
for X, y in loader: | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,14 +1,14 @@ | ||
cnn_local: | ||
local: | ||
cpu_used: 4 # Number of CPUs to use | ||
train_path: "./data/summarized/train/" # Path to training data | ||
test_path: "./data/summarized/test/" # Path to testing data | ||
val_path: "./data/summarized/validation/" # Path to validation data | ||
models_path: "./models/" # Path to save models | ||
transforms_height: 224 # Height of the image | ||
transforms_width: 224 # Width of the image | ||
replications: 5 # Number of models replications | ||
replications: 2 # Number of models replications | ||
batch_size: 8 # Number of samples per batch to be loaded | ||
model_names: ["Alexnet"] # Model names supported (Alexnet, MobilenetV3Large, VGG11) | ||
epochs: [5, 10, 20] # Number of epochs | ||
learning_rates: [0.001, 0.0001, 0.00001] # Learning rates | ||
epochs: [5] # Number of epochs | ||
learning_rates: [0.001, 0.0001] # Learning rates | ||
weight_decays: [0, 0.001] # Weight decays |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,35 +1,29 @@ | ||
## @file cnn_local.py | ||
# @author Wilson Castello Branco Neto (<mailto:[email protected]>) / Robson Costa (<mailto:[email protected]>) | ||
# @brief CNN Trainer. | ||
# @version 0.1.0 | ||
# @since 06/12/2024 | ||
# @date 09/12/2024 | ||
# @copyright Copyright © since 2024 <a href="https://agrotechlab.lages.ifsc.edu.br" target="_blank">AgroTechLab</a>.\n | ||
# ![LICENSE license](../figs/license.png)<br> | ||
# Licensed under the CC BY-NC-SA (<i>Creative Commons Attribution-NonCommercial-ShareAlike</i>) 4.0 International Unported License (the <em>"License"</em>). You may not | ||
# use this file except in compliance with the License. You may obtain a copy of the License <a href="https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode" target="_blank">here</a>. | ||
# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an <em>"as is" basis, | ||
# without warranties or conditions of any kind</em>, either express or implied. See the License for the specific language governing permissions | ||
# and limitations under the License. | ||
import os | ||
import logging | ||
import multiprocessing | ||
import time | ||
from concurrent.futures import ProcessPoolExecutor, as_completed | ||
import config | ||
import config as config | ||
from cnn import CNN | ||
import torch | ||
from torchvision import datasets | ||
from torchvision.transforms import v2 | ||
|
||
## Configuration file name | ||
CONFIG_FILE = 'cnn_local.yml' | ||
CONFIG_FILE = 'cnn_trainer.yml' | ||
LOG_FILE = 'cnn_trainer.log' | ||
|
||
## @fn define_transforms | ||
# @brief Define transforms for the images | ||
# @param height Height | ||
# @param width Width | ||
# @return Data transforms | ||
def define_transforms(height, width): | ||
""" | ||
Define transforms for the images. | ||
Parameters: | ||
height (int): Images height. | ||
width (int): Images width. | ||
Returns: | ||
(dict): Data transforms. | ||
""" | ||
logging.info("Defining transforms") | ||
print("\tDefining transforms...", flush=True) | ||
data_transforms = { | ||
|
@@ -51,14 +45,22 @@ def define_transforms(height, width): | |
} | ||
return data_transforms | ||
|
||
## @fn read_images | ||
# @brief Read images | ||
# @param data_transforms Data transforms | ||
# @param train_path Train path | ||
# @param val_path Validation path | ||
# @param test_path Test path | ||
# @return Train data, validation data and test data | ||
def read_images(data_transforms, train_path, val_path, test_path): | ||
""" | ||
Read images (train, validation and test) from their respective directories. | ||
Parameters: | ||
data_transforms (dict): Tranforms to be applied to the images. | ||
train_path (str): Path to train images directory. | ||
val_path (str): Path to validation images directory. | ||
test_path (str): Path to test images directory. | ||
Returns: | ||
(dict): A dict mapping keys to the: | ||
* 'train_data': (datasets.ImageFolder) Train data. | ||
* 'validation_data': (datasets.ImageFolder) Validation data. | ||
* 'test_data': (datasets.ImageFolder) Test data. | ||
""" | ||
logging.info("Reading images") | ||
print("\tReading images...", flush=True) | ||
|
||
|
@@ -73,9 +75,8 @@ def read_images(data_transforms, train_path, val_path, test_path): | |
|
||
return train_data, validation_data, test_data | ||
|
||
## @fn main | ||
# @brief Main function | ||
def main(): | ||
"""Main function.""" | ||
logging.info("Starting CNN Trainer (Local)") | ||
print("Starting CNN Trainer (Local)", flush=True) | ||
|
||
|
@@ -87,7 +88,7 @@ def main(): | |
# Read configuration file | ||
logging.info(f"Reading configuration file {CONFIG_FILE}") | ||
print("\tReading configuration file... ", end='', flush=True) | ||
cfgObj = config.Config(CONFIG_FILE) | ||
cfgObj = config.Config(cfgFile=CONFIG_FILE) | ||
print("[OK]", flush=True) | ||
|
||
# Validate CPU used | ||
|
@@ -130,14 +131,10 @@ def main(): | |
logging.info(f"Total duration: {totalDuration:.2f} seconds") | ||
print(f"\tTotal duration: {totalDuration:.2f} seconds", flush=True) | ||
|
||
## Main function | ||
if __name__ == '__main__': | ||
## LOG file name | ||
filename = "cnn_local.log" | ||
## LOG format | ||
format = '%(asctime)s %(levelname)s - %(message)s' | ||
## LOG level | ||
level = logging.INFO | ||
|
||
logging.basicConfig(filename=filename, format=format, level=level) | ||
"""Main function (entry point).""" | ||
# Setup logging | ||
logging.basicConfig(filename=LOG_FILE, format='%(asctime)s %(levelname)s - %(message)s', encoding='utf-8', level=logging.INFO) | ||
|
||
# Call main function | ||
main() |
Oops, something went wrong.