Skip to content

Commit

Permalink
Use xarray-beam to append derived Replay variables (#24)
Browse files Browse the repository at this point in the history
* this works for really small stuff

* some debugging stuff and storage options kwargs

* this worked for 1 degree dataset

* oops

* num_threads on chunkstozarr capable, but has not effect

* set input chunks to 127, full column... finally quarter degree is possible

* submitted subsampled quarter degree append job

* note about missinf fields

* geopotential verified

* used this to append 1degree and 1/4 degree static variables

* cleanup

* typo
  • Loading branch information
timothyas authored Oct 16, 2024
1 parent 83f03a6 commit c5582bf
Show file tree
Hide file tree
Showing 8 changed files with 689 additions and 10 deletions.
16 changes: 16 additions & 0 deletions examples/replay/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,22 @@ python move_one_degree.py
Currently only the FV3 data is being moved, and it can be found
[in this zarr store](https://console.cloud.google.com/storage/browser/noaa-ufs-gefsv13replay/ufs-hr1/1.00-degree/03h-freq/zarr/fv3.zarr).


### Missing data

There are some time stamps that are missing (the 2D fields are all NaNs for those time steps) for the following fields,
only in the 1 degree dataset:
- hgtsfc
- sltyp
- weasd

Since hgtsfc and sltyp are static, this doesn't matter, we can just grab the
first timestamp. If 1 degree weasd is
needed in the future then we'll need to move this again.
Also, the field `hgtsfc_static` has been added which rightfully does not have
the time dimension.


## 1/4 Degree Data

[move_quarter_degree.py](move_quarter_degree.py)
Expand Down
93 changes: 93 additions & 0 deletions examples/replay/append_geopotential.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
"""Compute geopotential from Replay dataset,
append it back to the original or store it locally depending on inputs
Note: this was heavily borrowed from this xarray-beam example:
https://github.com/google/xarray-beam/blob/main/examples/era5_climatology.py
"""

from typing import Tuple

from absl import app
from absl import flags
import logging
import apache_beam as beam
from apache_beam.options.pipeline_options import PipelineOptions
from apache_beam.runners.dask.dask_runner import DaskRunner
import numpy as np
import xarray as xr
import xarray_beam as xbeam

from ufs2arco import Layers2Pressure
from verify_geopotential import setup_log
from localzarr import ChunksToZarr


INPUT_PATH = flags.DEFINE_string('input_path', None, help='Input Zarr path')
OUTPUT_PATH = flags.DEFINE_string('output_path', None, help='Output Zarr path')
RUNNER = flags.DEFINE_string('runner', "DirectRunner", 'beam.runners.Runner')
TIME_LENGTH = flags.DEFINE_integer('time_length', None, help="Number of time slices to use for debugging")
NUM_WORKERS = flags.DEFINE_integer('num_workers', None, help="Number of workers for the runner")
NUM_THREADS = flags.DEFINE_integer('num_threads', None, help="Passed to ChunksToZarr")

def calc_geopotential(
key: xbeam.Key,
xds: xr.Dataset,
) -> Tuple[xbeam.Key, xr.Dataset]:
"""Return dataset with geopotential field, that's it"""

lp = Layers2Pressure()
xds = xds.rename({"pfull": "level"})
prsl = lp.calc_layer_mean_pressure(xds["pressfc"], xds["tmp"], xds["spfh"], xds["delz"])

newds = xr.Dataset()
newds["geopotential"] = lp.calc_geopotential(xds["hgtsfc"], xds["delz"])
newds = newds.rename({"level": "pfull"})
return key, newds

def main(argv):

setup_log()
path = INPUT_PATH.value
kwargs = {}

if "gs://" in path or "gcs://" in path:
kwargs["storage_options"] = {"token": "anon"}

source_dataset, source_chunks = xbeam.open_zarr(path, **kwargs)
source_dataset = source_dataset.drop_vars(["cftime", "ftime"])
if TIME_LENGTH.value is not None:
source_dataset = source_dataset.isel(time=slice(int(TIME_LENGTH.value)))

# create template
tds = source_dataset[["tmp"]].rename({"tmp": "geopotential"})
tds["geopotential"].attrs = {
"units": "m**2 / s**2",
"description": "Diagnosed using ufs2arco.Layers2Pressure.calc_geopotential",
"long_name": "geopotential height",
}
input_chunks = {k: source_chunks[k] if k != "pfull" else 127 for k in tds["geopotential"].dims}
output_chunks = {k: source_chunks[k] for k in tds["geopotential"].dims}

template = xbeam.make_template(tds)
storage_options = None
if "gs://" in OUTPUT_PATH.value:
storage_options = {"token": "/contrib/Tim.Smith/.gcs/replay-service-account.json"}

pipeline_kwargs = {}
if NUM_WORKERS.value is not None:
pipeline_kwargs["options"]=PipelineOptions(
direct_num_workers=NUM_WORKERS.value,
)

with beam.Pipeline(runner=RUNNER.value, argv=argv, **pipeline_kwargs) as root:
(
root
| xbeam.DatasetToChunks(source_dataset, input_chunks, num_threads=NUM_THREADS.value)
| beam.MapTuple(calc_geopotential)
| ChunksToZarr(OUTPUT_PATH.value, template, output_chunks, num_threads=NUM_THREADS.value, storage_options=storage_options)
)

logging.info("Done")

if __name__ == "__main__":
app.run(main)
104 changes: 104 additions & 0 deletions examples/replay/append_static_vars.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
"""Compute static variables surface orography and land/sea mask,
append it back to the original or store it locally depending on inputs
Note: this was heavily borrowed from this xarray-beam example:
https://github.com/google/xarray-beam/blob/main/examples/era5_climatology.py
"""

from typing import Tuple

from absl import app
from absl import flags
import logging
import apache_beam as beam
from apache_beam.options.pipeline_options import PipelineOptions
from apache_beam.runners.dask.dask_runner import DaskRunner
import numpy as np
import xarray as xr
import xarray_beam as xbeam

from ufs2arco import Layers2Pressure
from verify_geopotential import setup_log
from localzarr import ChunksToZarr


INPUT_PATH = flags.DEFINE_string('input_path', None, help='Input Zarr path')
OUTPUT_PATH = flags.DEFINE_string('output_path', None, help='Output Zarr path')
RUNNER = flags.DEFINE_string('runner', "DirectRunner", 'beam.runners.Runner')
NUM_WORKERS = flags.DEFINE_integer('num_workers', None, help="Number of workers for the runner")
NUM_THREADS = flags.DEFINE_integer('num_threads', None, help="Passed to ChunksToZarr")

def calc_static_vars(
key: xbeam.Key,
xds: xr.Dataset,
) -> Tuple[xbeam.Key, xr.Dataset]:
"""Return dataset with geopotential field, that's it"""

newds = xr.Dataset()
hgtsfc = xds["hgtsfc"] if "time" not in xds["hgtsfc"] else xds["hgtsfc"].isel(time=0)
land = xds["land"] if "time" not in xds["land"] else xds["land"].isel(time=0)

newds["hgtsfc_static"] = hgtsfc

newds["land_static"] = xr.where(
land == 1,
1,
0,
).astype(np.int32)
newds["hgtsfc_static"].attrs = xds["hgtsfc"].attrs.copy()
newds["land_static"].attrs = {
"long_name": "static land-sea/ice mask",
"description": "1 = land, 0 = not land",
}

for k in ["time", "cftime", "ftime", "pfull"]:
if k in newds:
newds = newds.drop_vars(k)
return key, newds

def main(argv):

setup_log()
path = INPUT_PATH.value
kwargs = {}

if "gs://" in path or "gcs://" in path:
kwargs["storage_options"] = {"token": "anon"}

source_dataset, source_chunks = xbeam.open_zarr(path, **kwargs)
source_dataset = source_dataset[["hgtsfc", "land"]].isel(time=0)
for key in ["time", "cftime", "ftime", "pfull"]:
if key in source_dataset:
source_dataset = source_dataset.drop_vars(key)
if key in source_chunks:
source_chunks.pop(key)

# create template
_, tds = calc_static_vars(None, source_dataset)
#input_chunks = source_chunks.copy()
output_chunks = {k: v for k,v in source_chunks.items() if k not in ("pfull", "time")}
input_chunks=output_chunks.copy()

template = xbeam.make_template(tds)
storage_options = None
if "gs://" in OUTPUT_PATH.value:
storage_options = {"token": "/contrib/Tim.Smith/.gcs/replay-service-account.json"}

pipeline_kwargs = {}
if NUM_WORKERS.value is not None:
pipeline_kwargs["options"]=PipelineOptions(
direct_num_workers=NUM_WORKERS.value,
)

with beam.Pipeline(runner=RUNNER.value, argv=argv, **pipeline_kwargs) as root:
(
root
| xbeam.DatasetToChunks(source_dataset, input_chunks, num_threads=NUM_THREADS.value)
| beam.MapTuple(calc_static_vars)
| ChunksToZarr(OUTPUT_PATH.value, template, output_chunks, num_threads=NUM_THREADS.value, storage_options=storage_options)
)

logging.info("Done")

if __name__ == "__main__":
app.run(main)
5 changes: 5 additions & 0 deletions examples/replay/find_nans.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,8 @@ echo " --- 1.00 Degree ---"
awk '!/ 0 NaN/ && (/NaN/) {print FILENAME, $1, $2, $3, $4}' slurm/verify-1.00-degree/*.out
echo ""
echo ""

echo " --- 0.25 Degree Subsampled ---"
awk '!/ 0 NaN/ && (/NaN/) {print FILENAME, $1, $2, $3, $4}' slurm/verify-0.25-degree-subsampled/*.out
echo ""
echo ""
Loading

0 comments on commit c5582bf

Please sign in to comment.