diff --git a/src/train_test/distributed.py b/src/train_test/distributed.py index b3601337..f6c18513 100644 --- a/src/train_test/distributed.py +++ b/src/train_test/distributed.py @@ -56,6 +56,8 @@ def dtrain(args): print(f"-------------- HYPERPARAMETERS -----------") print(f" Learning rate: {args.learning_rate}") print(f" Dropout: {args.dropout}") + print(f" dropout_prot: {args.dropout_prot}") + print(f" pro_emb_dim: {args.pro_emb_dim}") print(f" Num epochs: {args.num_epochs}\n") print(f"----------------- DISTRIBUTED ARGS -----------------") @@ -87,7 +89,10 @@ def dtrain(args): # ==== Load model ==== # args.gpu is the local rank for this process model = Loader.init_model(model=MODEL, pro_feature=FEATURE, pro_edge=EDGEW, - dropout=args.dropout).cuda(args.gpu) + dropout=args.dropout, + dropout_prot=args.dropout_prot, + pro_emb_dim=args.pro_emb_dim).cuda(args.gpu) + cp_saver = CheckpointSaver(model=model, save_path=f'{cfg.MODEL_SAVE_DIR}/{MODEL_KEY}.model', train_all=False, patience=50, min_delta=0.2, diff --git a/src/utils/arg_parse.py b/src/utils/arg_parse.py index 7a9dcda9..5a06c732 100644 --- a/src/utils/arg_parse.py +++ b/src/utils/arg_parse.py @@ -95,6 +95,17 @@ def add_hyperparam_args(parser: argparse.ArgumentParser): action='store', type=float, default=0.4, help='Dropout rate for training (default: 0.4)' ) + parser.add_argument('-dop', + '--dropout_prot', + action='store', type=float, default=0.4, + help='Dropout rate for protein GCN branch for training (default: 0.4)' + ) + parser.add_argument('-embP', + '--pro_emb_dim', + action='stor', type=int, default=128, + help='Embedding dimension for protein GCN branch for training (default: 128)' + ) + parser.add_argument('-ne', '--num_epochs', action='store', type=int, default=2000,