Skip to content

Commit

Permalink
add prefix CLI list (#268)
Browse files Browse the repository at this point in the history
  • Loading branch information
younik authored Dec 8, 2024
1 parent 3864a98 commit 82c1ff7
Showing 1 changed file with 16 additions and 7 deletions.
23 changes: 16 additions & 7 deletions minari/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,18 +138,22 @@ def _show_dataset_table(datasets: Dict[str, Dict[str, Any]], table_title: str):

queue = deque(table_tree.sub_nodes.values())
section_sentinel = object()
aggregate = False
while queue:
table_node = queue.popleft()
if table_node is section_sentinel:
table.add_section()
elif len(table_node.sub_nodes) == 0:
continue

if len(table_node.sub_nodes) == 0:
table.add_row(*table_node.to_row())
elif len(table_node.sub_nodes) <= MAX_ROWS_PER_GROUP:
queue.extend(table_node.sub_nodes.values())
queue.append(section_sentinel)
else:
elif aggregate and len(table_node.sub_nodes) > MAX_ROWS_PER_GROUP:
table.add_row(*table_node.to_row())
table.add_section()
else:
queue.extend(table_node.sub_nodes.values())
queue.append(section_sentinel)
aggregate = aggregate or len(table_node.sub_nodes) > 1

print(table)

Expand Down Expand Up @@ -178,11 +182,14 @@ def list_cmd(
all: Annotated[
bool, typer.Option("--all", "-a", help="Show all dataset versions.")
] = False,
prefix: Annotated[
Optional[str], typer.Option("--prefix", "-p", help="Filter datasets by prefix.")
] = None,
):
"""List Minari datasets in local or remote storage."""
if path == "local":
datasets = local.list_local_datasets(
latest_version=True, compatible_minari_version=not all
latest_version=True, compatible_minari_version=not all, prefix=prefix
)
dataset_dir = os.environ.get(
"MINARI_DATASETS_PATH",
Expand All @@ -204,8 +211,10 @@ def list_cmd(
raise typer.Abort()

remote_type, remote_path = remote_path.split("://", maxsplit=1)
remote_path, prefix, *_ = remote_path.split("/", maxsplit=1) + [None]
remote_path, path_prefix, *_ = remote_path.split("/", maxsplit=1) + [None]
remote_path = f"{remote_type}://{remote_path}"
if path_prefix or prefix:
prefix = os.path.join(path_prefix or "", prefix or "")

datasets = hosting.list_remote_datasets(
remote_path=remote_path,
Expand Down

0 comments on commit 82c1ff7

Please sign in to comment.