|
| 1 | +# Author: Firas Moosvi, Jake Bobowski, others |
| 2 | +# Date: 2023-10-31 |
| 3 | + |
| 4 | +from __future__ import annotations |
| 5 | + |
| 6 | +import numpy as np |
| 7 | +import matplotlib.pyplot as plt |
| 8 | +from matplotlib.figure import Figure |
| 9 | +from scipy import stats |
| 10 | + |
| 11 | + |
| 12 | +def shaded_normal_density( |
| 13 | + q: float | tuple[float, float], |
| 14 | + /, |
| 15 | + mean: float = 0, |
| 16 | + sd: float = 1, |
| 17 | + rsd: float = 4, |
| 18 | + lower_tail: bool = True, |
| 19 | + add_prob: bool = True, |
| 20 | + add_q: bool = True, |
| 21 | + add_legend: bool = False, |
| 22 | + figsize: tuple[float, float] | None = (8, 6), |
| 23 | + color: str = "xkcd:sky blue", |
| 24 | + x_label: str = "x", |
| 25 | + y_label: str = "f(x; μ,σ)", |
| 26 | + legend_text: str | None = None, |
| 27 | + **kwargs, |
| 28 | +) -> Figure: |
| 29 | + """ |
| 30 | + Generate a normal distribution plot with optional listed probability calculation. |
| 31 | +
|
| 32 | + Parameters |
| 33 | + ---------- |
| 34 | + q : float or tuple of 2 floats |
| 35 | + If a float, the upper or lower bound of the shaded area. If a tuple of floats, the lower and upper bounds of the shaded area. |
| 36 | + mean : float, default: 0 |
| 37 | + The mean of the normal distribution. |
| 38 | + sd : float, default: 1 |
| 39 | + The standard deviation of the normal distribution. |
| 40 | + rsd : float, default: 4 |
| 41 | + The number of standard deviations to plot on either side of the mean=. |
| 42 | + lower_tail : bool, default: True |
| 43 | + Whether the shaded area should represent the lower tail probability P(X <= x) (True) or the upper tail probability P(X > x) (False). |
| 44 | + add_prob : bool, default: True |
| 45 | + Whether to show the probability of the shaded area will be displayed on the plot. |
| 46 | + add_q : bool, default: True |
| 47 | + Whether the value(s) of `q` should be displayed on the x-axis of the plot. |
| 48 | + add_legend : bool, default: False |
| 49 | + Whether a legend with the mean and standard deviation values will be displayed on the plot. |
| 50 | + figsize : tuple of 2 floats or None, default: (8, 6) |
| 51 | + The size of the plot in inches. If None, the default matplotlib figure size will be used as this is passed to `matplotlib.pyplot.figure`. |
| 52 | + color : color, default: 'xkcd:sky blue' |
| 53 | + The color of the shaded area as a valid `matplotlib color <https://matplotlib.org/stable/users/explain/colors/colors.html>`__. |
| 54 | + x_label : str, default: 'x' |
| 55 | + The label for the x-axis. |
| 56 | + y_label : str, default: 'f(x; μ,σ)' |
| 57 | + The label for the y-axis. |
| 58 | + legend_text : str or None, optional |
| 59 | + The text to display in the legend if add_legend is set to true. By default (None), the legend will display the mean and standard deviation values. |
| 60 | + **kwargs |
| 61 | + Additional keyword arguments to pass to `matplotlib.pyplot.figure`. |
| 62 | +
|
| 63 | + Returns |
| 64 | + ------- |
| 65 | + matplotlib.figure.Figure |
| 66 | + The generated matplotlib Figure object. |
| 67 | +
|
| 68 | + Raises |
| 69 | + ------ |
| 70 | + TypeError |
| 71 | + If the input parameters are not of the expected type. |
| 72 | + ValueError |
| 73 | + If the input values are out of the expected range. |
| 74 | +
|
| 75 | + References |
| 76 | + ---------- |
| 77 | + Based off of an R function written by Dr. Irene Vrbick for making `shaded normal density curves <https://irene.vrbik.ok.ubc.ca/blog/2021-11-04-shading-under-the-normal-curve/>`__. |
| 78 | +
|
| 79 | + The R function by Dr. Irene Vrbick was adapted from `here <http://rstudio-pubs-static.s3.amazonaws.com/78857_86c2403ca9c146ba8fcdcda79c3f4738.html>`__. |
| 80 | + """ |
| 81 | + if not isinstance(mean, (float, int)): |
| 82 | + raise TypeError(f"mean must be a number, not a {mean.__class__.__name__!r}") |
| 83 | + if not isinstance(sd, (float, int)): |
| 84 | + raise TypeError(f"sd must be a number, not a {sd.__class__.__name__!r}") |
| 85 | + if not isinstance(rsd, (float, int)): |
| 86 | + raise TypeError(f"rsd must be a number, not a {rsd.__class__.__name__!r}") |
| 87 | + if ( |
| 88 | + isinstance(q, tuple) |
| 89 | + and len(q) == 2 |
| 90 | + and isinstance(q[0], (float, int)) |
| 91 | + and isinstance(q[1], (float, int)) |
| 92 | + ): |
| 93 | + q_lower, q_upper = sorted(q) |
| 94 | + xx = np.linspace(mean - rsd * sd, mean + rsd * sd, 200) |
| 95 | + yy = stats.norm.pdf(xx, mean, sd) |
| 96 | + fig = plt.figure(figsize=figsize, **kwargs) |
| 97 | + ax = fig.gca() |
| 98 | + ax.plot(xx, yy) |
| 99 | + ax.set_xlabel(x_label) |
| 100 | + ax.set_ylabel(y_label) |
| 101 | + x = np.linspace(q_lower, q_upper, 200) |
| 102 | + y = stats.norm.pdf(x, mean, sd) |
| 103 | + # fmt: off |
| 104 | + filled, *_ = ax.fill( # Fill returns a list of polygons, but we're only making one |
| 105 | + np.concatenate([[q_lower], x, [q_upper]]), |
| 106 | + np.concatenate([[0], y, [0]]), |
| 107 | + color, |
| 108 | + ) |
| 109 | + # fmt: on |
| 110 | + if add_prob: |
| 111 | + height = max(y) / 4 |
| 112 | + rv = stats.norm(mean, sd) |
| 113 | + prob: float = rv.cdf(q_upper) - rv.cdf(q_lower) |
| 114 | + ax.text((sum(q) / 2), height, f"{prob:.3f}", ha="center") |
| 115 | + if add_q: |
| 116 | + ax.set_xticks( |
| 117 | + [q_lower, q_upper], |
| 118 | + labels=[ |
| 119 | + str(round(q_lower, 4)), |
| 120 | + str(round(q_upper, 4)), |
| 121 | + ], |
| 122 | + minor=True, |
| 123 | + color=color, |
| 124 | + y=-0.05, |
| 125 | + ) |
| 126 | + if q_lower in ax.get_xticks(): |
| 127 | + ax.get_xticklabels()[ |
| 128 | + np.where(ax.get_xticks() == q_lower)[0][0] |
| 129 | + ].set_color(color) |
| 130 | + if q_upper in ax.get_xticks(): |
| 131 | + ax.get_xticklabels()[ |
| 132 | + np.where(ax.get_xticks() == q_upper)[0][0] |
| 133 | + ].set_color(color) |
| 134 | + |
| 135 | + elif isinstance(q, (float, int)): |
| 136 | + if not isinstance(lower_tail, bool): |
| 137 | + raise TypeError( |
| 138 | + f"lower_tail must be a bool, not a {lower_tail.__class__.__name__!r}" |
| 139 | + ) |
| 140 | + |
| 141 | + xx = np.linspace(mean - rsd * sd, mean + rsd * sd, 200) |
| 142 | + yy = stats.norm.pdf(xx, mean, sd) |
| 143 | + fig = plt.figure(figsize=figsize, **kwargs) |
| 144 | + ax = fig.gca() |
| 145 | + ax.plot(xx, yy) |
| 146 | + ax.set_xlabel(x_label) |
| 147 | + ax.set_ylabel(y_label) |
| 148 | + |
| 149 | + if lower_tail is True: |
| 150 | + x = np.linspace(xx[0], q, 100) |
| 151 | + y = stats.norm.pdf(x, mean, sd) |
| 152 | + # fmt: off |
| 153 | + filled, *_ = ax.fill( # Fill returns a list of polygons, but we're only making one |
| 154 | + np.concatenate([[xx[0]], x, [q]]), |
| 155 | + np.concatenate([[0], y, [0]]), |
| 156 | + color, |
| 157 | + ) |
| 158 | + # fmt: on |
| 159 | + if add_prob: |
| 160 | + height: float = stats.norm.pdf(q, mean, sd) / 4 # type: ignore |
| 161 | + prob: float = stats.norm.cdf(q, mean, sd) # type: ignore |
| 162 | + ax.text((q - 0.5 * sd), height, f"{prob:.3f}", ha="center") |
| 163 | + else: |
| 164 | + x = np.linspace(q, xx[-1], 100) |
| 165 | + y = stats.norm.pdf(x, mean, sd) |
| 166 | + # fmt: off |
| 167 | + filled, *_ = ax.fill( # Fill returns a list of polygons, but we're only making one |
| 168 | + np.concatenate([[q], x, [xx[-1]]]), |
| 169 | + np.concatenate([[0], y, [0]]), |
| 170 | + color, |
| 171 | + ) |
| 172 | + # fmt: on |
| 173 | + if add_prob: |
| 174 | + height: float = stats.norm.pdf(q, mean, sd) / 4 # type: ignore |
| 175 | + prob: float = stats.norm.sf(q, mean, sd) # type: ignore |
| 176 | + ax.text((q + 0.5 * sd), height, f"{prob:.3f}", ha="center") |
| 177 | + |
| 178 | + if add_q: |
| 179 | + if q in ax.get_xticks(): |
| 180 | + ax.get_xticklabels()[np.where(ax.get_xticks() == q)[0][0]].set_color( |
| 181 | + color |
| 182 | + ) |
| 183 | + else: |
| 184 | + ax.set_xticks( |
| 185 | + [q], |
| 186 | + labels=[ |
| 187 | + str(round(q, 4)), |
| 188 | + ], |
| 189 | + minor=True, |
| 190 | + color=color, |
| 191 | + y=-0.05, |
| 192 | + ) |
| 193 | + |
| 194 | + else: |
| 195 | + error_base = "q must be a tuple of two numbers, or a single number" |
| 196 | + if isinstance(q, tuple): |
| 197 | + if len(q) != 2: |
| 198 | + raise ValueError(f"{error_base}, not a {len(q)}-tuple") |
| 199 | + raise TypeError( |
| 200 | + f"{error_base}, not a 2-tuple containing a {q[0].__class__.__name__!r} and a {q[1].__class__.__name__!r}" |
| 201 | + ) |
| 202 | + else: |
| 203 | + raise TypeError(f"{error_base}, not a {q.__class__.__name__!r}") |
| 204 | + |
| 205 | + if add_legend: |
| 206 | + ax.set_title(legend_text or f"μ = {mean}, σ = {sd}") |
| 207 | + |
| 208 | + return fig |
0 commit comments