Skip to content

Commit 77efa85

Browse files
committed
Fixes #228 add ProcessBuilder array_element(data, n) through data[n] syntax
1 parent e0a58e8 commit 77efa85

File tree

4 files changed

+95
-2
lines changed

4 files changed

+95
-2
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1515
- Add `DataCube.dimension_labels()` (EP-4008)
1616
- Add `Connection.load_result()` (EP-4008)
1717
- Add proper support for child callbacks in `fit_curve` and `predict_curve` ([#229](https://github.com/Open-EO/openeo-python-client/issues/229))
18+
- `ProcessBuilder`: Add support for `array_element(data, n)` through `data[n]` syntax ([#228](https://github.com/Open-EO/openeo-python-client/issues/228))
1819

1920

2021
### Changed

openeo/internal/processes/generator.py

+7
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ def collect_processes(sources: List[Union[Path, str]]) -> List[Process]:
107107

108108
def generate_process_py(processes: List[Process], output=sys.stdout, argv=None):
109109
oo_src = textwrap.dedent("""
110+
import builtins
110111
from openeo.internal.processes.builder import ProcessBuilderBase, UNSET
111112
112113
@@ -142,6 +143,12 @@ def __neg__(self) -> 'ProcessBuilder':
142143
def __pow__(self, other) -> 'ProcessBuilder':
143144
return self.power(other)
144145
146+
def __getitem__(self, key) -> 'ProcessBuilder':
147+
if isinstance(key, builtins.int):
148+
return self.array_element(index=key)
149+
else:
150+
return self.array_element(label=key)
151+
145152
""")
146153
fun_src = textwrap.dedent("""
147154
# Public shortcut

openeo/processes.py

+7
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# Used command line arguments:
55
# openeo/internal/processes/generator.py ../openeo-processes/ ../openeo-processes/proposals/ --output openeo/processes.py
66

7+
import builtins
78
from openeo.internal.processes.builder import ProcessBuilderBase, UNSET
89

910

@@ -39,6 +40,12 @@ def __neg__(self) -> 'ProcessBuilder':
3940
def __pow__(self, other) -> 'ProcessBuilder':
4041
return self.power(other)
4142

43+
def __getitem__(self, key) -> 'ProcessBuilder':
44+
if isinstance(key, builtins.int):
45+
return self.array_element(index=key)
46+
else:
47+
return self.array_element(label=key)
48+
4249
def absolute(self) -> 'ProcessBuilder':
4350
"""
4451
Absolute value

tests/rest/datacube/test_processbuilder.py

+80-2
Original file line numberDiff line numberDiff line change
@@ -275,10 +275,10 @@ def test_apply_dimension_bandmath_lambda(con100):
275275

276276

277277
def test_apply_dimension_time_to_bands(con100):
278-
from openeo.processes import array_concat,quantiles,sd,mean
278+
from openeo.processes import array_concat, quantiles, sd, mean
279279
im = con100.load_collection("S2")
280280
res = im.apply_dimension(
281-
process=lambda d: array_concat(quantiles(d,[0.25,0.5,0.75]), [sd(d),mean(d)]),
281+
process=lambda d: array_concat(quantiles(d, [0.25, 0.5, 0.75]), [sd(d), mean(d)]),
282282
dimension="t",
283283
target_dimension="bands"
284284
)
@@ -382,3 +382,81 @@ def test_merge_cubes_max_lambda(con100):
382382
im2 = con100.load_collection("MASK")
383383
res = im1.merge_cubes(other=im2, overlap_resolver=lambda data: data.max())
384384
assert res.graph == load_json_resource('data/1.0.0/merge_cubes_max.json')
385+
386+
387+
def test_getitem_array_element_index(con100):
388+
im = con100.load_collection("S2")
389+
390+
def callback(data: ProcessBuilder):
391+
return data[1] + data[2]
392+
393+
res = im.reduce_dimension(reducer=callback, dimension="bands")
394+
395+
assert res.flat_graph() == {
396+
"loadcollection1": {
397+
"process_id": "load_collection",
398+
"arguments": {"id": "S2", "spatial_extent": None, "temporal_extent": None},
399+
},
400+
"reducedimension1": {
401+
"process_id": "reduce_dimension",
402+
"arguments": {
403+
"data": {"from_node": "loadcollection1"},
404+
"dimension": "bands",
405+
"reducer": {"process_graph": {
406+
"arrayelement1": {
407+
"process_id": "array_element",
408+
"arguments": {"data": {"from_parameter": "data"}, "index": 1},
409+
},
410+
"arrayelement2": {
411+
"process_id": "array_element",
412+
"arguments": {"data": {"from_parameter": "data"}, "index": 2},
413+
},
414+
"add1": {
415+
"process_id": "add",
416+
"arguments": {"x": {"from_node": "arrayelement1"}, "y": {"from_node": "arrayelement2"}},
417+
"result": True
418+
},
419+
}}
420+
},
421+
"result": True
422+
}
423+
}
424+
425+
426+
def test_getitem_array_element_label(con100):
427+
im = con100.load_collection("S2")
428+
429+
def callback(data: ProcessBuilder):
430+
return data["red"] + data["green"]
431+
432+
res = im.reduce_dimension(reducer=callback, dimension="bands")
433+
434+
assert res.flat_graph() == {
435+
"loadcollection1": {
436+
"process_id": "load_collection",
437+
"arguments": {"id": "S2", "spatial_extent": None, "temporal_extent": None},
438+
},
439+
"reducedimension1": {
440+
"process_id": "reduce_dimension",
441+
"arguments": {
442+
"data": {"from_node": "loadcollection1"},
443+
"dimension": "bands",
444+
"reducer": {"process_graph": {
445+
"arrayelement1": {
446+
"process_id": "array_element",
447+
"arguments": {"data": {"from_parameter": "data"}, "label": "red"},
448+
},
449+
"arrayelement2": {
450+
"process_id": "array_element",
451+
"arguments": {"data": {"from_parameter": "data"}, "label": "green"},
452+
},
453+
"add1": {
454+
"process_id": "add",
455+
"arguments": {"x": {"from_node": "arrayelement1"}, "y": {"from_node": "arrayelement2"}},
456+
"result": True
457+
},
458+
}}
459+
},
460+
"result": True
461+
}
462+
}

0 commit comments

Comments
 (0)