Skip to content

Commit f01d698

Browse files
authored
Merge pull request #1609 from jhamman/fix/1215
fix to_netcdf append bug (GH1215)
2 parents 3061db6 + 09101d6 commit f01d698

File tree

5 files changed

+94
-87
lines changed

5 files changed

+94
-87
lines changed

doc/io.rst

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,10 @@ for dealing with datasets too big to fit into memory. Instead, xarray integrates
176176
with dask.array (see :ref:`dask`), which provides a fully featured engine for
177177
streaming computation.
178178

179+
It is possible to append or overwrite netCDF variables using the ``mode='a'``
180+
argument. When using this option, all variables in the dataset will be written
181+
to the original netCDF file, regardless if they exist in the original dataset.
182+
179183
.. _io.encoding:
180184

181185
Reading encoded data
@@ -390,7 +394,7 @@ over the network until we look at particular values:
390394

391395
Some servers require authentication before we can access the data. For this
392396
purpose we can explicitly create a :py:class:`~xarray.backends.PydapDataStore`
393-
and pass in a `Requests`__ session object. For example for
397+
and pass in a `Requests`__ session object. For example for
394398
HTTP Basic authentication::
395399

396400
import xarray as xr
@@ -403,7 +407,7 @@ HTTP Basic authentication::
403407
session=session)
404408
ds = xr.open_dataset(store)
405409

406-
`Pydap's cas module`__ has functions that generate custom sessions for
410+
`Pydap's cas module`__ has functions that generate custom sessions for
407411
servers that use CAS single sign-on. For example, to connect to servers
408412
that require NASA's URS authentication::
409413

doc/whats-new.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,10 @@ Bug fixes
295295
the first argument was a numpy variable (:issue:`1588`).
296296
By `Guido Imperiale <https://github.com/crusaderky>`_.
297297

298+
- Fix bug in :py:meth:`~xarray.Dataset.to_netcdf` when writing in append mode
299+
(:issue:`1215`).
300+
By `Joe Hamman <https://github.com/jhamman>`_.
301+
298302
- Fix ``netCDF4`` backend to properly roundtrip the ``shuffle`` encoding option
299303
(:issue:`1606`).
300304
By `Joe Hamman <https://github.com/jhamman>`_.

