Skip to content

Commit 3a4befc

Browse files
authored
Merge pull request #548 from xylar/fix_isomip_plus_osf
Fix `isomip_plus` streamfunction computation
2 parents bba7983 + 0248a4a commit 3a4befc

File tree

1 file changed

+52
-48
lines changed

1 file changed

+52
-48
lines changed

compass/ocean/tests/isomip_plus/streamfunction.py

+52-48
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
1-
import xarray
1+
import os
2+
23
import numpy
4+
import progressbar
35
import scipy.sparse
46
import scipy.sparse.linalg
5-
import progressbar
6-
import os
7+
import xarray
78
from mpas_tools.io import write_netcdf
89

9-
from compass.step import Step
1010
from compass.ocean.tests.isomip_plus.viz import file_complete
11+
from compass.step import Step
1112

1213

1314
class Streamfunction(Step):
@@ -117,14 +118,14 @@ def _compute_overturning_streamfunction(dsMesh, ds, out_dir, dx=2e3, dz=5.,
117118
if file_complete(ds, osfFileName):
118119
return
119120

120-
xMin = 320e3 + 0.5*dx
121-
xMax = 800e3 - 0.5*dx
122-
nx = int((xMax - xMin)/dx + 1)
121+
xMin = 320e3 + 0.5 * dx
122+
xMax = 800e3 - 0.5 * dx
123+
nx = int((xMax - xMin) / dx + 1)
123124
x = numpy.linspace(xMin, xMax, nx)
124125

125-
zMin = -720.0 + 0.5*dz
126-
zMax = 0.0 - 0.5*dz
127-
nz = int((zMax - zMin)/dz + 1)
126+
zMin = -720.0 + 0.5 * dz
127+
zMax = 0.0 - 0.5 * dz
128+
nz = int((zMax - zMin) / dz + 1)
128129
z = numpy.linspace(zMax, zMin, nz)
129130

130131
try:
@@ -173,8 +174,8 @@ def _compute_barotropic_transport(dsMesh, ds):
173174
normalVelocity = ds.timeMonthly_avg_normalVelocity[:, innerEdges, :].chunk(
174175
chunks={'Time': 1})
175176

176-
layerThicknessEdge = 0.5*(layerThickness[:, cell0, :] +
177-
layerThickness[:, cell1, :])
177+
layerThicknessEdge = 0.5 * (layerThickness[:, cell0, :] +
178+
layerThickness[:, cell1, :])
178179
transport = dsMesh.dvEdge[innerEdges] * \
179180
(layerThicknessEdge * normalVelocity).sum(dim='nVertLevels')
180181

@@ -200,28 +201,28 @@ def _compute_barotropic_streamfunction_vertex(dsMesh, ds, show_progress):
200201
nBoundaryVertices = len(boundaryVertices)
201202
nInnerEdges = len(innerEdges)
202203

203-
indices = numpy.zeros((2, 2*nInnerEdges+nBoundaryVertices), dtype=int)
204-
data = numpy.zeros(2*nInnerEdges+nBoundaryVertices, dtype=float)
204+
indices = numpy.zeros((2, 2 * nInnerEdges + nBoundaryVertices), dtype=int)
205+
data = numpy.zeros(2 * nInnerEdges + nBoundaryVertices, dtype=float)
205206

206207
# The difference between the streamfunction at vertices on an inner edge
207208
# should be equal to the transport
208209
v0 = verticesOnEdge[innerEdges, 0].values
209210
v1 = verticesOnEdge[innerEdges, 1].values
210211

211212
ind = numpy.arange(nInnerEdges)
212-
indices[0, 2*ind] = ind
213-
indices[1, 2*ind] = v1
214-
data[2*ind] = 1.
213+
indices[0, 2 * ind] = ind
214+
indices[1, 2 * ind] = v1
215+
data[2 * ind] = 1.
215216

216-
indices[0, 2*ind+1] = ind
217-
indices[1, 2*ind+1] = v0
218-
data[2*ind+1] = -1.
217+
indices[0, 2 * ind + 1] = ind
218+
indices[1, 2 * ind + 1] = v0
219+
data[2 * ind + 1] = -1.
219220

220221
# the streamfunction should be zero at all boundary vertices
221222
ind = numpy.arange(nBoundaryVertices)
222-
indices[0, 2*nInnerEdges + ind] = nInnerEdges + ind
223-
indices[1, 2*nInnerEdges + ind] = boundaryVertices
224-
data[2*nInnerEdges + ind] = 1.
223+
indices[0, 2 * nInnerEdges + ind] = nInnerEdges + ind
224+
indices[1, 2 * nInnerEdges + ind] = boundaryVertices
225+
data[2 * nInnerEdges + ind] = 1.
225226

226227
bsfVertex = xarray.DataArray(numpy.zeros((nTime, nVertices)),
227228
dims=('Time', 'nVertices'))
@@ -235,24 +236,24 @@ def _compute_barotropic_streamfunction_vertex(dsMesh, ds, show_progress):
235236
bar = None
236237

237238
for tIndex in range(nTime):
238-
rhs = numpy.zeros(nInnerEdges+nBoundaryVertices, dtype=float)
239+
rhs = numpy.zeros(nInnerEdges + nBoundaryVertices, dtype=float)
239240

240241
# convert to Sv
241242
ind = numpy.arange(nInnerEdges)
242-
rhs[ind] = 1e-6*transport.isel(Time=tIndex)
243+
rhs[ind] = 1e-6 * transport.isel(Time=tIndex)
243244

244245
ind = numpy.arange(nBoundaryVertices)
245246
rhs[nInnerEdges + ind] = 0.
246247

247248
M = scipy.sparse.csr_matrix((data, indices),
248-
shape=(nInnerEdges+nBoundaryVertices,
249+
shape=(nInnerEdges + nBoundaryVertices,
249250
nVertices))
250251

251252
solution = scipy.sparse.linalg.lsqr(M, rhs)
252253

253254
bsfVertex[tIndex, :] = -solution[0]
254255
if show_progress:
255-
bar.update(tIndex+1)
256+
bar.update(tIndex + 1)
256257
if show_progress:
257258
bar.finish()
258259

@@ -266,13 +267,13 @@ def _compute_barotropic_streamfunction_cell(dsMesh, bsfVertex):
266267
nEdgesOnCell = dsMesh.nEdgesOnCell
267268
edgesOnCell = dsMesh.edgesOnCell - 1
268269
verticesOnCell = dsMesh.verticesOnCell - 1
269-
areaEdge = dsMesh.dcEdge*dsMesh.dvEdge
270+
areaEdge = dsMesh.dcEdge * dsMesh.dvEdge
270271
prevEdgesOnCell = edgesOnCell.copy(deep=True)
271272
prevEdgesOnCell[:, 1:] = edgesOnCell[:, 0:-1]
272-
prevEdgesOnCell[:, 0] = edgesOnCell[:, nEdgesOnCell-1]
273+
prevEdgesOnCell[:, 0] = edgesOnCell[:, nEdgesOnCell - 1]
273274

274275
mask = verticesOnCell >= 0
275-
areaVert = mask*0.5*(areaEdge[edgesOnCell] + areaEdge[prevEdgesOnCell])
276+
areaVert = mask * 0.5 * (areaEdge[edgesOnCell] + areaEdge[prevEdgesOnCell])
276277

277278
bsfCell = ((areaVert * bsfVertex[:, verticesOnCell]).sum(dim='maxEdges') /
278279
areaVert.sum(dim='maxEdges'))
@@ -327,17 +328,18 @@ def _compute_horizontal_transport_mpas(ds, dsMesh, outFileName):
327328
minLevelEdgeBot = minLevelEdgeBot.chunk(chunks)
328329
maxLevelEdgeTop = maxLevelEdgeTop.chunk(chunks)
329330
dvEdge = dsMesh.dvEdge[internalEdgeIndices].chunk(chunks)
330-
bottomDepthEdge = 0.5*(bottomDepth[cell0] +
331-
bottomDepth[cell1]).chunk(chunks)
331+
bottomDepthEdge = 0.5 * (bottomDepth[cell0] +
332+
bottomDepth[cell1]).chunk(chunks)
332333

333334
chunks = {'Time': 1, 'nInternalEdges': 1024}
334335

335336
normalVelocity = ds.timeMonthly_avg_normalVelocity.isel(
336337
nEdges=internalEdgeIndices).chunk(chunks)
337338
layerThickness = ds.timeMonthly_avg_layerThickness.chunk()
338339

339-
layerThicknessEdge = 0.5*(layerThickness.isel(nCells=cell0) +
340-
layerThickness.isel(nCells=cell1)).chunk(chunks)
340+
layerThicknessEdge = (
341+
0.5 * (layerThickness.isel(nCells=cell0) +
342+
layerThickness.isel(nCells=cell1)).chunk(chunks))
341343

342344
mask = numpy.logical_and(vertIndex >= minLevelEdgeBot,
343345
vertIndex <= maxLevelEdgeTop)
@@ -356,7 +358,7 @@ def _compute_horizontal_transport_mpas(ds, dsMesh, outFileName):
356358
zInterfaceEdge.rename({'nVertLevels': 'nVertLevelsP1'})],
357359
dim='nVertLevelsP1')
358360

