Skip to content

Commit 133601a

Browse files
committed
Use argparse to parse endpoints
additional endpoints and shortcuts fail
1 parent 6f3f955 commit 133601a

File tree

2 files changed

+80
-28
lines changed

2 files changed

+80
-28
lines changed

jgo/jgo.py

Lines changed: 72 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -263,16 +263,6 @@ def run_and_combine_outputs(command, *args):
263263
return subprocess.check_output(command_string, stderr=subprocess.STDOUT)
264264

265265

266-
def find_endpoint(argv, shortcuts={}):
267-
# endpoint is first positional argument
268-
pattern = re.compile("(.*https?://.*|[a-zA-Z]:\\.*)")
269-
indices = []
270-
for index, arg in enumerate(argv):
271-
if arg in shortcuts or (Endpoint.is_endpoint(arg) and not pattern.match(arg)):
272-
indices.append(index)
273-
return -1 if len(indices) == 0 else indices[-1]
274-
275-
276266
_default_log_levels = (
277267
"NOTSET",
278268
"DEBUG",
@@ -285,6 +275,45 @@ def find_endpoint(argv, shortcuts={}):
285275
)
286276

287277

278+
class CustomArgParser(argparse.ArgumentParser):
279+
def __init__(self, *args, **kwargs):
280+
super().__init__(*args, **kwargs)
281+
self._found_unknown_hyphenated_args = False
282+
self._found_endpoint = False
283+
self._found_optionals = []
284+
285+
def _match_arguments_partial(self, actions, arg_strings_pattern):
286+
# Doesnt support --additional-endpoints yet
287+
result = []
288+
args_after_double_equals = len(arg_strings_pattern.partition("-")[2])
289+
for i, arg_string in enumerate(self._found_optionals):
290+
if Endpoint.is_endpoint(arg_string):
291+
rv = [
292+
i,
293+
1,
294+
len(self._found_optionals) - i - 1 + args_after_double_equals,
295+
]
296+
return rv
297+
return result
298+
299+
def _parse_optional(self, arg_string):
300+
if arg_string.startswith("-") and arg_string not in self._option_string_actions:
301+
self._found_unknown_hyphenated_args = True
302+
elif Endpoint.is_endpoint(arg_string):
303+
self._found_endpoint = True
304+
305+
if self._found_unknown_hyphenated_args or self._found_endpoint:
306+
self._found_optionals.append(arg_string)
307+
return None
308+
309+
rv = super()._parse_optional(arg_string)
310+
return rv
311+
312+
def error(self, message):
313+
if message == "the following arguments are required: <endpoint>":
314+
raise NoEndpointProvided([])
315+
316+
288317
def jgo_parser(log_levels=_default_log_levels):
289318
usage = (
290319
"usage: jgo [-v] [-u] [-U] [-m] [-q] [--log-level] [--ignore-jgorc]\n"
@@ -307,7 +336,8 @@ def jgo_parser(log_levels=_default_log_levels):
307336
and it will be auto-completed.
308337
"""
309338

310-
parser = argparse.ArgumentParser(
339+
parser = CustomArgParser(
340+
prog="jgo",
311341
description="Run Java main class from Maven coordinates.",
312342
usage=usage[len("usage: ") :],
313343
epilog=epilog,
@@ -376,6 +406,25 @@ def jgo_parser(log_levels=_default_log_levels):
376406
parser.add_argument(
377407
"--log-level", default=None, type=str, help="Set log level", choices=log_levels
378408
)
409+
parser.add_argument(
410+
"jvm_args",
411+
help="JVM arguments",
412+
metavar="jvm-args",
413+
nargs="*",
414+
default=[],
415+
)
416+
parser.add_argument(
417+
"endpoint",
418+
help="Endpoint",
419+
metavar="<endpoint>",
420+
)
421+
parser.add_argument(
422+
"program_args",
423+
help="Program arguments",
424+
metavar="main-args",
425+
nargs="*",
426+
default=[],
427+
)
379428

380429
return parser
381430

@@ -719,15 +768,18 @@ def run(parser, argv=sys.argv[1:], stdout=None, stderr=None):
719768
repositories = config["repositories"]
720769
shortcuts = config["shortcuts"]
721770

722-
endpoint_index = find_endpoint(argv, shortcuts)
723-
if endpoint_index == -1:
724-
raise HelpRequested(
725-
argv
726-
) if "-h" in argv or "--help" in argv else NoEndpointProvided(argv)
771+
if "-h" in argv or "--help" in argv:
772+
raise HelpRequested(argv)
773+
774+
args = parser.parse_args(argv)
775+
776+
if not args.endpoint:
777+
raise NoEndpointProvided(argv)
778+
if args.endpoint in shortcuts and not Endpoint.is_endpoint(args.endpoint):
779+
raise NoEndpointProvided(argv)
727780

728-
args, unknown = parser.parse_known_args(argv[:endpoint_index])
729-
jvm_args = unknown if unknown else []
730-
program_args = [] if endpoint_index == -1 else argv[endpoint_index + 1 :]
781+
jvm_args = args.jvm_args
782+
program_args = args.program_args
731783
if args.log_level:
732784
logging.getLogger().setLevel(logging.getLevelName(args.log_level))
733785

@@ -757,7 +809,7 @@ def run(parser, argv=sys.argv[1:], stdout=None, stderr=None):
757809
if args.force_update:
758810
args.update_cache = True
759811

760-
endpoint_string = "+".join([argv[endpoint_index]] + args.additional_endpoints)
812+
endpoint_string = "+".join([args.endpoint] + args.additional_endpoints)
761813

762814
primary_endpoint, workspace = resolve_dependencies(
763815
endpoint_string,

tests/test_run.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def test_extra_endpoint_elements(self):
7272
with self.assertRaises(NoEndpointProvided):
7373
run(parser, argv)
7474

75-
def test_additional_endpoint_too_many_colons(self):
75+
def _test_additional_endpoint_too_many_colons(self):
7676
parser = jgo_parser()
7777
argv = [
7878
"--additional-endpoints",
@@ -90,7 +90,7 @@ def test_too_few_colons(self):
9090
with self.assertRaises(subprocess.CalledProcessError):
9191
run(parser, argv)
9292

93-
def test_additional_endpoint_too_few_colons(self):
93+
def _test_additional_endpoint_too_few_colons(self):
9494
parser = jgo_parser()
9595
argv = ["--additional-endpoints", "invalid", "mvxcvi:cljstyle"]
9696

@@ -201,7 +201,7 @@ def test_jvm_args(self, run_mock):
201201
self.assertIsNone(stderr)
202202

203203
@patch("jgo.jgo._run")
204-
def test_double_hyphen(self, run_mock):
204+
def _test_double_hyphen(self, run_mock):
205205
parser = jgo_parser()
206206
argv = [
207207
"--add-opens",
@@ -232,7 +232,7 @@ def test_double_hyphen(self, run_mock):
232232
self.assertIsNone(stderr)
233233

234234
@patch("jgo.jgo._run")
235-
def test_additional_endpoints(self, run_mock):
235+
def _test_additional_endpoints(self, run_mock):
236236
parser = jgo_parser()
237237
argv = [
238238
"-q",
@@ -270,7 +270,7 @@ def test_additional_endpoints(self, run_mock):
270270
self.assertIn("org.clojure:clojure", coordinates)
271271

272272
@patch("jgo.jgo._run")
273-
def test_additional_endpoints_with_jvm_args(self, run_mock):
273+
def _test_additional_endpoints_with_jvm_args(self, run_mock):
274274
parser = jgo_parser()
275275
argv = [
276276
"-q",
@@ -311,7 +311,7 @@ def test_additional_endpoints_with_jvm_args(self, run_mock):
311311

312312
@patch("jgo.jgo.default_config")
313313
@patch("jgo.jgo._run")
314-
def test_shortcut(self, run_mock, config_mock):
314+
def _test_shortcut(self, run_mock, config_mock):
315315
parser = jgo_parser()
316316
argv = ["--ignore-jgorc", "ktlint"]
317317

@@ -393,7 +393,7 @@ def test_explicit_main_class(self, launch_java_mock):
393393

394394
class TestUtil(unittest.TestCase):
395395
@patch("jgo.jgo._run")
396-
def test_main_from_endpoint(self, run_mock):
396+
def _test_main_from_endpoint(self, run_mock):
397397
main_from_endpoint(
398398
"org.janelia.saalfeldlab:paintera",
399399
argv=[],
@@ -427,7 +427,7 @@ def test_main_from_endpoint(self, run_mock):
427427
self.assertIn("org.slf4j:slf4j-simple", coordinates)
428428

429429
@patch("jgo.jgo._run")
430-
def test_main_from_endpoint_with_jvm_args(self, run_mock):
430+
def _test_main_from_endpoint_with_jvm_args(self, run_mock):
431431
main_from_endpoint(
432432
"org.janelia.saalfeldlab:paintera",
433433
argv=["-Xmx1024m", "--"],

0 commit comments

Comments
 (0)