Skip to content

Commit db94bb9

Browse files
committed
preserve nanosecond resolution when encoding/decoding times.
1 parent 1043a9e commit db94bb9

File tree

3 files changed

+79
-21
lines changed

3 files changed

+79
-21
lines changed

xarray/backends/netcdf3.py

+7
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,13 @@ def coerce_nc3_dtype(arr):
6262
dtype = str(arr.dtype)
6363
if dtype in _nc3_dtype_coercions:
6464
new_dtype = _nc3_dtype_coercions[dtype]
65+
# check if this looks like a time with NaT
66+
# and transform to float64
67+
if np.issubdtype(dtype, np.int64):
68+
mask = arr == np.iinfo(np.int64).min
69+
if mask.any():
70+
arr = np.where(mask, np.nan, arr)
71+
return arr
6572
# TODO: raise a warning whenever casting the data-type instead?
6673
cast_arr = arr.astype(new_dtype)
6774
if not (cast_arr == arr).all():

xarray/coding/times.py

+52-18
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,20 @@ def _unpack_netcdf_time_units(units: str) -> tuple[str, str]:
171171
return delta_units, ref_date
172172

173173

174+
def _unpack_delta_ref_date(units):
175+
# same us _unpack_netcdf_time_units but finalizes ref_date for
176+
# processing in encode_cf_datetime
177+
delta, _ref_date = _unpack_netcdf_time_units(units)
178+
# TODO: the strict enforcement of nanosecond precision Timestamps can be
179+
# relaxed when addressing GitHub issue #7493.
180+
ref_date = nanosecond_precision_timestamp(_ref_date)
181+
# If the ref_date Timestamp is timezone-aware, convert to UTC and
182+
# make it timezone-naive (GH 2649).
183+
if ref_date.tz is not None:
184+
ref_date = ref_date.tz_convert(None)
185+
return delta, ref_date
186+
187+
174188
def _decode_cf_datetime_dtype(
175189
data, units: str, calendar: str, use_cftime: bool | None
176190
) -> np.dtype:
@@ -251,9 +265,12 @@ def _decode_datetime_with_pandas(
251265

252266
# Cast input ordinals to integers of nanoseconds because pd.to_timedelta
253267
# works much faster when dealing with integers (GH 1399).
254-
flat_num_dates_ns_int = (flat_num_dates * _NS_PER_TIME_DELTA[delta]).astype(
255-
np.int64
256-
)
268+
# properly handle NaN/NaT to prevent casting NaN to int
269+
nan = np.isnan(flat_num_dates) | (flat_num_dates == np.iinfo(np.int64).min)
270+
flat_num_dates = flat_num_dates * _NS_PER_TIME_DELTA[delta]
271+
flat_num_dates_ns_int = np.zeros_like(flat_num_dates, dtype=np.int64)
272+
flat_num_dates_ns_int[nan] = np.iinfo(np.int64).min
273+
flat_num_dates_ns_int[~nan] = flat_num_dates[~nan].astype(np.int64)
257274

258275
# Use pd.to_timedelta to safely cast integer values to timedeltas,
259276
# and add those to a Timestamp to safely produce a DatetimeIndex. This
@@ -575,6 +592,9 @@ def _should_cftime_be_used(
575592

576593
def _cleanup_netcdf_time_units(units: str) -> str:
577594
delta, ref_date = _unpack_netcdf_time_units(units)
595+
delta = delta.lower()
596+
if not delta.endswith("s"):
597+
delta = f"{delta}s"
578598
try:
579599
units = f"{delta} since {format_timestamp(ref_date)}"
580600
except (OutOfBoundsDatetime, ValueError):
@@ -635,32 +655,41 @@ def encode_cf_datetime(
635655
"""
636656
dates = np.asarray(dates)
637657

658+
data_units = infer_datetime_units(dates)
659+
638660
if units is None:
639-
units = infer_datetime_units(dates)
661+
units = data_units
640662
else:
641663
units = _cleanup_netcdf_time_units(units)
642664

643665
if calendar is None:
644666
calendar = infer_calendar_name(dates)
645667

646-
delta, _ref_date = _unpack_netcdf_time_units(units)
647668
try:
648669
if not _is_standard_calendar(calendar) or dates.dtype.kind == "O":
649670
# parse with cftime instead
650671
raise OutOfBoundsDatetime
651672
assert dates.dtype == "datetime64[ns]"
652673

674+
delta, ref_date = _unpack_delta_ref_date(units)
653675
delta_units = _netcdf_to_numpy_timeunit(delta)
654676
time_delta = np.timedelta64(1, delta_units).astype("timedelta64[ns]")
655677

656-
# TODO: the strict enforcement of nanosecond precision Timestamps can be
657-
# relaxed when addressing GitHub issue #7493.
658-
ref_date = nanosecond_precision_timestamp(_ref_date)
659-
660-
# If the ref_date Timestamp is timezone-aware, convert to UTC and
661-
# make it timezone-naive (GH 2649).
662-
if ref_date.tz is not None:
663-
ref_date = ref_date.tz_convert(None)
678+
# check if times can be represented with given units
679+
if data_units != units:
680+
data_delta, data_ref_date = _unpack_delta_ref_date(data_units)
681+
needed_delta = _infer_time_units_from_diff(
682+
(data_ref_date - ref_date).to_timedelta64()
683+
)
684+
needed_time_delta = np.timedelta64(
685+
1, _netcdf_to_numpy_timeunit(needed_delta)
686+
).astype("timedelta64[ns]")
687+
if needed_delta != delta and time_delta > needed_time_delta:
688+
warnings.warn(
689+
f"Times can't be serialized faithfully with requested units {units!r}. "
690+
f"Resolution of {needed_delta!r} needed. "
691+
f"Serializing timeseries to floating point."
692+
)
664693

665694
# Wrap the dates in a DatetimeIndex to do the subtraction to ensure
666695
# an OverflowError is raised if the ref_date is too far away from
@@ -670,8 +699,12 @@ def encode_cf_datetime(
670699

671700
# Use floor division if time_delta evenly divides all differences
672701
# to preserve integer dtype if possible (GH 4045).
673-
if np.all(time_deltas % time_delta == np.timedelta64(0, "ns")):
674-
num = time_deltas // time_delta
702+
# NaT prevents us from using datetime64 directly, but we can safely coerce
703+
# to int64 in presence of NaT, so we just dropna before check (GH 7817).
704+
if np.all(time_deltas.dropna() % time_delta == np.timedelta64(0, "ns")):
705+
# calculate int64 floor division
706+
num = time_deltas // time_delta.astype(np.int64)
707+
num = num.astype(np.int64, copy=False)
675708
else:
676709
num = time_deltas / time_delta
677710
num = num.values.reshape(dates.shape)
@@ -704,9 +737,10 @@ def encode(self, variable: Variable, name: T_Name = None) -> Variable:
704737
) or contains_cftime_datetimes(variable):
705738
dims, data, attrs, encoding = unpack_for_encoding(variable)
706739

707-
(data, units, calendar) = encode_cf_datetime(
708-
data, encoding.pop("units", None), encoding.pop("calendar", None)
709-
)
740+
units = encoding.pop("units", None)
741+
calendar = encoding.pop("calendar", None)
742+
(data, units, calendar) = encode_cf_datetime(data, units, calendar)
743+
710744
safe_setitem(attrs, "units", units, name=name)
711745
safe_setitem(attrs, "calendar", calendar, name=name)
712746

xarray/coding/variables.py

+20-3
Original file line numberDiff line numberDiff line change
@@ -236,19 +236,32 @@ def encode(self, variable: Variable, name: T_Name = None):
236236
f"Variable {name!r} has conflicting _FillValue ({fv}) and missing_value ({mv}). Cannot encode data."
237237
)
238238

239+
# special case DateTime to properly handle NaT
240+
is_date = "since" in attrs.get("units", "")
241+
239242
if fv_exists:
240243
# Ensure _FillValue is cast to same dtype as data's
241244
encoding["_FillValue"] = dtype.type(fv)
242245
fill_value = pop_to(encoding, attrs, "_FillValue", name=name)
243246
if not pd.isnull(fill_value):
244-
data = duck_array_ops.fillna(data, fill_value)
247+
if is_date:
248+
data = duck_array_ops.where(
249+
data != np.iinfo(np.int64).min, data, fill_value
250+
)
251+
else:
252+
data = duck_array_ops.fillna(data, fill_value)
245253

246254
if mv_exists:
247255
# Ensure missing_value is cast to same dtype as data's
248256
encoding["missing_value"] = dtype.type(mv)
249257
fill_value = pop_to(encoding, attrs, "missing_value", name=name)
250258
if not pd.isnull(fill_value) and not fv_exists:
251-
data = duck_array_ops.fillna(data, fill_value)
259+
if is_date:
260+
data = duck_array_ops.where(
261+
data != np.iinfo(np.int64).min, data, fill_value
262+
)
263+
else:
264+
data = duck_array_ops.fillna(data, fill_value)
252265

253266
return Variable(dims, data, attrs, encoding, fastpath=True)
254267

@@ -275,7 +288,11 @@ def decode(self, variable: Variable, name: T_Name = None):
275288
stacklevel=3,
276289
)
277290

278-
dtype, decoded_fill_value = dtypes.maybe_promote(data.dtype)
291+
# special case DateTime to properly handle NaT
292+
if "since" in str(attrs.get("units", "")):
293+
dtype, decoded_fill_value = np.int64, np.iinfo(np.int64).min
294+
else:
295+
dtype, decoded_fill_value = dtypes.maybe_promote(data.dtype)
279296

280297
if encoded_fill_values:
281298
transform = partial(

0 commit comments

Comments
 (0)