Skip to content

Commit

Permalink
fix printing
Browse files Browse the repository at this point in the history
  • Loading branch information
Samy Wu Fung committed Aug 11, 2023
1 parent 0c629ad commit aad6213
Showing 1 changed file with 11 additions and 10 deletions.
21 changes: 11 additions & 10 deletions src/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,13 +204,14 @@ def trainer_warcraft(net, train_dataset, val_dataset, test_dataset,
epoch_time=0

# print initial test loss in :7.3e format
print('initial_val_loss: ', "{:5.2e}".format(val_loss),
' | initial_val_acc: ', "{:<4f}".format(val_acc),
' | initial_val_cost_pred: ', "{:5.2e}".format(val_cost_pred),
print('INITIAL VALUES:')
print('val_loss: ', "{:5.2e}".format(val_loss),
' | val_acc: ', "{:<4.3g}".format(val_acc),
' | : ', "{:5.2e}".format(val_cost_pred),
' | true val_cost: ', "{:5.2e}".format(val_cost_true),
'initial_test_loss: ', "{:5.2e}".format(test_loss),
' | initial_test_acc: ', "{:<4f}".format(test_acc),
' | initial_test_cost_pred: ', "{:5.2e}".format(test_cost_pred),
'test_loss: ', "{:5.2e}".format(test_loss),
' | test_acc: ', "{:<4.3g}".format(test_acc),
' | test_cost_pred: ', "{:5.2e}".format(test_cost_pred),
' | true test_cost: ', "{:5.2e}".format(test_cost_true))

while epoch <= max_epochs:
Expand Down Expand Up @@ -256,7 +257,7 @@ def trainer_warcraft(net, train_dataset, val_dataset, test_dataset,

print('epoch: ', epoch, '| ave_tr_loss: ', "{:5.2e}".format(train_loss_ave),
'| val_loss: ', "{:5.2e}".format(val_loss),
'| val_acc.: ', "{:<4f}".format(val_acc),
'| val_acc.: ', "{:<4.3g}".format(val_acc),
'| val_cost_pred: ', "{:5.2e}".format(val_cost_pred),
'| lr: ', "{:5.2e}".format(optimizer.param_groups[0]['lr']),
'| time: ', "{:<15f}".format(epoch_time))
Expand All @@ -269,8 +270,8 @@ def trainer_warcraft(net, train_dataset, val_dataset, test_dataset,
print('path_pred edge = ', torch.nonzero(path_pred[2,:]))
print(edge_to_node(path_pred[2,:], edge_list, grid_size, device))
print('\n True Path \n')
print(edge_to_node(path_batch[2,:], edge_list, grid_size, device))
print('path_batch edge = ', torch.nonzero(path_batch[2,:]))
print(edge_to_node(path_batch_edge[2,:], edge_list, grid_size, device))
print('path_batch edge = ', torch.nonzero(path_batch_edge[2,:]))
print('\n ------------------------ \n')


Expand All @@ -294,7 +295,7 @@ def trainer_warcraft(net, train_dataset, val_dataset, test_dataset,
# pred_batch_edge_form=True means path_pred_edge is in edge form.
test_acc, test_cost_pred, test_cost_true = compute_accuracy(path_pred_edge, path_batch_vertex, costs_batch, edge_list, grid_size, device, pred_batch_edge_form=True)

print('final test loss is ', "{:5.2e}".format(test_loss), ' | final test acc. is ', "{:<4f}".format(test_acc), ' | final test cost pred is ', "{:5.2e}".format(test_cost_pred), ' | final test cost true is ', "{:5.2e}".format(test_cost_true))
print('final test loss is ', "{:5.2e}".format(test_loss), ' | final test acc. is ', "{:<4.3g}".format(test_acc), ' | final test cost pred is ', "{:5.2e}".format(test_cost_pred), ' | final test cost true is ', "{:5.2e}".format(test_cost_true))

return best_params, val_loss_hist, val_acc_hist, test_loss, test_acc, train_time

0 comments on commit aad6213

Please sign in to comment.