diff --git a/uvtools/plot.py b/uvtools/plot.py index d5daac9b..828e785d 100644 --- a/uvtools/plot.py +++ b/uvtools/plot.py @@ -391,6 +391,7 @@ def labeled_waterfall( vmin=None, vmax=None, dynamic_range=None, + Nticks=6, fft_axis=None, freq_taper=None, freq_taper_kwargs=None, @@ -482,6 +483,11 @@ def labeled_waterfall( five orders of magnitude below the maximum. If ``mode=="phs"``, then this parameter is ignored. If both ``vmin`` and ``vmax`` are provided, then this parameter is ignored. + Nticks: int or iterable of int, optional + Number of tick marks to use on the plot axes. If a single number is passed, + then the same number of ticks are used on both axes. If an iterable is + passed, then it must be length 2 and specify the number of ticks to use on + the x- and y-axes, respectively. Default is to use 6 ticks per axis. fft_axis: int or str, optional Axis over which to perform a Fourier transform. May be specified with one of three strings ("time", "freq", "both") or one of three integers (0, 1, @@ -504,6 +510,18 @@ def labeled_waterfall( Figure containing the plot. ax: :class:`plt.Axes` instance Axes object the waterfall is drawn into. + + Notes + ----- + If you are plotting data with LSTs listed on the time axis and passing a + ndarray to the ``data`` parameter, then care should be taken when providing + the LST array. If you are pulling the LSTs from a ``pyuvdata.UVData`` object, + then you should *not* use ``np.unique`` to extract the unique LSTs from the + ``UVData.lst_array`` attribute--this will sort the LSTs, which *will* cause + problems if there is a phase wrap in the LSTs. If you find yourself faced + with this situation, a relatively simple solution can be found in the source + code for this function: see the end of the block of code following the + "# Validate parameters." comment. """ import matplotlib.pyplot as plt @@ -516,6 +534,15 @@ def labeled_waterfall( raise TypeError("array-like data must consist of complex numbers.") if data.ndim != 2 or (data.ndim == 2 and 1 in data.shape): raise ValueError("array-like data must be 2-dimensional.") + if type(Nticks) is not int: + try: + _ = iter(Nticks) + if len(Nticks) != 2 or not all(type(Ntick) is int for Ntick in Nticks): + raise TypeError + except TypeError: + raise TypeError( + "Nticks must be an integer or length-2 iterable of integers." + ) if isinstance(data, np.ndarray): if freqs is None or (times is None and lsts is None): raise ValueError( @@ -540,7 +567,8 @@ def labeled_waterfall( ) freqs = np.unique(data.freq_array) times = np.unique(data.time_array) - lsts = np.unique(data.lst_array) + lst_inds = sorted(np.unique(data.time_array, return_index=True)[1]) + lsts = data.lst_array[lst_inds] data_units = data.vis_units or data_units data = data.get_data(antpairpol) @@ -675,6 +703,16 @@ def labeled_waterfall( else: fig = ax.get_figure() + # Choose bounds for setting plot extent. + xmin, xmax = xvals.min(), xvals.max() + # Special handling for LSTs since they can wrap. + if time_or_lst == "lst" and fft_axis not in ("time", "both"): + adjust_yaxis = True + ymin, ymax = 0, len(lsts) - 1 + else: + adjust_yaxis = False + ymin, ymax = yvals.min(), yvals.max() + # Finish setup, then plot. ax.set_xlabel(xlabel, fontsize=fontsize) ax.set_ylabel(ylabel, fontsize=fontsize) @@ -683,8 +721,14 @@ def labeled_waterfall( aspect=aspect, cmap=cmap, norm=norm, - extent=(xvals.min(), xvals.max(), yvals.max(), yvals.min()), + extent=(xmin, xmax, ymax, ymin), ) + if adjust_yaxis: + # This is a bit of a hack, but it gets the job done. + ax.set_yticks(ax.get_yticks()[:-1]) + ax.set_yticklabels( + f"{yval:.2f}" for yval in yvals[ax.get_yticks().astype(int)] + ) # Optionally draw a colorbar. if draw_colorbar: