Skip to content

Commit

Permalink
Merge pull request #363 from dcherian/grib-fixes
Browse files Browse the repository at this point in the history
Improvements to scan_grib
  • Loading branch information
martindurant authored Sep 30, 2023
2 parents 51987d9 + 34610aa commit 3fb9147
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 25 deletions.
9 changes: 4 additions & 5 deletions kerchunk/codecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,14 +99,13 @@ def decode(self, buf, out=None):
mid = eccodes.codes_new_from_message(bytes(buf))
try:
data = eccodes.codes_get_array(mid, var)
if var == "values" and eccodes.codes_get_string(mid, "missingValue"):
data[
data == float(eccodes.codes_get_string(mid, "missingValue"))
] = np.nan
missingValue = eccodes.codes_get_string(mid, "missingValue")
if var == "values" and missingValue:
data[data == float(missingValue)] = np.nan
if out is not None:
return numcodecs.compat.ndarray_copy(data, out)
else:
return data.astype(dt)
return data.astype(dt, copy=False)

finally:
eccodes.codes_release(mid)
Expand Down
88 changes: 77 additions & 11 deletions kerchunk/grib2.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,15 @@
from kerchunk.utils import class_factory, _encode_for_JSON
from kerchunk.codecs import GRIBCodec


# cfgrib copies over certain GRIB attributes
# but renames them to CF-compliant values
ATTRS_TO_COPY_OVER = {
"long_name": "GRIB_name",
"units": "GRIB_units",
"standard_name": "GRIB_cfName",
}

logger = logging.getLogger("grib2-to-zarr")


