Skip to content

Commit

Permalink
add default value in args
Browse files Browse the repository at this point in the history
  • Loading branch information
Jean-KOUAGOU committed Jan 13, 2025
1 parent 4b5fa54 commit 261e436
Showing 1 changed file with 1 addition and 2 deletions.
3 changes: 1 addition & 2 deletions examples/train_nces.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ def start(args):
synthesizer = ROCES(knowledge_base_path=args.kb, auto_train=False, k=5, max_length=48, proj_dim=128,
drop_prob=0.1, num_heads=4, num_seeds=1, m=32, load_pretrained=args.load_pretrained, verbose=True)
synthesizer.train(training_data, epochs=args.epochs, max_num_lps=args.max_num_lps, refinement_expressivity=args.refinement_expressivity)
print(synthesizer)

if __name__ == '__main__':
set_seed(42)
Expand All @@ -55,7 +54,7 @@ def start(args):
parser.add_argument('--refinement_expressivity', type=float, default=0.9, help='The expressivity of the refinement operator during training data generation')
parser.add_argument('--max_num_lps', type=int, default=20000, help='Maximum number of learning problems to generate if no training data is provided')
parser.add_argument('--path_of_embeddings', type=str, default=None, help='Path to a csv file containing embeddings for the KB.')
parser.add_argument('--path_train_data', type=str, help='Path to training data')
parser.add_argument('--path_train_data', type=str, default=None, help='Path to training data')
parser.add_argument('--epochs', type=int, default=50, help='Number of training epochs')
parser.add_argument('--load_pretrained', type=str2bool, default=False, help='Whether to load the pretrained model')
start(parser.parse_args())

0 comments on commit 261e436

Please sign in to comment.