359-
transportPerDepth = dvEdge*normalVelocity
361+
transportPerDepth = dvEdge * normalVelocity
360362

361363
dsOut = xarray.Dataset()
362364
dsOut['xtime_startMonthly'] = ds.xtime_startMonthly
@@ -365,7 +367,7 @@ def _compute_horizontal_transport_mpas(ds, dsMesh, outFileName):
365367
dsOut['layerThicknessEdge'] = layerThicknessEdge
366368
dsOut['transportPerDepth'] = transportPerDepth
367369
dsOut['transportVertSum'] = \
368-
(transportPerDepth*layerThicknessEdge).sum(dim='nVertLevels')
370+
(transportPerDepth * layerThicknessEdge).sum(dim='nVertLevels')
369371

370372
dsOut = dsOut.transpose('Time', 'nInternalEdges', 'nVertLevels',
371373
'nVertLevelsP1')
@@ -401,7 +403,7 @@ def _interpolate_horizontal_transport_zlevel(ds, z, outFileName,
401403
nVertLevels = ds.sizes['nVertLevels']
402404

403405
if show_progress:
404-
widgets = ['interpolating tansport on z-level grid: ',
406+
widgets = ['interpolating transport on z-level grid: ',
405407
progressbar.Percentage(), ' ', progressbar.Bar(), ' ',
406408
progressbar.ETA()]
407409
bar = progressbar.ProgressBar(widgets=widgets, maxval=nTime).start()
@@ -417,25 +419,26 @@ def _interpolate_horizontal_transport_zlevel(ds, z, outFileName,
417419
continue
418420

419421
outTransport = xarray.DataArray(
420-
numpy.zeros((nInternalEdges, nz-1)),
422+
numpy.zeros((nInternalEdges, nz - 1)),
421423
dims=('nInternalEdges', 'nzM1'))
422424

423425
dzSum = xarray.DataArray(
424-
numpy.zeros((nInternalEdges, nz-1)),
426+
numpy.zeros((nInternalEdges, nz - 1)),
425427
dims=('nInternalEdges', 'nzM1'))
426428

427429
dsIn = ds.isel(Time=tIndex)
428430
for inZIndex in range(nVertLevels):
429431
zTop = dsIn.zInterfaceEdge.isel(nVertLevelsP1=inZIndex)
430-
zBot = dsIn.zInterfaceEdge.isel(nVertLevelsP1=inZIndex+1)
432+
zBot = dsIn.zInterfaceEdge.isel(nVertLevelsP1=inZIndex + 1)
431433
inTransportPerDepth = \
432434
dsIn.transportPerDepth.isel(nVertLevels=inZIndex)
435+
inTransportPerDepth = inTransportPerDepth.fillna(value=0.)
433436

434437
zt = numpy.minimum(zTop, z0)
435438
zb = numpy.maximum(zBot, z1)
436439
dz = numpy.maximum(zt - zb, 0.)
437440

438-
outTransport = outTransport + dz*inTransportPerDepth
441+
outTransport = outTransport + dz * inTransportPerDepth
439442

440443
dzSum = dzSum + dz
441444

@@ -453,7 +456,7 @@ def _interpolate_horizontal_transport_zlevel(ds, z, outFileName,
453456

454457
write_netcdf(dsOut, fileName)
455458

456-
assert(numpy.abs(dsOut.transportVertSumCheck).max().values < 1e-9)
459+
assert numpy.abs(dsOut.transportVertSumCheck).max().values < 1e-9
457460

458461
if show_progress:
459462
bar.update(tIndex + 1)
@@ -503,7 +506,7 @@ def _vertical_cumsum_horizontal_transport(ds, outFileName):
503506
# with either the output layer above or the one below
504507
mask = ds.mask.rename({'nzM1': 'nz'})
505508
maskTop = mask.isel(nz=0)
506-
maskBot = mask.isel(nz=nz-2)
509+
maskBot = mask.isel(nz=nz - 2)
507510

508511
outMask = xarray.concat([maskTop,
509512
numpy.logical_or(mask[:, 0:-1, :],
@@ -602,12 +605,13 @@ def _horizontally_bin_overturning_streamfunction(ds, dsMesh, x, osfFileName,
602605

603606
if len(edgeIndices) == 0:
604607

605-
localOSF = numpy.nan*xarray.DataArray(numpy.ones((nTime, nz)),
606-
dims=('Time', 'nz'))
608+
localOSF = numpy.nan * xarray.DataArray(numpy.ones((nTime, nz)),
609+
dims=('Time', 'nz'))
607610
else:
608611
# convert to Sv
609-
transportSum = 1e-6 * \
610-
edgeSigns*ds.transportSum.isel(nInternalEdges=edgeIndices)
612+
transportSum = (
613+
1e-6 * edgeSigns *
614+
ds.transportSum.isel(nInternalEdges=edgeIndices))
611615

612616
localOSF = transportSum.sum(dim='nInternalEdges')
613617

@@ -620,7 +624,7 @@ def _horizontally_bin_overturning_streamfunction(ds, dsMesh, x, osfFileName,
620624
write_netcdf(dsOSF, fileName)
621625

622626
if showProgress:
623-
bar.update(xIndex+1)
627+
bar.update(xIndex + 1)
624628

625629
if showProgress:
626630
bar.finish()

0 commit comments

Comments
 (0)