Skip to content

Commit

Permalink
Improve script for plotting binom. dist. by adding softmax with tempe…
Browse files Browse the repository at this point in the history
…rature
  • Loading branch information
nathanpainchaud committed Dec 19, 2023
1 parent f653f78 commit 949c663
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion didactic/scripts/binomial_figure.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,16 @@ def main():
import numpy as np
import scipy.stats as stats
import seaborn.objects as so
from scipy.special import softmax

parser = argparse.ArgumentParser(description="Plot a binomial distribution.")
parser.add_argument("--n", type=int, default=6, help="Number of trials.")
parser.add_argument("--x_title", type=str, default="k", help="Title of the x-axis.")
parser.add_argument("--x_labels", type=str, nargs="+", help="Tick labels of the x-axis.")
parser.add_argument("--y_title", type=str, default="B(k,p)", help="Title of the y-axis.")
parser.add_argument("--p", type=float, default=0.4, help="Probability of success.")
parser.add_argument("--output_name", type=Path, help="Output file name")
parser.add_argument("--tau", type=float, default=1, help="Temperature parameter for the softmax function.")
parser.add_argument("--output_name", type=Path, help="Output file name.")
args = parser.parse_args()

if len(args.x_labels) != args.n:
Expand All @@ -26,6 +28,7 @@ def main():
# Compute the binomial distribution
x = np.arange(args.n)
y = stats.binom.pmf(x, args.n, args.p)
y = softmax(y / args.tau)

# Plot the binomial distribution
if categorical_x := args.x_labels is not None:
Expand Down

0 comments on commit 949c663

Please sign in to comment.