From a34d5a34704eb908885eb11b92c6c98bfe27dbe9 Mon Sep 17 00:00:00 2001 From: Owyii Date: Sun, 14 Apr 2024 20:37:17 +0200 Subject: [PATCH] fix model --- mACHINE-LEARNINGS/train.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/mACHINE-LEARNINGS/train.py b/mACHINE-LEARNINGS/train.py index 1501e25..00d06de 100644 --- a/mACHINE-LEARNINGS/train.py +++ b/mACHINE-LEARNINGS/train.py @@ -61,8 +61,10 @@ def execute(train_set_size, batch_size, lr, epochs, is_verbose, weight_decay): """ model = DankCNN() model = model.to(device) - optimizer = torch.optim.RMSprop(model.parameters(), lr=.1, alpha=.99, eps=1e-08, weight_decay=0, momentum=0, centered=False) - loss_function = torch.nn.HingeEmbeddingLoss() + # optimizer = torch.optim.RMSprop(model.parameters(), lr=.001, alpha=.99, eps=1e-08, weight_decay=0, momentum=0, centered=False) + optimizer = torch.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0) + # loss_function = torch.nn.HingeEmbeddingLoss() + loss_function = torch.nn.BCELoss() """ TRAINING PHASE