-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #17 from beeb89gang/feat/danknessificator
Feat/danknessificator
- Loading branch information
Showing
13 changed files
with
312 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,4 +4,6 @@ | |
*.png | ||
*.gif | ||
.env* | ||
*__pycache__* | ||
*__pycache__* | ||
mACHINE-LEARNINGS/datasets/*.csv | ||
mACHINE-LEARNINGS/models/trained/* |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
stupidvenv/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
import os | ||
import glob | ||
from torchvision.transforms import Resize | ||
from torchvision.io import read_image, ImageReadMode | ||
from torch.utils.data import Dataset | ||
import torch | ||
import imghdr | ||
|
||
supported_ext = ['jpg'] | ||
|
||
class MemeDataset(Dataset): | ||
def __init__(self, data_dir, transform=None): | ||
self.transform = transform | ||
self.img_path = [] | ||
for e in supported_ext: | ||
path = glob.glob(os.path.join(data_dir, '*', f'*.{e}')) | ||
for p in path: | ||
if imghdr.what(p) == 'jpeg': | ||
self.img_path.append(p) | ||
classes = set() | ||
for path in self.img_path: | ||
classes.add(os.path.basename(os.path.dirname(path))) | ||
self.labels = {cls: i for i, cls in enumerate(sorted(list(classes)))} | ||
|
||
def __len__(self): | ||
return len(self.img_path) | ||
|
||
def __getitem__(self, idx): | ||
img = read_image(self.img_path[idx], ImageReadMode.RGB).float() | ||
cls = os.path.basename(os.path.dirname(self.img_path[idx])) | ||
label = self.labels[cls] | ||
|
||
if self.transform: | ||
return self.transform['train'](img), torch.tensor(label) | ||
|
||
return img, torch.tensor(label) |
Empty file.
Empty file.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
import warnings; warnings.filterwarnings('ignore') | ||
import torch | ||
from torch.utils.data import DataLoader | ||
from torchvision.io import read_image, ImageReadMode | ||
from classes.MemeDataset import MemeDataset | ||
from models.DankCNN import DankCNN | ||
import argparse | ||
from torchvision import transforms | ||
|
||
cls = {0: 'dank', 1: 'normie'} | ||
|
||
def eval(image): | ||
""" | ||
SETUP | ||
""" | ||
if torch.cuda.is_available(): | ||
device = torch.device('cuda') | ||
else: | ||
device = torch.device('cpu') | ||
|
||
""" | ||
DATA LOADING | ||
- Load all data: train, test, validation | ||
""" | ||
image = read_image(image, ImageReadMode.RGB).float() | ||
data_transforms = { | ||
'test': transforms.Compose([ | ||
transforms.ToPILImage(), | ||
transforms.Resize(320), | ||
transforms.ToTensor(), | ||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | ||
]) | ||
} | ||
image = data_transforms['test'](image) | ||
""" | ||
MODEL INITIALIZATION | ||
- optimizer: Adam with weight decay as regularization technique | ||
- loss function: binary cross entropy loss | ||
""" | ||
model = DankCNN() | ||
model.load_state_dict(torch.load('models/trained/model.pt', map_location=torch.device('cpu'))) | ||
model = model.to(device) | ||
|
||
model.eval() | ||
with torch.no_grad(): | ||
image = image.unsqueeze(0) | ||
output = model(image.to(device)) | ||
prediction = output.data | ||
print('This meme is', cls[0 if prediction.item() < .5 else 1]) | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument('-i', '--image', nargs=1, help='Image to evaluate', required=True) | ||
args = parser.parse_args() | ||
eval( | ||
image=args.image[0] | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from torchvision import models | ||
|
||
class DankCNN(nn.Module): | ||
|
||
def __init__(self, dropout=False): | ||
super(DankCNN, self).__init__() | ||
# self.pretrained_model = models.resnet18(weights='IMAGENET1K_V1') | ||
self.pretrained_model = models.resnet50(pretrained=True) | ||
self.n_pretrained_features = self.pretrained_model.fc.in_features | ||
self.pretrained_model.fc = nn.Linear(self.pretrained_model.fc.in_features, 1) | ||
|
||
def forward(self, x): | ||
x = self.pretrained_model(x) | ||
return F.sigmoid(x) |
Empty file.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
filelock==3.13.4 | ||
fsspec==2024.3.1 | ||
Jinja2==3.1.3 | ||
joblib==1.4.0 | ||
MarkupSafe==2.1.5 | ||
mpmath==1.3.0 | ||
networkx==3.3 | ||
numpy==1.26.4 | ||
pandas==2.2.2 | ||
pillow==10.3.0 | ||
pyarrow==15.0.2 | ||
python-dateutil==2.9.0.post0 | ||
pytz==2024.1 | ||
scikit-learn==1.4.2 | ||
scipy==1.13.0 | ||
six==1.16.0 | ||
sympy==1.12 | ||
threadpoolctl==3.4.0 | ||
torch==2.2.2+cpu | ||
torchaudio==2.2.2+cpu | ||
torchvision==0.17.2+cpu | ||
tqdm==4.66.2 | ||
typing_extensions==4.11.0 | ||
tzdata==2024.1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
import warnings; warnings.filterwarnings('ignore') | ||
import torch | ||
from torch.utils.data import DataLoader | ||
from classes.MemeDataset import MemeDataset | ||
from utils.performance_measure import precision_recall_f1 | ||
from models.DankCNN import DankCNN | ||
import argparse | ||
from torchvision import transforms | ||
|
||
def execute(train_set_size, batch_size, lr, epochs, is_verbose, weight_decay): | ||
""" | ||
HYPERPARAMETERS AND CONSTANTS | ||
- BATCH_SIZE_TRAIN: size of the batches for training phase | ||
- LR: learning rate | ||
- N_EPOCHS: number of epochs to execute | ||
- IS_VERBOSE: to avoid too much output | ||
- WEIGHT_DECAY: the weight decay for the regularization in Adam optimizer | ||
""" | ||
TRAIN_SET_SIZE = train_set_size | ||
TEST_SET_SIZE = 1 - train_set_size | ||
BATCH_SIZE = batch_size | ||
LR = lr | ||
N_EPOCHS = epochs | ||
IS_VERBOSE = is_verbose | ||
WEIGHT_DECAY = weight_decay | ||
""" | ||
SETUP | ||
""" | ||
if torch.cuda.is_available(): | ||
device = torch.device('cuda') | ||
else: | ||
device = torch.device('cpu') | ||
|
||
""" | ||
DATA LOADING | ||
- Load all data: train, test, validation | ||
""" | ||
data_transforms = { | ||
'train': transforms.Compose([ | ||
transforms.ToPILImage(), | ||
transforms.Resize(320), | ||
transforms.RandomResizedCrop(224), | ||
transforms.RandomHorizontalFlip(), | ||
transforms.ToTensor(), | ||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | ||
]) | ||
} | ||
dataset = MemeDataset('data/', transform=data_transforms) | ||
train_size = int(TRAIN_SET_SIZE*len(dataset)) | ||
test_size = len(dataset) - train_size | ||
train_set, test_set = torch.utils.data.random_split(dataset, [train_size, test_size]) | ||
dataloaders = { | ||
"train": DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True), | ||
"test": DataLoader(test_set, batch_size=BATCH_SIZE, shuffle=False) | ||
} | ||
|
||
""" | ||
MODEL INITIALIZATION | ||
- optimizer: Adam with weight decay as regularization technique | ||
- loss function: binary cross entropy loss | ||
""" | ||
model = DankCNN() | ||
model = model.to(device) | ||
# optimizer = torch.optim.RMSprop(model.parameters(), lr=.001, alpha=.99, eps=1e-08, weight_decay=0, momentum=0, centered=False) | ||
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0) | ||
# loss_function = torch.nn.HingeEmbeddingLoss() | ||
loss_function = torch.nn.BCELoss() | ||
|
||
""" | ||
TRAINING PHASE | ||
""" | ||
for epoch in range(N_EPOCHS): | ||
train_loss = 0 | ||
acc = 0 | ||
|
||
for batch_num, (image, label) in enumerate(dataloaders["train"]): | ||
optimizer.zero_grad() | ||
output = model(image.to(device)) | ||
|
||
label = label.unsqueeze(1).to(device) | ||
|
||
loss = loss_function(output, label.float()) | ||
loss.backward() | ||
optimizer.step() | ||
|
||
train_loss += loss.item() | ||
|
||
predictions = torch.where(output > .5, 1, 0).to(device) | ||
acc += (label == predictions).sum()/len(label) | ||
|
||
if IS_VERBOSE: | ||
print('Training: Epoch %d - Batch %d/%d 🌭 Loss: %.4f' % | ||
(epoch+1, batch_num+1, len(dataloaders["train"]), train_loss / (batch_num + 1))) | ||
|
||
n_batch = batch_num | ||
|
||
print('EPOCH', epoch+1, 'ACCURACY:', (acc.item() / (n_batch+1))) | ||
|
||
""" | ||
TEST PHASE | ||
""" | ||
test_loss = 0 | ||
acc = 0 | ||
model.eval() | ||
with torch.no_grad(): | ||
for batch_num, (image, label) in enumerate(dataloaders["test"]): | ||
output = model(image.to(device)) | ||
|
||
label = label.unsqueeze(1).to(device) | ||
loss = loss_function(output, label.float()) | ||
|
||
test_loss += loss.item() | ||
|
||
predictions = torch.where(output > .5, 1, 0).to(device) | ||
acc += (label == predictions).sum()/len(label) | ||
|
||
if IS_VERBOSE: | ||
print('Evaluating: Batch %d/%d: Loss: %.4f' % | ||
(batch_num, len(dataloaders["test"]), test_loss / (batch_num + 1))) | ||
|
||
n_batch = batch_num | ||
|
||
print('TEST ACCURACY:', (acc.item() / (n_batch+1))) | ||
|
||
print('Saving model...') | ||
torch.save(model.state_dict(), 'models/trained/model.pt') | ||
print('Saved 🌭') | ||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument('-ts', '--trainsetsize', nargs=1, type=float, choices=[x/10 for x in range(0, 10)], help='Train set split size', required=True) | ||
parser.add_argument('-bs', '--batchsize', nargs=1, type=int, help='Size of the training batch', required=False) | ||
parser.add_argument('-lr', '--learningrate', nargs=1, type=float, help='Learning rate', required=False) | ||
parser.add_argument('-e', '--epochs', nargs=1, type=int, help='Number of epochs', required=False) | ||
parser.add_argument('-v', '--verbose', nargs=1, type=bool, help='Verbose mode on/off', required=False) | ||
parser.add_argument('-wd', '--weightdecay', nargs=1, type=float, help='Weight decay (L2 regularization)', required=False) | ||
args = parser.parse_args() | ||
execute( | ||
train_set_size=args.trainsetsize[0] if args.trainsetsize[0]!=0 else 0.8, | ||
batch_size=args.batchsize[0] if args.batchsize else 1000, | ||
lr=args.learningrate[0] if args.learningrate else 0.5, | ||
epochs=args.epochs[0] if args.epochs else 10, | ||
is_verbose=args.verbose if args.verbose else True, | ||
weight_decay=args.weightdecay[0] if args.weightdecay else 0.9 | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
from threading import local | ||
import numpy as np | ||
|
||
""" | ||
precision_recall_f1(predictions: tensor, target: tensor) -> prec: float, rec: float, f1: float | ||
Input: | ||
- predictions: predicted labels | ||
- target: gold data | ||
Output: | ||
- prec: average precision over the classes | ||
- rec: average recall over the classes | ||
- f1: average f1 over the classes | ||
""" | ||
def precision_recall_f1(predictions, target): | ||
predictions = predictions.numpy().tolist() | ||
predictions = [tuple(x) for x in predictions] | ||
target = target.numpy().tolist() | ||
target = [tuple(x) for x in target] | ||
|
||
fp = set(predictions) - set(target) | ||
fn = set(target) - set(predictions) | ||
tp = set(predictions) - set(fp) | ||
|
||
prec = len(tp) / (len(tp) + len(fp)) if (len(tp)+len(fp)) > 0 else 0 | ||
rec = len(tp) / (len(tp) + len(fn)) if (len(tp)+len(fn)) > 0 else 0 | ||
f1 = (2*prec*rec)/(prec+rec) if (prec+rec) > 0 else 0 | ||
|
||
return prec, rec, f1 |