From 82c1ff7943aba1da2af36b159c5a073c9a8623d5 Mon Sep 17 00:00:00 2001 From: Omar Younis <42100908+younik@users.noreply.github.com> Date: Mon, 9 Dec 2024 00:25:07 +0100 Subject: [PATCH] add prefix CLI list (#268) --- minari/cli.py | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/minari/cli.py b/minari/cli.py index 36b60419..adb38a50 100644 --- a/minari/cli.py +++ b/minari/cli.py @@ -138,18 +138,22 @@ def _show_dataset_table(datasets: Dict[str, Dict[str, Any]], table_title: str): queue = deque(table_tree.sub_nodes.values()) section_sentinel = object() + aggregate = False while queue: table_node = queue.popleft() if table_node is section_sentinel: table.add_section() - elif len(table_node.sub_nodes) == 0: + continue + + if len(table_node.sub_nodes) == 0: table.add_row(*table_node.to_row()) - elif len(table_node.sub_nodes) <= MAX_ROWS_PER_GROUP: - queue.extend(table_node.sub_nodes.values()) - queue.append(section_sentinel) - else: + elif aggregate and len(table_node.sub_nodes) > MAX_ROWS_PER_GROUP: table.add_row(*table_node.to_row()) table.add_section() + else: + queue.extend(table_node.sub_nodes.values()) + queue.append(section_sentinel) + aggregate = aggregate or len(table_node.sub_nodes) > 1 print(table) @@ -178,11 +182,14 @@ def list_cmd( all: Annotated[ bool, typer.Option("--all", "-a", help="Show all dataset versions.") ] = False, + prefix: Annotated[ + Optional[str], typer.Option("--prefix", "-p", help="Filter datasets by prefix.") + ] = None, ): """List Minari datasets in local or remote storage.""" if path == "local": datasets = local.list_local_datasets( - latest_version=True, compatible_minari_version=not all + latest_version=True, compatible_minari_version=not all, prefix=prefix ) dataset_dir = os.environ.get( "MINARI_DATASETS_PATH", @@ -204,8 +211,10 @@ def list_cmd( raise typer.Abort() remote_type, remote_path = remote_path.split("://", maxsplit=1) - remote_path, prefix, *_ = remote_path.split("/", maxsplit=1) + [None] + remote_path, path_prefix, *_ = remote_path.split("/", maxsplit=1) + [None] remote_path = f"{remote_type}://{remote_path}" + if path_prefix or prefix: + prefix = os.path.join(path_prefix or "", prefix or "") datasets = hosting.list_remote_datasets( remote_path=remote_path,