Skip to content

Commit

Permalink
final
Browse files Browse the repository at this point in the history
  • Loading branch information
siddsax committed Nov 17, 2018
1 parent a7214c4 commit cdaada5
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 8 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
code/saved_models
*saved_models
datasets
*.npy
*.npz
Expand Down
12 changes: 6 additions & 6 deletions code/cnn_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ def train(x_tr, y_tr, x_te, y_te, embedding_weights, params):

if i % int(num_mb/12) == 0:
print('Iter-{}; Loss: {:.4}; best_loss: {:.4}; max_grad: {}:'.format(i, loss.data, loss_best, max_grad))
if not os.path.exists('saved_models/' + params.model_name ):
os.makedirs('saved_models/' + params.model_name)
if not os.path.exists('../saved_models/' + params.model_name ):
os.makedirs('../saved_models/' + params.model_name)
save_model(model, optimizer, epoch, params.model_name + "/model_best_batch")
if(loss<loss_best):
loss_best = loss.data
Expand Down Expand Up @@ -101,8 +101,8 @@ def train(x_tr, y_tr, x_te, y_te, embedding_weights, params):
if(totalLoss<bestTotalLoss):

bestTotalLoss = totalLoss
if not os.path.exists('saved_models/' + params.model_name ):
os.makedirs('saved_models/' + params.model_name)
if not os.path.exists('../saved_models/' + params.model_name ):
os.makedirs('../saved_models/' + params.model_name)
save_model(model, optimizer, epoch, params.model_name + "/model_best_epoch")

print('End-of-Epoch: {} Loss: {:.4}; best_loss: {:.4};'.format(epoch, totalLoss, bestTotalLoss))
Expand All @@ -116,8 +116,8 @@ def train(x_tr, y_tr, x_te, y_te, embedding_weights, params):
best_test_loss = test_ce_loss
best_test_acc = test_prec_acc
print("This acc is better than the previous recored test acc:- {} ; while CELoss:- {}".format(best_test_acc, best_test_loss))
if not os.path.exists('saved_models/' + params.model_name ):
os.makedirs('saved_models/' + params.model_name)
if not os.path.exists('../saved_models/' + params.model_name ):
os.makedirs('../saved_models/' + params.model_name)
save_model(model, optimizer, epoch, params.model_name + "/model_best_test")

if epoch % params.save_step == 0:
Expand Down
2 changes: 1 addition & 1 deletion utils/futils.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def save_model(model, optimizer, epoch, name):
'optimizer': optimizer.state_dict(),
'epoch': epoch
}
torch.save(checkpoint, "saved_models/" + name)
torch.save(checkpoint, "../saved_models/" + name)
def sample_z(mu, log_var, params, dtype_f):
eps = Variable(torch.randn(params.batch_size, params.Z_dim).type(dtype_f))
k = torch.exp(log_var / 2) * eps
Expand Down

0 comments on commit cdaada5

Please sign in to comment.