Expand Down Expand Up @@ -127,6 +136,11 @@ def scan_grib(
storage_options = storage_options or {}
logger.debug(f"Open {url}")

# This is hardcoded a lot in cfgrib!
# valid_time is added if "time" and "step" are present in time_dims
# These are present by default
# TIME_DIMS = ["step", "time", "valid_time"]

out = []
with fsspec.open(url, "rb", **storage_options) as f:
logger.debug(f"File {url}")
Expand All @@ -135,6 +149,28 @@ def scan_grib(
mid = eccodes.codes_new_from_message(data)
m = cfgrib.cfmessage.CfMessage(mid)

# It would be nice to just have a list of valid keys
# There does not seem to be a nice API for this
# 1. message_grib_keys returns keys coded in the message
# 2. There exist "computed" keys, that are functions applied on the data
# 3. There are also aliases!
# e.g. "number" is an alias of "perturbationNumber", and cfgrib uses this alias
# So we stick to checking membership in 'm', which ends up doing
# a lot of reads.
message_keys = set(m.message_grib_keys())
# The choices here copy cfgrib :(
# message_keys.update(cfgrib.dataset.INDEX_KEYS)
# message_keys.update(TIME_DIMS)
# print("totalNumber" in cfgrib.dataset.INDEX_KEYS)
# Adding computed keys adds a lot that isn't added by cfgrib
# message_keys.extend(m.computed_keys)

shape = (m["Ny"], m["Nx"])
# thank you, gribscan
native_type = eccodes.codes_get_native_type(m.codes_id, "values")
data_size = eccodes.codes_get_size(m.codes_id, "values")
coordinates = []

good = True
for k, v in (filter or {}).items():
if k not in m:
Expand All @@ -149,33 +185,53 @@ def scan_grib(

z = zarr.open_group(store)
global_attrs = {
k: m[k] for k in cfgrib.dataset.GLOBAL_ATTRIBUTES_KEYS if k in m
f"GRIB_{k}": m[k]
for k in cfgrib.dataset.GLOBAL_ATTRIBUTES_KEYS
if k in m
}
if "GRIB_centreDescription" in global_attrs:
# follow CF compliant renaming from cfgrib
global_attrs["institution"] = global_attrs["GRIB_centreDescription"]
z.attrs.update(global_attrs)

vals = m["values"].reshape((m["Ny"], m["Nx"]))
if data_size < inline_threshold:
# read the data
vals = m["values"].reshape(shape)
else:
# dummy array to match the required interface
vals = np.empty(shape, dtype=native_type)
assert vals.size == data_size

attrs = {
k: m[k]
# Follow cfgrib convention and rename key
f"GRIB_{k}": m[k]
for k in cfgrib.dataset.DATA_ATTRIBUTES_KEYS
+ cfgrib.dataset.DATA_TIME_KEYS
+ cfgrib.dataset.EXTRA_DATA_ATTRIBUTES_KEYS
+ cfgrib.dataset.GRID_TYPE_MAP.get(m["gridType"], [])
if k in m
}
for k, v in ATTRS_TO_COPY_OVER.items():
if v in attrs:
attrs[k] = attrs[v]

# try to use cfVarName if available,
# otherwise use the grib shortName
varName = m["cfVarName"]
if varName in ("undef", "unknown"):
varName = m["shortName"]
_store_array(store, z, vals, varName, inline_threshold, offset, size, attrs)
if "typeOfLevel" in m and "level" in m:
if "typeOfLevel" in message_keys and "level" in message_keys:
name = m["typeOfLevel"]
data = np.array([m["level"]])
coordinates.append(name)
# convert to numpy scalar, so that .tobytes can be used for inlining
# dtype=float is hardcoded in cfgrib
data = np.array(m["level"], dtype=float)[()]
try:
attrs = cfgrib.dataset.COORD_ATTRS[name]
except KeyError:
logger.debug(f"Couldn't find coord {name} in dataset")
attrs = {}
attrs["_ARRAY_DIMENSIONS"] = [name]
attrs["_ARRAY_DIMENSIONS"] = []
_store_array(
store, z, data, name, inline_threshold, offset, size, attrs
)
Expand All @@ -190,12 +246,12 @@ def scan_grib(
coord2 = {"latitude": "latitudes", "longitude": "longitudes"}.get(
coord, coord
)
if coord2 in m:
x = m[coord2]
else:
x = m.get(coord2)
if x is None:
continue
coordinates.append(coord)
inline_extra = 0
if isinstance(x, np.ndarray) and x.size == vals.size:
if isinstance(x, np.ndarray) and x.size == data_size:
if (
m["gridType"]
in cfgrib.dataset.GRID_TYPES_2D_NON_DIMENSION_COORDS
Expand All @@ -208,7 +264,15 @@ def scan_grib(
x = x.reshape(vals.shape)[:, 0].copy()
elif coord == "longitude":
x = x.reshape(vals.shape)[0].copy()
# force inlining of x/y/latitude/longitude coordinates.
# since these are derived from analytic formulae
# and are not stored in the message
inline_extra = x.nbytes + 1
elif np.isscalar(x):
# convert python scalars to numpy scalar
# so that .tobytes can be used for inlining
x = np.array(x)[()]
dims = []
else:
x = np.array([x])
dims = [coord]
Expand All @@ -224,6 +288,8 @@ def scan_grib(
attrs,
)
z[coord].attrs["_ARRAY_DIMENSIONS"] = dims
if coordinates:
z.attrs["coordinates"] = " ".join(coordinates)

out.append(
{
Expand Down
19 changes: 10 additions & 9 deletions kerchunk/tests/test_grib.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def test_one():
backend_kwargs={"consolidated": False, "storage_options": {"fo": out[0]}},
)

assert ds.attrs["centre"] == "cwao"
assert ds.attrs["GRIB_centre"] == "cwao"
ds2 = xr.open_dataset(fn, engine="cfgrib", backend_kwargs={"indexpath": ""})

for var in ["latitude", "longitude", "unknown", "isobaricInhPa", "time"]:
Expand All @@ -43,12 +43,13 @@ def _fetch_first(url):
[
"s3://noaa-hrrr-bdp-pds/hrrr.20140730/conus/hrrr.t23z.wrfsubhf08.grib2",
"s3://noaa-gefs-pds/gefs.20221011/00/atmos/pgrb2ap5/gep01.t00z.pgrb2a.0p50.f570",
"s3://noaa-gefs-retrospective/GEFSv12/reforecast/2000/2000010100/c00/Days:10-16/acpcp_sfc_2000010100_c00.grib2",
],
)
def test_archives(tmpdir, url):
grib = GribToZarr(url, storage_options={"anon": True}, skip=1)
out = grib.translate()[0]
ds = xr.open_dataset(
ours = xr.open_dataset(
"reference://",
engine="zarr",
backend_kwargs={
Expand All @@ -66,11 +67,11 @@ def test_archives(tmpdir, url):
with open(fn, "wb") as f:
f.write(data)

ds2 = cfgrib.open_dataset(fn)
dims = list(ds.dims)
theirs = cfgrib.open_dataset(fn)
if "hrrr" in url:
assert (ds.refc == ds2.refc).all()
assert dims.index("y") < dims.index("x")
else:
assert np.allclose(ds.gh, ds2.gh)
assert dims[0] == "latitude"
# for some reason, cfgrib reads `step` as 7.25 hours
# while grib_ls and kerchunk reads `step` as 425 hours.
ours = ours.drop_vars("step")
theirs = theirs.drop_vars("step")

xr.testing.assert_allclose(ours, theirs)

0 comments on commit 3fb9147

Please sign in to comment.