Skip to content

Commit

Permalink
concept drifting katchow
Browse files Browse the repository at this point in the history
  • Loading branch information
enricobu96 committed Apr 28, 2024
1 parent 83c7a18 commit 5937256
Show file tree
Hide file tree
Showing 6 changed files with 102 additions and 27 deletions.
32 changes: 30 additions & 2 deletions mACHINE-LEARNINGS/classes/MemeDataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,46 @@
from torch.utils.data import Dataset
import torch
import imghdr
from PIL import Image
import imagehash
import pickle
import random

supported_ext = ['jpg']

class MemeDataset(Dataset):
def __init__(self, data_dir, transform=None):
def __init__(self, data_dir, retrain=False, grace=0, transform=None):
self.transform = transform
self.img_path = []

seen_images = []
if retrain:
with open(data_dir+'already_trained.pkl', 'rb') as f:
seen_images = pickle.load(f)

# cyclomatic complexity goes brrr but whatevs
for e in supported_ext:
path = glob.glob(os.path.join(data_dir, '*', f'*.{e}'))
for p in path:
if imghdr.what(p) == 'jpeg':
self.img_path.append(p)
if not retrain:
self.img_path.append(p)
seen_images.append(imagehash.average_hash(Image.open(p)))
else:
h_img = imagehash.average_hash(Image.open(p))
similar = False
for si in seen_images:
if abs(h_img - si) > 30: # totally random hamming distance
similar = True
break

if not similar or random.uniform(0, 1)>1-grace: # let's probability do their job
self.img_path.append(p)
seen_images.append(h_img)

with open(data_dir+'already_trained.pkl', 'wb') as f:
pickle.dump(seen_images, f)

classes = set()
for path in self.img_path:
classes.add(os.path.basename(os.path.dirname(path)))
Expand Down
57 changes: 57 additions & 0 deletions mACHINE-LEARNINGS/danknessificator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from train import train
from eval import eval
import argparse


def main(args):
if args.subcommand == 'train':
train(
train_set_size=args.trainsetsize[0] if args.trainsetsize[0]!=0 else 0.8,
batch_size=args.batchsize[0] if args.batchsize else 1000,
lr=args.learningrate[0] if args.learningrate else 0.0001,
epochs=args.epochs[0] if args.epochs else 10,
is_verbose=args.verbose if args.verbose else True,
weight_decay=args.weightdecay[0] if args.weightdecay else 0.9,
retrain=False
)
elif args.subcommand == 'retrain':
train(
train_set_size=0.8,
batch_size=10,
lr=0.0001,
epochs=1,
is_verbose=args.verbose if args.verbose else True,
weight_decay=0.8,
retrain=True,
grace=args.grace[0] if args.grace else 0
)
elif args.subcommand == 'eval':
eval(
image=args.image[0]
)

if __name__ == "__main__":
parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers(dest='subcommand')

parser_train = subparsers.add_parser('train')
parser_retrain = subparsers.add_parser('retrain')
parser_eval = subparsers.add_parser('eval')

parser_train.add_argument('-ts', '--trainsetsize', nargs=1, type=float, choices=[x/10 for x in range(0, 10)], help='Train set split size', required=True)
parser_train.add_argument('-bs', '--batchsize', nargs=1, type=int, help='Size of the training batch', required=False)
parser_train.add_argument('-lr', '--learningrate', nargs=1, type=float, help='Learning rate', required=False)
parser_train.add_argument('-e', '--epochs', nargs=1, type=int, help='Number of epochs', required=False)
parser_train.add_argument('-v', '--verbose', nargs=1, type=bool, help='Verbose mode on/off', required=False)
parser_train.add_argument('-wd', '--weightdecay', nargs=1, type=float, help='Weight decay (L2 regularization)', required=False)

parser_retrain.add_argument('-g', '--grace', nargs=1, type=float, choices=[x/10 for x in range(0, 11)], help='Proportion of old memes to consider again', required=False)
parser_retrain.add_argument('-v', '--verbose', nargs=1, type=bool, help='Verbose mode on/off', required=False)

parser_eval.add_argument('-i', '--image', nargs=1, help='Image to evaluate', required=True)

args = parser.parse_args()

main(args)


