diff --git a/llm/cli.py b/llm/cli.py index 01c5bffb1..a2266ffa5 100644 --- a/llm/cli.py +++ b/llm/cli.py @@ -289,10 +289,44 @@ def cli(): """ +def complete_model(ctx: click.Context, param: click.Parameter, incomplete: str): + from click.shell_completion import CompletionItem + + return [CompletionItem(alias) for alias in get_model_aliases().keys() if alias.startswith(incomplete)] + +def complete_embedding_model(ctx: click.Context, param: click.Parameter, incomplete: str): + from click.shell_completion import CompletionItem + + return [CompletionItem(alias) for alias in get_embedding_model_aliases().keys() if alias.startswith(incomplete)] + +def complete_option(ctx: click.Context, param: click.Parameter, incomplete): + from click.shell_completion import CompletionItem + + # This function is actually hit for the option 'value' as well. + # But there's no way to tell without patching click. + + try: + model_id = ctx.params["model_id"] + model = get_model(model_id or get_default_model()) + except (AttributeError, UnknownModelError): + return [] + options = model.Options.model_json_schema()["properties"].keys() + return [CompletionItem(option) for option in options if option.startswith(incomplete)] + +class TemplateType(click.ParamType): + name = "template" + + def shell_complete(self, ctx, param, incomplete): + from click.shell_completion import CompletionItem + + path = template_dir() + return [CompletionItem(file.stem) for file in path.glob(incomplete + "*.yaml")] + + @cli.command(name="prompt") @click.argument("prompt", required=False) @click.option("-s", "--system", help="System prompt to use") -@click.option("model_id", "-m", "--model", help="Model to use", envvar="LLM_MODEL") +@click.option("model_id", "-m", "--model", help="Model to use", envvar="LLM_MODEL", shell_complete=complete_model) @click.option( "-d", "--database", @@ -330,6 +364,7 @@ def cli(): type=(str, str), multiple=True, help="key/value options for the model", + shell_complete=complete_option, ) @schema_option @click.option( @@ -350,7 +385,7 @@ def cli(): multiple=True, help="Fragment to add to system prompt", ) -@click.option("-t", "--template", help="Template to use") +@click.option("-t", "--template", help="Template to use", type=TemplateType()) @click.option( "-p", "--param", @@ -783,7 +818,7 @@ async def inner(): @cli.command() @click.option("-s", "--system", help="System prompt to use") -@click.option("model_id", "-m", "--model", help="Model to use", envvar="LLM_MODEL") +@click.option("model_id", "-m", "--model", help="Model to use", envvar="LLM_MODEL", shell_complete=complete_model) @click.option( "_continue", "-c", @@ -798,7 +833,7 @@ async def inner(): "--conversation", help="Continue the conversation with the given ID.", ) -@click.option("-t", "--template", help="Template to use") +@click.option("-t", "--template", help="Template to use", type=TemplateType()) @click.option( "-p", "--param", @@ -813,6 +848,7 @@ async def inner(): type=(str, str), multiple=True, help="key/value options for the model", + shell_complete=complete_option, ) @click.option("--no-stream", is_flag=True, help="Do not stream output") @click.option("--key", help="API key to use") @@ -1843,7 +1879,7 @@ def templates_list(): @templates.command(name="show") -@click.argument("name") +@click.argument("name", type=TemplateType()) def templates_show(name): "Show the specified prompt template" template = load_template(name) @@ -1857,7 +1893,7 @@ def templates_show(name): @templates.command(name="edit") -@click.argument("name") +@click.argument("name", type=TemplateType()) def templates_edit(name): "Edit the specified prompt template using the default $EDITOR" # First ensure it exists @@ -2373,7 +2409,8 @@ def uninstall(packages, yes): help="File to embed", ) @click.option( - "-m", "--model", help="Embedding model to use", envvar="LLM_EMBEDDING_MODEL" + "-m", "--model", help="Embedding model to use", envvar="LLM_EMBEDDING_MODEL", + shell_complete=complete_embedding_model ) @click.option("--store", is_flag=True, help="Store the text itself in the database") @click.option( @@ -2517,7 +2554,8 @@ def get_db(): ) @click.option("--prefix", help="Prefix to add to the IDs", default="") @click.option( - "-m", "--model", help="Embedding model to use", envvar="LLM_EMBEDDING_MODEL" + "-m", "--model", help="Embedding model to use", envvar="LLM_EMBEDDING_MODEL", + shell_complete=complete_embedding_model ) @click.option( "--prepend",