From ec42795f8273cab2a52315fbd0c776bbb5347346 Mon Sep 17 00:00:00 2001 From: Gonzalo Mellizo-Soto Date: Mon, 10 Jun 2024 18:24:32 +0200 Subject: [PATCH] Add json option for all resources --- giza/cli/commands/agents.py | 16 ++++++++++++++++ giza/cli/commands/endpoints.py | 19 +++++++++++++++++++ giza/cli/commands/models.py | 12 ++++++++++-- giza/cli/commands/versions.py | 13 +++++++++++++ giza/cli/frameworks/cairo.py | 5 ++++- giza/cli/options.py | 2 +- giza/cli/utils/echo.py | 5 +++-- 7 files changed, 66 insertions(+), 6 deletions(-) diff --git a/giza/cli/commands/agents.py b/giza/cli/commands/agents.py index 5f002f7..da7a4c6 100644 --- a/giza/cli/commands/agents.py +++ b/giza/cli/commands/agents.py @@ -13,6 +13,7 @@ DEBUG_OPTION, DESCRIPTION_OPTION, ENDPOINT_OPTION, + JSON_OPTION, MODEL_OPTION, NAME_OPTION, VERSION_OPTION, @@ -43,8 +44,13 @@ def create( endpoint_id: int = ENDPOINT_OPTION, name: Optional[str] = NAME_OPTION, description: Optional[str] = DESCRIPTION_OPTION, + json: Optional[bool] = JSON_OPTION, debug: Optional[bool] = DEBUG_OPTION, ) -> None: + + if json: + echo.set_log_file() + echo("Creating agent ✅ ") if not model_id and not version_id and not endpoint_id: @@ -122,8 +128,12 @@ def list( parameters: Optional[List[str]] = typer.Option( None, "--parameters", "-p", help="The parameters of the agent" ), + json: Optional[bool] = JSON_OPTION, debug: Optional[bool] = DEBUG_OPTION, ) -> None: + if json: + echo.set_log_file() + echo("Listing agents ✅ ") with ExceptionHandler(debug=debug): client = AgentsClient(API_HOST) @@ -148,8 +158,11 @@ def list( ) def get( agent_id: int = AGENT_OPTION, + json: Optional[bool] = JSON_OPTION, debug: Optional[bool] = DEBUG_OPTION, ) -> None: + if json: + echo.set_log_file() echo(f"Getting agent {agent_id} ✅ ") with ExceptionHandler(debug=debug): client = AgentsClient(API_HOST) @@ -195,8 +208,11 @@ def update( parameters: Optional[List[str]] = typer.Option( None, "--parameters", "-p", help="The parameters of the agent" ), + json: Optional[bool] = JSON_OPTION, debug: Optional[bool] = DEBUG_OPTION, ) -> None: + if json: + echo.set_log_file() echo(f"Updating agent {agent_id} ✅ ") with ExceptionHandler(debug=debug): client = AgentsClient(API_HOST) diff --git a/giza/cli/commands/endpoints.py b/giza/cli/commands/endpoints.py index 6d03dcc..fc5797b 100644 --- a/giza/cli/commands/endpoints.py +++ b/giza/cli/commands/endpoints.py @@ -12,6 +12,7 @@ DEBUG_OPTION, ENDPOINT_OPTION, FRAMEWORK_OPTION, + JSON_OPTION, MODEL_OPTION, VERSION_OPTION, ) @@ -76,8 +77,11 @@ def list( only_active: bool = typer.Option( False, "--only-active", "-a", help="Only list active endpoints" ), + json: Optional[bool] = JSON_OPTION, debug: Optional[bool] = DEBUG_OPTION, ) -> None: + if json: + echo.set_log_file() echo("Listing endpoints ✅ ") params = {} try: @@ -126,8 +130,11 @@ def list( ) def get( endpoint_id: int = ENDPOINT_OPTION, + json: Optional[bool] = JSON_OPTION, debug: Optional[bool] = DEBUG_OPTION, ) -> None: + if json: + echo.set_log_file() echo(f"Getting endpoint {endpoint_id} ✅ ") try: client = EndpointsClient(API_HOST) @@ -190,8 +197,11 @@ def delete_endpoint( ) def list_proofs( endpoint_id: int = ENDPOINT_OPTION, + json: Optional[bool] = JSON_OPTION, debug: Optional[bool] = DEBUG_OPTION, ) -> None: + if json: + echo.set_log_file() echo(f"Getting proofs from endpoint {endpoint_id} ✅ ") try: client = EndpointsClient(API_HOST) @@ -236,8 +246,11 @@ def get_proof( proof_id: str = typer.Option( None, "--proof-id", "-p", help="The ID or request id of the proof" ), + json: Optional[bool] = JSON_OPTION, debug: Optional[bool] = DEBUG_OPTION, ) -> None: + if json: + echo.set_log_file() echo(f"Getting proof from endpoint {endpoint_id} ✅ ") try: client = EndpointsClient(API_HOST) @@ -333,8 +346,11 @@ def download_proof( ) def list_jobs( endpoint_id: int = ENDPOINT_OPTION, + json: Optional[bool] = JSON_OPTION, debug: Optional[bool] = DEBUG_OPTION, ) -> None: + if json: + echo.set_log_file() echo(f"Getting jobs from endpoint {endpoint_id} ✅ ") with ExceptionHandler(debug=debug): client = EndpointsClient(API_HOST) @@ -356,8 +372,11 @@ def verify( proof_id: str = typer.Option( None, "--proof-id", "-p", help="The ID or request id of the proof" ), + json: Optional[bool] = JSON_OPTION, debug: Optional[bool] = DEBUG_OPTION, ) -> None: + if json: + echo.set_log_file() echo(f"Verifying proof from endpoint {endpoint_id} ✅ ") with ExceptionHandler(debug=debug): client = EndpointsClient(API_HOST) diff --git a/giza/cli/commands/models.py b/giza/cli/commands/models.py index 088c680..90d622f 100644 --- a/giza/cli/commands/models.py +++ b/giza/cli/commands/models.py @@ -7,7 +7,7 @@ from giza.cli import API_HOST from giza.cli.client import ModelsClient -from giza.cli.options import DEBUG_OPTION, DESCRIPTION_OPTION, MODEL_OPTION +from giza.cli.options import DEBUG_OPTION, DESCRIPTION_OPTION, JSON_OPTION, MODEL_OPTION from giza.cli.schemas.models import ModelCreate from giza.cli.utils import echo, get_response_info @@ -27,6 +27,7 @@ ) def get( model_id: int = MODEL_OPTION, + json: Optional[bool] = JSON_OPTION, debug: Optional[bool] = DEBUG_OPTION, ) -> None: """ @@ -40,6 +41,8 @@ def get( ValidationError: input fields are validated, if these are not suitable the exception is raised HTTPError: request error to the API, 4XX or 5XX """ + if json: + echo.set_log_file() echo("Retrieving model information ✅ ") try: client = ModelsClient(API_HOST) @@ -81,6 +84,7 @@ def get( """, ) def list( + json: Optional[bool] = JSON_OPTION, debug: Optional[bool] = DEBUG_OPTION, ) -> None: """ @@ -93,7 +97,8 @@ def list( ValidationError: input fields are validated, if these are not suitable the exception is raised HTTPError: request error to the API, 4XX or 5XX """ - + if json: + echo.set_log_file() echo("Listing models ✅ ") try: client = ModelsClient(API_HOST) @@ -138,6 +143,7 @@ def create( ..., "--name", "-n", help="Name of the model to be created" ), description: str = DESCRIPTION_OPTION, + json: Optional[bool] = JSON_OPTION, debug: Optional[bool] = DEBUG_OPTION, ) -> None: """ @@ -151,6 +157,8 @@ def create( ValidationError: input fields are validated, if these are not suitable the exception is raised HTTPError: request error to the API, 4XX or 5XX """ + if json: + echo.set_log_file() if name is None or name == "": echo.error("Name is required") sys.exit(1) diff --git a/giza/cli/commands/versions.py b/giza/cli/commands/versions.py index 6bb8d7b..cbee757 100644 --- a/giza/cli/commands/versions.py +++ b/giza/cli/commands/versions.py @@ -15,6 +15,7 @@ DESCRIPTION_OPTION, FRAMEWORK_OPTION, INPUT_OPTION, + JSON_OPTION, MODEL_OPTION, OUTPUT_PATH_OPTION, VERSION_OPTION, @@ -49,8 +50,11 @@ def update_sierra(model_id: int, version_id: int, model_path: str): def get( model_id: int = MODEL_OPTION, version_id: int = VERSION_OPTION, + json: Optional[bool] = JSON_OPTION, debug: bool = DEBUG_OPTION, ) -> None: + if json: + echo.set_log_file() if any([model_id is None, version_id is None]): echo.error("⛔️Model ID and version ID are required⛔️") sys.exit(1) @@ -79,6 +83,7 @@ def transpile( "--download-sierra", help="Download the siera file is the modle is fully compatible. CAIRO only.", ), + json: Optional[bool] = JSON_OPTION, debug: Optional[bool] = DEBUG_OPTION, ) -> None: if framework == Framework.CAIRO: @@ -90,6 +95,7 @@ def transpile( output_path=output_path, download_model=download_model, download_sierra=download_sierra, + json=json, debug=debug, ) elif framework == Framework.EZKL: @@ -145,8 +151,12 @@ def update( model_path: str = typer.Option( None, "--model-path", "-M", help="Path of the model to update" ), + json: Optional[bool] = JSON_OPTION, debug: bool = DEBUG_OPTION, ) -> None: + if json: + echo.set_log_file() + if any([model_id is None, version_id is None]): echo.error("⛔️Model ID and version ID are required to update the version⛔️") sys.exit(1) @@ -183,8 +193,11 @@ def update( ) def list( model_id: int = MODEL_OPTION, + json: Optional[bool] = JSON_OPTION, debug: bool = DEBUG_OPTION, ) -> None: + if json: + echo.set_log_file() if model_id is None: echo.error("⛔️Model ID is required⛔️") sys.exit(1) diff --git a/giza/cli/frameworks/cairo.py b/giza/cli/frameworks/cairo.py index 48ece5e..fb764c8 100644 --- a/giza/cli/frameworks/cairo.py +++ b/giza/cli/frameworks/cairo.py @@ -228,6 +228,7 @@ def transpile( output_path: str, download_model: bool, download_sierra: bool, + json: Optional[bool], debug: Optional[bool], ) -> None: """ @@ -256,7 +257,7 @@ def transpile( ValidationError: If there is a validation error with the model or version. HTTPError: If there is an HTTP error while communicating with the server. """ - echo = Echo(debug=debug) + echo = Echo(debug=debug, output_json=json) if model_path is None: echo.error("No model name provided, please provide a model path ⛔️") sys.exit(1) @@ -400,6 +401,8 @@ def transpile( if debug: raise zip_error sys.exit(1) + echo.print_model(model, title="Model") + echo.print_model(version, title="Version") def verify( diff --git a/giza/cli/options.py b/giza/cli/options.py index 1f329bf..817ea98 100644 --- a/giza/cli/options.py +++ b/giza/cli/options.py @@ -41,7 +41,7 @@ ) NAME_OPTION = typer.Option(None, "--name", "-n", help="The name of the resource") JSON_OPTION = typer.Option( - None, + False, "--json", "-j", help="Whether to print the output as JSON. This will make that the only ouput is the json and the logs will be saved to `giza.log`", diff --git a/giza/cli/utils/echo.py b/giza/cli/utils/echo.py index 0ab8f29..2115630 100644 --- a/giza/cli/utils/echo.py +++ b/giza/cli/utils/echo.py @@ -23,7 +23,7 @@ class Echo: LOG_FILE: str = "giza.log" def __init__( - self, debug: Optional[bool] = False, output_json: bool = False + self, debug: Optional[bool] = False, output_json: bool | None = False ) -> None: self._debug = debug self._json = output_json @@ -205,8 +205,9 @@ def print_model(self, model: Union[BaseModel, RootModel], title=""): model (Union[BaseModel, RootModel]): The model or list of models to print title (str, optional): Title of the table. Defaults to "". """ - if self._json: + if self._json and self._file is not None: print_json(model.model_dump_json()) + self._file.write(model.model_dump_json(indent=4)) return table = Table(title=title)