xarray/backends/common.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -223,8 +223,12 @@ def set_variables(self, variables, check_encoding_set,
223223
for vn, v in iteritems(variables):
224224
name = _encode_variable_name(vn)
225225
check = vn in check_encoding_set
226-
target, source = self.prepare_variable(
227-
name, v, check, unlimited_dims=unlimited_dims)
226+
if vn not in self.variables:
227+
target, source = self.prepare_variable(
228+
name, v, check, unlimited_dims=unlimited_dims)
229+
else:
230+
target, source = self.ds.variables[name], v.data
231+
228232
self.writer.add(source, target)
229233

230234
def set_necessary_dimensions(self, variable, unlimited_dims=None):

xarray/core/dataset.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -974,7 +974,8 @@ def to_netcdf(self, path=None, mode='w', format=None, group=None,
974974
default format becomes NETCDF3_64BIT).
975975
mode : {'w', 'a'}, optional
976976
Write ('w') or append ('a') mode. If mode='w', any existing file at
977-
this location will be overwritten.
977+
this location will be overwritten. If mode='a', existing variables
978+
will be overwritten.
978979
format : {'NETCDF4', 'NETCDF4_CLASSIC', 'NETCDF3_64BIT', 'NETCDF3_CLASSIC'}, optional
979980
File format for the resulting netCDF file:
980981

xarray/tests/test_backends.py

Lines changed: 76 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -128,12 +128,42 @@ class Only32BitTypes(object):
128128

129129
class DatasetIOTestCases(object):
130130
autoclose = False
131+
engine = None
132+
file_format = None
131133

132134
def create_store(self):
133135
raise NotImplementedError
134136

135-
def roundtrip(self, data, **kwargs):
136-
raise NotImplementedError
137+
@contextlib.contextmanager
138+
def roundtrip(self, data, save_kwargs={}, open_kwargs={},
139+
allow_cleanup_failure=False):
140+
with create_tmp_file(
141+
allow_cleanup_failure=allow_cleanup_failure) as path:
142+
self.save(data, path, **save_kwargs)
143+
with self.open(path, **open_kwargs) as ds:
144+
yield ds
145+
146+
@contextlib.contextmanager
147+
def roundtrip_append(self, data, save_kwargs={}, open_kwargs={},
148+
allow_cleanup_failure=False):
149+
with create_tmp_file(
150+
allow_cleanup_failure=allow_cleanup_failure) as path:
151+
for i, key in enumerate(data.variables):
152+
mode = 'a' if i > 0 else 'w'
153+
self.save(data[[key]], path, mode=mode, **save_kwargs)
154+
with self.open(path, **open_kwargs) as ds:
155+
yield ds
156+
157+
# The save/open methods may be overwritten below
158+
def save(self, dataset, path, **kwargs):
159+
dataset.to_netcdf(path, engine=self.engine, format=self.file_format,
160+
**kwargs)
161+
162+
@contextlib.contextmanager
163+
def open(self, path, **kwargs):
164+
with open_dataset(path, engine=self.engine, autoclose=self.autoclose,
165+
**kwargs) as ds:
166+
yield ds
137167

138168
def test_zero_dimensional_variable(self):
139169
expected = create_test_data()
@@ -563,6 +593,23 @@ def test_encoding_same_dtype(self):
563593
self.assertEqual(actual.x.encoding['dtype'], 'f4')
564594
self.assertEqual(ds.x.encoding, {})
565595

596+
def test_append_write(self):
597+
# regression for GH1215
598+
data = create_test_data()
599+
with self.roundtrip_append(data) as actual:
600+
assert_allclose(data, actual)
601+
602+
def test_append_overwrite_values(self):
603+
# regression for GH1215
604+
data = create_test_data()
605+
with create_tmp_file(allow_cleanup_failure=False) as tmp_file:
606+
self.save(data, tmp_file, mode='w')
607+
data['var2'][:] = -999
608+
data['var9'] = data['var2'] * 3
609+
self.save(data[['var2', 'var9']], tmp_file, mode='a')
610+
with self.open(tmp_file) as actual:
611+
assert_allclose(data, actual)
612+
566613

567614
_counter = itertools.count()
568615

@@ -592,6 +639,9 @@ def create_tmp_files(nfiles, suffix='.nc', allow_cleanup_failure=False):
592639

593640
@requires_netCDF4
594641
class BaseNetCDF4Test(CFEncodedDataTest):
642+
643+
engine = 'netcdf4'
644+
595645
def test_open_group(self):
596646
# Create a netCDF file with a dataset stored within a group
597647
with create_tmp_file() as tmp_file:
@@ -813,16 +863,6 @@ def create_store(self):
813863
with backends.NetCDF4DataStore.open(tmp_file, mode='w') as store:
814864
yield store
815865

816-
@contextlib.contextmanager
817-
def roundtrip(self, data, save_kwargs={}, open_kwargs={},
818-
allow_cleanup_failure=False):
819-
with create_tmp_file(
820-
allow_cleanup_failure=allow_cleanup_failure) as tmp_file:
821-
data.to_netcdf(tmp_file, **save_kwargs)
822-
with open_dataset(tmp_file,
823-
autoclose=self.autoclose, **open_kwargs) as ds:
824-
yield ds
825-
826866
def test_variable_order(self):
827867
# doesn't work with scipy or h5py :(
828868
ds = Dataset()
@@ -883,19 +923,13 @@ class NetCDF4ViaDaskDataTestAutocloseTrue(NetCDF4ViaDaskDataTest):
883923

884924
@requires_scipy
885925
class ScipyInMemoryDataTest(CFEncodedDataTest, Only32BitTypes, TestCase):
926+
engine = 'scipy'
927+
886928
@contextlib.contextmanager
887929
def create_store(self):
888930
fobj = BytesIO()
889931
yield backends.ScipyDataStore(fobj, 'w')
890932

891-
@contextlib.contextmanager
892-
def roundtrip(self, data, save_kwargs={}, open_kwargs={},
893-
allow_cleanup_failure=False):
894-
serialized = data.to_netcdf(**save_kwargs)
895-
with open_dataset(serialized, engine='scipy',
896-
autoclose=self.autoclose, **open_kwargs) as ds:
897-
yield ds
898-
899933
def test_to_netcdf_explicit_engine(self):
900934
# regression test for GH1321
901935
Dataset({'foo': 42}).to_netcdf(engine='scipy')
@@ -915,6 +949,8 @@ class ScipyInMemoryDataTestAutocloseTrue(ScipyInMemoryDataTest):
915949

916950
@requires_scipy
917951
class ScipyFileObjectTest(CFEncodedDataTest, Only32BitTypes, TestCase):
952+
engine = 'scipy'
953+
918954
@contextlib.contextmanager
919955
def create_store(self):
920956
fobj = BytesIO()
@@ -925,9 +961,9 @@ def roundtrip(self, data, save_kwargs={}, open_kwargs={},
925961
allow_cleanup_failure=False):
926962
with create_tmp_file() as tmp_file:
927963
with open(tmp_file, 'wb') as f:
928-
data.to_netcdf(f, **save_kwargs)
964+
self.save(data, f, **save_kwargs)
929965
with open(tmp_file, 'rb') as f:
930-
with open_dataset(f, engine='scipy', **open_kwargs) as ds:
966+
with self.open(f, **open_kwargs) as ds:
931967
yield ds
932968

933969
@pytest.mark.skip(reason='cannot pickle file objects')
@@ -941,22 +977,14 @@ def test_pickle_dataarray(self):
941977

942978
@requires_scipy
943979
class ScipyFilePathTest(CFEncodedDataTest, Only32BitTypes, TestCase):
980+
engine = 'scipy'
981+
944982
@contextlib.contextmanager
945983
def create_store(self):
946984
with create_tmp_file() as tmp_file:
947985
with backends.ScipyDataStore(tmp_file, mode='w') as store:
948986
yield store
949987

950-
@contextlib.contextmanager
951-
def roundtrip(self, data, save_kwargs={}, open_kwargs={},
952-
allow_cleanup_failure=False):
953-
with create_tmp_file(
954-
allow_cleanup_failure=allow_cleanup_failure) as tmp_file:
955-
data.to_netcdf(tmp_file, engine='scipy', **save_kwargs)
956-
with open_dataset(tmp_file, engine='scipy',
957-
autoclose=self.autoclose, **open_kwargs) as ds:
958-
yield ds
959-
960988
def test_array_attrs(self):
961989
ds = Dataset(attrs={'foo': [[1, 2], [3, 4]]})
962990
with self.assertRaisesRegexp(ValueError, 'must be 1-dimensional'):
@@ -995,24 +1023,16 @@ class ScipyFilePathTestAutocloseTrue(ScipyFilePathTest):
9951023

9961024
@requires_netCDF4
9971025
class NetCDF3ViaNetCDF4DataTest(CFEncodedDataTest, Only32BitTypes, TestCase):
1026+
engine = 'netcdf4'
1027+
file_format = 'NETCDF3_CLASSIC'
1028+
9981029
@contextlib.contextmanager
9991030
def create_store(self):
10001031
with create_tmp_file() as tmp_file:
10011032
with backends.NetCDF4DataStore.open(
10021033
tmp_file, mode='w', format='NETCDF3_CLASSIC') as store:
10031034
yield store
10041035

1005-
@contextlib.contextmanager
1006-
def roundtrip(self, data, save_kwargs={}, open_kwargs={},
1007-
allow_cleanup_failure=False):
1008-
with create_tmp_file(
1009-
allow_cleanup_failure=allow_cleanup_failure) as tmp_file:
1010-
data.to_netcdf(tmp_file, format='NETCDF3_CLASSIC',
1011-
engine='netcdf4', **save_kwargs)
1012-
with open_dataset(tmp_file, engine='netcdf4',
1013-
autoclose=self.autoclose, **open_kwargs) as ds:
1014-
yield ds
1015-
10161036

10171037
class NetCDF3ViaNetCDF4DataTestAutocloseTrue(NetCDF3ViaNetCDF4DataTest):
10181038
autoclose = True
@@ -1021,24 +1041,16 @@ class NetCDF3ViaNetCDF4DataTestAutocloseTrue(NetCDF3ViaNetCDF4DataTest):
10211041
@requires_netCDF4
10221042
class NetCDF4ClassicViaNetCDF4DataTest(CFEncodedDataTest, Only32BitTypes,
10231043
TestCase):
1044+
engine = 'netcdf4'
1045+
file_format = 'NETCDF4_CLASSIC'
1046+
10241047
@contextlib.contextmanager
10251048
def create_store(self):
10261049
with create_tmp_file() as tmp_file:
10271050
with backends.NetCDF4DataStore.open(
10281051
tmp_file, mode='w', format='NETCDF4_CLASSIC') as store:
10291052
yield store
10301053

1031-
@contextlib.contextmanager
1032-
def roundtrip(self, data, save_kwargs={}, open_kwargs={},
1033-
allow_cleanup_failure=False):
1034-
with create_tmp_file(
1035-
allow_cleanup_failure=allow_cleanup_failure) as tmp_file:
1036-
data.to_netcdf(tmp_file, format='NETCDF4_CLASSIC',
1037-
engine='netcdf4', **save_kwargs)
1038-
with open_dataset(tmp_file, engine='netcdf4',
1039-
autoclose=self.autoclose, **open_kwargs) as ds:
1040-
yield ds
1041-
10421054

10431055
class NetCDF4ClassicViaNetCDF4DataTestAutocloseTrue(
10441056
NetCDF4ClassicViaNetCDF4DataTest):
@@ -1049,21 +1061,12 @@ class NetCDF4ClassicViaNetCDF4DataTestAutocloseTrue(
10491061
class GenericNetCDFDataTest(CFEncodedDataTest, Only32BitTypes, TestCase):
10501062
# verify that we can read and write netCDF3 files as long as we have scipy
10511063
# or netCDF4-python installed
1064+
file_format = 'netcdf3_64bit'
10521065

10531066
def test_write_store(self):
10541067
# there's no specific store to test here
10551068
pass
10561069

1057-
@contextlib.contextmanager
1058-
def roundtrip(self, data, save_kwargs={}, open_kwargs={},
1059-
allow_cleanup_failure=False):
1060-
with create_tmp_file(
1061-
allow_cleanup_failure=allow_cleanup_failure) as tmp_file:
1062-
data.to_netcdf(tmp_file, format='netcdf3_64bit', **save_kwargs)
1063-
with open_dataset(tmp_file,
1064-
autoclose=self.autoclose, **open_kwargs) as ds:
1065-
yield ds
1066-
10671070
def test_engine(self):
10681071
data = create_test_data()
10691072
with self.assertRaisesRegexp(ValueError, 'unrecognized engine'):
@@ -1122,21 +1125,13 @@ class GenericNetCDFDataTestAutocloseTrue(GenericNetCDFDataTest):
11221125
@requires_h5netcdf
11231126
@requires_netCDF4
11241127
class H5NetCDFDataTest(BaseNetCDF4Test, TestCase):
1128+
engine = 'h5netcdf'
1129+
11251130
@contextlib.contextmanager
11261131
def create_store(self):
11271132
with create_tmp_file() as tmp_file:
11281133
yield backends.H5NetCDFStore(tmp_file, 'w')
11291134

1130-
@contextlib.contextmanager
1131-
def roundtrip(self, data, save_kwargs={}, open_kwargs={},
1132-
allow_cleanup_failure=False):
1133-
with create_tmp_file(
1134-
allow_cleanup_failure=allow_cleanup_failure) as tmp_file:
1135-
data.to_netcdf(tmp_file, engine='h5netcdf', **save_kwargs)
1136-
with open_dataset(tmp_file, engine='h5netcdf',
1137-
autoclose=self.autoclose, **open_kwargs) as ds:
1138-
yield ds
1139-
11401135
def test_orthogonal_indexing(self):
11411136
# doesn't work for h5py (without using dask as an intermediate layer)
11421137
pass
@@ -1646,14 +1641,13 @@ def test_orthogonal_indexing(self):
16461641
pass
16471642

16481643
@contextlib.contextmanager
1649-
def roundtrip(self, data, save_kwargs={}, open_kwargs={},
1650-
allow_cleanup_failure=False):
1651-
with create_tmp_file(
1652-
allow_cleanup_failure=allow_cleanup_failure) as tmp_file:
1653-
data.to_netcdf(tmp_file, engine='scipy', **save_kwargs)
1654-
with open_dataset(tmp_file, engine='pynio',
1655-
autoclose=self.autoclose, **open_kwargs) as ds:
1656-
yield ds
1644+
def open(self, path, **kwargs):
1645+
with open_dataset(path, engine='pynio', autoclose=self.autoclose,
1646+
**kwargs) as ds:
1647+
yield ds
1648+
1649+
def save(self, dataset, path, **kwargs):
1650+
dataset.to_netcdf(path, engine='scipy', **kwargs)
16571651

16581652
def test_weakrefs(self):
16591653
example = Dataset({'foo': ('x', np.arange(5.0))})

0 commit comments

Comments
 (0)