diff --git a/cartopy_plots.py b/cartopy_plots.py index a1c0835..621b38e 100644 --- a/cartopy_plots.py +++ b/cartopy_plots.py @@ -446,7 +446,7 @@ def ShowArea(lon_mask, lat_mask, field_mask, coords=[-7,15,40,60], **kwargs): -def multiple_field_plot(lon, lat, f, significance=None, projections=ccrs.Orthographic(central_latitude=90), extents=None, cmaps='RdBu_r', figsize=(9,6), fig_num=None, one_fig_layout=False, +def multiple_field_plot(lon, lat, f, significance=None, projections=ccrs.Orthographic(central_latitude=90), extents=None, cmaps='RdBu_r', fig=None, figsize=(9,6), fig_num=None, one_fig_layout=False, colorbar='individual', mx=None, titles=None, apply_tight_layout=True, significance_hatches=('//', None), **kwargs): ''' Plots several fields @@ -582,9 +582,9 @@ def multiple_field_plot(lon, lat, f, significance=None, projections=ccrs.Orthogr if np.prod(one_fig_layout) < n_fields: raise ValueError(f'Cannot accomodate {n_fields} subplots in a {one_fig_layout[0]} by {one_fig_layout[1]} grid!') - - plt.close(fig_num) - fig = plt.figure(num=fig_num, figsize=figsize) + if fig is None: + plt.close(fig_num) + fig = plt.figure(num=fig_num, figsize=figsize) if not isinstance(mx, list): mx = [mx]*n_fields @@ -601,11 +601,21 @@ def multiple_field_plot(lon, lat, f, significance=None, projections=ccrs.Orthogr levels[i] = np.linspace(-_mx,_mx, levels[i]) if one_fig_layout: - if isinstance(one_fig_layout, int): + if isinstance(one_fig_layout, list): + assert len(one_fig_layout) == n_fields + ofl = one_fig_layout[i] + if isinstance(ofl, int): + m = fig.add_subplot(ofl, projection=projections[i]) + else: + assert len(ofl) == 3 + m = plt.subplot2grid(ofl[:2], ofl[-1], projection=projections[i]) + elif isinstance(one_fig_layout, int): m = fig.add_subplot(one_fig_layout + i + 1, projection=projections[i]) else: m = plt.subplot2grid(one_fig_layout, np.unravel_index(i, one_fig_layout), projection=projections[i]) else: + if fig is not None: + raise ValueError('Cannot provide fig if not using one_fig_layout') if fig_num is not None: plt.close(fig_num + i) fig = plt.figure(figsize=figsize, num=fig_num + i)