Binary file added mACHINE-LEARNINGS/data/already_trained.pkl
Binary file not shown.
1 change: 0 additions & 1 deletion mACHINE-LEARNINGS/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from torchvision.io import read_image, ImageReadMode
from classes.MemeDataset import MemeDataset
from models.DankCNN import DankCNN
import argparse
from torchvision import transforms

cls = {0: 'dank', 1: 'normie'}
Expand Down
2 changes: 2 additions & 0 deletions mACHINE-LEARNINGS/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
filelock==3.13.4
fsspec==2024.3.1
ImageHash==4.3.1
Jinja2==3.1.3
joblib==1.4.0
MarkupSafe==2.1.5
Expand All @@ -11,6 +12,7 @@ pillow==10.3.0
pyarrow==15.0.2
python-dateutil==2.9.0.post0
pytz==2024.1
PyWavelets==1.6.0
scikit-learn==1.4.2
scipy==1.13.0
six==1.16.0
Expand Down
37 changes: 13 additions & 24 deletions mACHINE-LEARNINGS/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,9 @@
from classes.MemeDataset import MemeDataset
from utils.performance_measure import precision_recall_f1
from models.DankCNN import DankCNN
import argparse
from torchvision import transforms

def execute(train_set_size, batch_size, lr, epochs, is_verbose, weight_decay):
def train(train_set_size, batch_size, lr, epochs, is_verbose, weight_decay, retrain=False, grace=1):
"""
HYPERPARAMETERS AND CONSTANTS
- BATCH_SIZE_TRAIN: size of the batches for training phase
Expand Down Expand Up @@ -45,8 +44,11 @@ def execute(train_set_size, batch_size, lr, epochs, is_verbose, weight_decay):
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
}
dataset = MemeDataset('data/', transform=data_transforms)
dataset = MemeDataset('data/', retrain=retrain, grace=grace, transform=data_transforms)
train_size = int(TRAIN_SET_SIZE*len(dataset))
if retrain and len(dataset)==0:
print('all maymays already seen, baby')
return 69
test_size = len(dataset) - train_size
train_set, test_set = torch.utils.data.random_split(dataset, [train_size, test_size])
dataloaders = {
Expand All @@ -60,10 +62,15 @@ def execute(train_set_size, batch_size, lr, epochs, is_verbose, weight_decay):
- loss function: binary cross entropy loss
"""
model = DankCNN()
if retrain:
try:
model.load_state_dict(torch.load('models/trained/model.pt', map_location=torch.device(device)))
except:
print('ciccio there\'s no model to retrain')
return 69

model = model.to(device)
# optimizer = torch.optim.RMSprop(model.parameters(), lr=.001, alpha=.99, eps=1e-08, weight_decay=0, momentum=0, centered=False)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
# loss_function = torch.nn.HingeEmbeddingLoss()
loss_function = torch.nn.BCELoss()

"""
Expand Down Expand Up @@ -124,22 +131,4 @@ def execute(train_set_size, batch_size, lr, epochs, is_verbose, weight_decay):

print('Saving model...')
torch.save(model.state_dict(), 'models/trained/model.pt')
print('Saved 🌭')

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('-ts', '--trainsetsize', nargs=1, type=float, choices=[x/10 for x in range(0, 10)], help='Train set split size', required=True)
parser.add_argument('-bs', '--batchsize', nargs=1, type=int, help='Size of the training batch', required=False)
parser.add_argument('-lr', '--learningrate', nargs=1, type=float, help='Learning rate', required=False)
parser.add_argument('-e', '--epochs', nargs=1, type=int, help='Number of epochs', required=False)
parser.add_argument('-v', '--verbose', nargs=1, type=bool, help='Verbose mode on/off', required=False)
parser.add_argument('-wd', '--weightdecay', nargs=1, type=float, help='Weight decay (L2 regularization)', required=False)
args = parser.parse_args()
execute(
train_set_size=args.trainsetsize[0] if args.trainsetsize[0]!=0 else 0.8,
batch_size=args.batchsize[0] if args.batchsize else 1000,
lr=args.learningrate[0] if args.learningrate else 0.5,
epochs=args.epochs[0] if args.epochs else 10,
is_verbose=args.verbose if args.verbose else True,
weight_decay=args.weightdecay[0] if args.weightdecay else 0.9
)
print('Saved 🌭')

0 comments on commit 5937256

Please sign in to comment.