@@ -263,16 +263,6 @@ def run_and_combine_outputs(command, *args):
263
263
return subprocess .check_output (command_string , stderr = subprocess .STDOUT )
264
264
265
265
266
- def find_endpoint (argv , shortcuts = {}):
267
- # endpoint is first positional argument
268
- pattern = re .compile ("(.*https?://.*|[a-zA-Z]:\\ .*)" )
269
- indices = []
270
- for index , arg in enumerate (argv ):
271
- if arg in shortcuts or (Endpoint .is_endpoint (arg ) and not pattern .match (arg )):
272
- indices .append (index )
273
- return - 1 if len (indices ) == 0 else indices [- 1 ]
274
-
275
-
276
266
_default_log_levels = (
277
267
"NOTSET" ,
278
268
"DEBUG" ,
@@ -285,6 +275,45 @@ def find_endpoint(argv, shortcuts={}):
285
275
)
286
276
287
277
278
+ class CustomArgParser (argparse .ArgumentParser ):
279
+ def __init__ (self , * args , ** kwargs ):
280
+ super ().__init__ (* args , ** kwargs )
281
+ self ._found_unknown_hyphenated_args = False
282
+ self ._found_endpoint = False
283
+ self ._found_optionals = []
284
+
285
+ def _match_arguments_partial (self , actions , arg_strings_pattern ):
286
+ # Doesnt support --additional-endpoints yet
287
+ result = []
288
+ args_after_double_equals = len (arg_strings_pattern .partition ("-" )[2 ])
289
+ for i , arg_string in enumerate (self ._found_optionals ):
290
+ if Endpoint .is_endpoint (arg_string ):
291
+ rv = [
292
+ i ,
293
+ 1 ,
294
+ len (self ._found_optionals ) - i - 1 + args_after_double_equals ,
295
+ ]
296
+ return rv
297
+ return result
298
+
299
+ def _parse_optional (self , arg_string ):
300
+ if arg_string .startswith ("-" ) and arg_string not in self ._option_string_actions :
301
+ self ._found_unknown_hyphenated_args = True
302
+ elif Endpoint .is_endpoint (arg_string ):
303
+ self ._found_endpoint = True
304
+
305
+ if self ._found_unknown_hyphenated_args or self ._found_endpoint :
306
+ self ._found_optionals .append (arg_string )
307
+ return None
308
+
309
+ rv = super ()._parse_optional (arg_string )
310
+ return rv
311
+
312
+ def error (self , message ):
313
+ if message == "the following arguments are required: <endpoint>" :
314
+ raise NoEndpointProvided ([])
315
+
316
+
288
317
def jgo_parser (log_levels = _default_log_levels ):
289
318
usage = (
290
319
"usage: jgo [-v] [-u] [-U] [-m] [-q] [--log-level] [--ignore-jgorc]\n "
@@ -307,7 +336,8 @@ def jgo_parser(log_levels=_default_log_levels):
307
336
and it will be auto-completed.
308
337
"""
309
338
310
- parser = argparse .ArgumentParser (
339
+ parser = CustomArgParser (
340
+ prog = "jgo" ,
311
341
description = "Run Java main class from Maven coordinates." ,
312
342
usage = usage [len ("usage: " ) :],
313
343
epilog = epilog ,
@@ -376,6 +406,25 @@ def jgo_parser(log_levels=_default_log_levels):
376
406
parser .add_argument (
377
407
"--log-level" , default = None , type = str , help = "Set log level" , choices = log_levels
378
408
)
409
+ parser .add_argument (
410
+ "jvm_args" ,
411
+ help = "JVM arguments" ,
412
+ metavar = "jvm-args" ,
413
+ nargs = "*" ,
414
+ default = [],
415
+ )
416
+ parser .add_argument (
417
+ "endpoint" ,
418
+ help = "Endpoint" ,
419
+ metavar = "<endpoint>" ,
420
+ )
421
+ parser .add_argument (
422
+ "program_args" ,
423
+ help = "Program arguments" ,
424
+ metavar = "main-args" ,
425
+ nargs = "*" ,
426
+ default = [],
427
+ )
379
428
380
429
return parser
381
430
@@ -719,15 +768,18 @@ def run(parser, argv=sys.argv[1:], stdout=None, stderr=None):
719
768
repositories = config ["repositories" ]
720
769
shortcuts = config ["shortcuts" ]
721
770
722
- endpoint_index = find_endpoint (argv , shortcuts )
723
- if endpoint_index == - 1 :
724
- raise HelpRequested (
725
- argv
726
- ) if "-h" in argv or "--help" in argv else NoEndpointProvided (argv )
771
+ if "-h" in argv or "--help" in argv :
772
+ raise HelpRequested (argv )
773
+
774
+ args = parser .parse_args (argv )
775
+
776
+ if not args .endpoint :
777
+ raise NoEndpointProvided (argv )
778
+ if args .endpoint in shortcuts and not Endpoint .is_endpoint (args .endpoint ):
779
+ raise NoEndpointProvided (argv )
727
780
728
- args , unknown = parser .parse_known_args (argv [:endpoint_index ])
729
- jvm_args = unknown if unknown else []
730
- program_args = [] if endpoint_index == - 1 else argv [endpoint_index + 1 :]
781
+ jvm_args = args .jvm_args
782
+ program_args = args .program_args
731
783
if args .log_level :
732
784
logging .getLogger ().setLevel (logging .getLevelName (args .log_level ))
733
785
@@ -757,7 +809,7 @@ def run(parser, argv=sys.argv[1:], stdout=None, stderr=None):
757
809
if args .force_update :
758
810
args .update_cache = True
759
811
760
- endpoint_string = "+" .join ([argv [ endpoint_index ] ] + args .additional_endpoints )
812
+ endpoint_string = "+" .join ([args . endpoint ] + args .additional_endpoints )
761
813
762
814
primary_endpoint , workspace = resolve_dependencies (
763
815
endpoint_string ,
0 commit comments