Skip to content

Commit

Permalink
Issue #720/#402/#725 eliminate _ensure_save_result
Browse files Browse the repository at this point in the history
  • Loading branch information
soxofaan committed Feb 12, 2025
1 parent 923f4d9 commit 3302d87
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 97 deletions.
39 changes: 0 additions & 39 deletions openeo/rest/_datacube.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,42 +324,3 @@ def build_child_callback(
raise ValueError(process)

return PGNode.to_process_graph_argument(pg)


def _ensure_save_result(
cube: _ProcessGraphAbstraction,
*,
format: Optional[str] = None,
options: Optional[dict] = None,
weak_format: Optional[str] = None,
default_format: str,
method: str,
) -> Union[SaveResult, StacResource]:
"""
Make sure there is a`save_result` node in the process graph.
:param format: (optional) desired `save_result` file format
:param options: (optional) desired `save_result` file format parameters
:param weak_format: (optional) weak format indicator guessed from file name
:param default_format: default format for data type to use when no format is specified by user
:return:
"""
# TODO #278 instead of standalone helper function, move this to common base class for raster cubes, vector cubes, ...
save_result_nodes = [n for n in cube.result_node().walk_nodes() if n.process_id == "save_result"]

if not save_result_nodes:
# No `save_result` node yet: automatically add it.
if isinstance(cube, (openeo.DataCube, openeo.VectorCube)):
pg_with_save_result = cube.save_result(format=format or weak_format or default_format, options=options)
else:
raise OpenEoClientException(f"No support to add `save_result` on {cube!r}.")
else:
if format or options:
raise OpenEoClientException(
f"{method} with explicit output {'format' if format else 'options'} {format or options!r},"
f" but the process graph already has `save_result` node(s)"
f" which is ambiguous and should not be combined."
)
pg_with_save_result = cube

return pg_with_save_result
45 changes: 17 additions & 28 deletions openeo/rest/datacube.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@
from openeo.rest._datacube import (
THIS,
UDF,
_ensure_save_result,
_ProcessGraphAbstraction,
build_child_callback,
)
Expand All @@ -65,7 +64,6 @@
from openeo.rest.mlmodel import MlModel
from openeo.rest.result import SaveResult
from openeo.rest.service import Service
from openeo.rest.stac_resource import StacResource
from openeo.rest.udp import RESTUserDefinedProcess
from openeo.rest.vectorcube import VectorCube
from openeo.util import dict_no_none, guess_format, load_json, normalize_crs, rfc3339
Expand Down Expand Up @@ -2352,6 +2350,17 @@ def save_result(
)
return SaveResult(pg, connection=self._connection)

def _auto_save_result(
self,
format: Optional[str] = None,
outputfile: Optional[Union[str, pathlib.Path]] = None,
options: Optional[dict] = None,
) -> SaveResult:
return self.save_result(
format=format or (guess_format(outputfile) if outputfile else None) or self._DEFAULT_RASTER_FORMAT,
options=options,
)

def download(
self,
outputfile: Optional[Union[str, pathlib.Path]] = None,
Expand All @@ -2374,7 +2383,7 @@ def download(
:param options: Optional, file format options
:param validate: Optional toggle to enable/prevent validation of the process graphs before execution
(overruling the connection's ``auto_validate`` setting).
:param auto_add_save_result: Automatically add a ``save_result`` node to the process graph if there is none yet.
:param auto_add_save_result: Automatically add a ``save_result`` node to the process graph.
:param additional: additional (top-level) properties to set in the request body
:param job_options: dictionary of job options to pass to the backend
(under top-level property "job_options")
Expand All @@ -2389,14 +2398,7 @@ def download(
"""
# TODO #278 centralize download/create_job/execute_job logic in DataCube, VectorCube, MlModel, ...
if auto_add_save_result:
res = _ensure_save_result(
cube=self,
format=format,
options=options,
weak_format=guess_format(outputfile) if outputfile else None,
default_format=self._DEFAULT_RASTER_FORMAT,
method="DataCube.download()",
)
res = self._auto_save_result(format=format, outputfile=outputfile, options=options)
else:
res = self
return self._connection.download(
Expand Down Expand Up @@ -2543,7 +2545,7 @@ def execute_batch(
(under top-level property "job_options")
:param validate: Optional toggle to enable/prevent validation of the process graphs before execution
(overruling the connection's ``auto_validate`` setting).
:param auto_add_save_result: Automatically add a ``save_result`` node to the process graph if there is none yet.
:param auto_add_save_result: Automatically add a ``save_result`` node to the process graph.
:param show_error_logs: whether to automatically print error logs when the batch job failed.
:param log_level: Optional minimum severity level for log entries that the back-end should keep track of.
One of "error" (highest severity), "warning", "info", and "debug" (lowest severity).
Expand All @@ -2568,14 +2570,7 @@ def execute_batch(

# TODO #278 centralize download/create_job/execute_job logic in DataCube, VectorCube, MlModel, ...
if auto_add_save_result:
res = _ensure_save_result(
cube=self,
format=out_format,
options=format_options,
weak_format=guess_format(outputfile) if outputfile else None,
default_format=self._DEFAULT_RASTER_FORMAT,
method="DataCube.execute_batch()",
)
res = self._auto_save_result(format=out_format, outputfile=outputfile, options=format_options)
create_kwargs = {}
else:
res = self
Expand Down Expand Up @@ -2637,7 +2632,7 @@ def create_job(
(under top-level property "job_options")
:param validate: Optional toggle to enable/prevent validation of the process graphs before execution
(overruling the connection's ``auto_validate`` setting).
:param auto_add_save_result: Automatically add a ``save_result`` node to the process graph if there is none yet.
:param auto_add_save_result: Automatically add a ``save_result`` node to the process graph.
:param log_level: Optional minimum severity level for log entries that the back-end should keep track of.
One of "error" (highest severity), "warning", "info", and "debug" (lowest severity).
Expand All @@ -2656,13 +2651,7 @@ def create_job(
# TODO: avoid using all kwargs as format_options
# TODO #278 centralize download/create_job/execute_job logic in DataCube, VectorCube, MlModel, ...
if auto_add_save_result:
res = _ensure_save_result(
cube=self,
format=out_format,
options=format_options or None,
default_format=self._DEFAULT_RASTER_FORMAT,
method="DataCube.create_job()",
)
res = self._auto_save_result(format=out_format, options=format_options)
else:
res = self

Expand Down
50 changes: 20 additions & 30 deletions openeo/rest/vectorcube.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from openeo.rest._datacube import (
THIS,
UDF,
_ensure_save_result,
_ProcessGraphAbstraction,
build_child_callback,
)
Expand Down Expand Up @@ -194,7 +193,6 @@ def run_udf(

@openeo_process
def save_result(self, format: Union[str, None] = "GeoJSON", options: dict = None) -> SaveResult:
# TODO #401: guard against duplicate save_result nodes?
pg = self._build_pgnode(
process_id="save_result",
arguments={
Expand All @@ -206,6 +204,17 @@ def save_result(self, format: Union[str, None] = "GeoJSON", options: dict = None
)
return SaveResult(pg, connection=self._connection)

def _auto_save_result(
self,
format: Optional[str] = None,
outputfile: Optional[Union[str, pathlib.Path]] = None,
options: Optional[dict] = None,
) -> SaveResult:
return self.save_result(
format=format or (guess_format(outputfile) if outputfile else None) or self._DEFAULT_VECTOR_FORMAT,
options=options,
)

def execute(self, *, validate: Optional[bool] = None) -> dict:
"""Executes the process graph."""
return self._connection.execute(self.flat_graph(), validate=validate)
Expand All @@ -230,7 +239,7 @@ def download(
:param options: (optional) additional output format options.
:param validate: Optional toggle to enable/prevent validation of the process graphs before execution
(overruling the connection's ``auto_validate`` setting).
:param auto_add_save_result: Automatically add a ``save_result`` node to the process graph if there is none yet.
:param auto_add_save_result: Automatically add a ``save_result`` node to the process graph.
.. versionchanged:: 0.21.0
When not specified explicitly, output format is guessed from output file extension.
Expand All @@ -240,14 +249,7 @@ def download(
"""
# TODO #278 centralize download/create_job/execute_job logic in DataCube, VectorCube, MlModel, ...
if auto_add_save_result:
res = _ensure_save_result(
cube=self,
format=format,
options=options,
weak_format=guess_format(outputfile) if outputfile else None,
default_format=self._DEFAULT_VECTOR_FORMAT,
method="VectorCube.download()",
)
res = self._auto_save_result(format=format, outputfile=outputfile, options=options)
else:
res = self
return self._connection.download(res.flat_graph(), outputfile=outputfile, validate=validate)
Expand Down Expand Up @@ -287,7 +289,7 @@ def execute_batch(
:param format_options: (optional) additional output format options
:param validate: Optional toggle to enable/prevent validation of the process graphs before execution
(overruling the connection's ``auto_validate`` setting).
:param auto_add_save_result: Automatically add a ``save_result`` node to the process graph if there is none yet.
:param auto_add_save_result: Automatically add a ``save_result`` node to the process graph.
:param show_error_logs: whether to automatically print error logs when the batch job failed.
:param log_level: Optional minimum severity level for log entries that the back-end should keep track of.
One of "error" (highest severity), "warning", "info", and "debug" (lowest severity).
Expand All @@ -310,14 +312,7 @@ def execute_batch(
Added argument ``log_level``.
"""
if auto_add_save_result:
res = _ensure_save_result(
cube=self,
format=out_format,
options=format_options,
weak_format=guess_format(outputfile) if outputfile else None,
default_format=self._DEFAULT_VECTOR_FORMAT,
method="VectorCube.execute_batch()",
)
res = self._auto_save_result(format=out_format, outputfile=outputfile, options=format_options)
create_kwargs = {}
else:
res = self
Expand Down Expand Up @@ -373,7 +368,7 @@ def create_job(
:param format_options: String Parameters for the job result format
:param validate: Optional toggle to enable/prevent validation of the process graphs before execution
(overruling the connection's ``auto_validate`` setting).
:param auto_add_save_result: Automatically add a ``save_result`` node to the process graph if there is none yet.
:param auto_add_save_result: Automatically add a ``save_result`` node to the process graph.
:param log_level: Optional minimum severity level for log entries that the back-end should keep track of.
One of "error" (highest severity), "warning", "info", and "debug" (lowest severity).
Expand All @@ -387,17 +382,12 @@ def create_job(
"""
# TODO: avoid using all kwargs as format_options
# TODO #278 centralize download/create_job/execute_job logic in DataCube, VectorCube, MlModel, ...
cube = self
if auto_add_save_result:
cube = _ensure_save_result(
cube=cube,
format=out_format,
options=format_options or None,
default_format=self._DEFAULT_VECTOR_FORMAT,
method="VectorCube.create_job()",
)
res = self._auto_save_result(format=out_format, options=format_options or None)
else:
res = self
return self._connection.create_job(
process_graph=cube.flat_graph(),
process_graph=res.flat_graph(),
title=title,
description=description,
plan=plan,
Expand Down

0 comments on commit 3302d87

Please sign in to comment.