diff --git a/figures/parameter_size_overview.png b/figures/parameter_size_overview.png index bf01dfc..19c07f0 100644 Binary files a/figures/parameter_size_overview.png and b/figures/parameter_size_overview.png differ diff --git a/figures/scripts/parameter_size_overview_generate.py b/figures/scripts/parameter_size_overview_generate.py index 8196fd2..991b1f1 100644 --- a/figures/scripts/parameter_size_overview_generate.py +++ b/figures/scripts/parameter_size_overview_generate.py @@ -47,11 +47,48 @@ "EN-unavailable": "#B8CDF7", } +labels = { + "JP-available": "Japanese (public)", + "JP-unavailable": "Japanese (private)", + "EN-available": "English (public)", + "EN-unavailable": "English (private)", +} + fig, ax = plt.subplots(figsize=(10, 8)) -for i in reversed(range(len(df))): - ax.scatter(df["Announced"][i], df["Parameters(B)"][i], color=colors[df["Type"][i]], s=100 if df["Type"][i].startswith("JP") else 30) +ax.scatter( + df[df["Type"] == "EN-unavailable"]["Announced"], + df[df["Type"] == "EN-unavailable"]["Parameters(B)"], + color=colors["EN-unavailable"], + label=labels["EN-unavailable"], + s=30 +) + +ax.scatter( + df[df["Type"] == "EN-available"]["Announced"], + df[df["Type"] == "EN-available"]["Parameters(B)"], + color=colors["EN-available"], + label=labels["EN-available"], + s=30 +) + +ax.scatter( + df[df["Type"] == "JP-unavailable"]["Announced"], + df[df["Type"] == "JP-unavailable"]["Parameters(B)"], + color=colors["JP-unavailable"], + label=labels["JP-unavailable"], + s=100 +) + +ax.scatter( + df[df["Type"] == "JP-available"]["Announced"], + df[df["Type"] == "JP-available"]["Parameters(B)"], + color=colors["JP-available"], + label=labels["JP-available"], + s=100 +) +for i in reversed(range(len(df))): if df["Type"][i].startswith("JP"): ax.text( df["Announced"][i], @@ -61,7 +98,6 @@ verticalalignment='bottom', horizontalalignment='center' ) - if df["Type"][i].startswith("EN") and df["Lab"][i] in BIGTECH_LIST and df["Parameters(B)"][i] > 100: ax.text( df["Announced"][i], @@ -87,4 +123,5 @@ plt.xticks(rotation=45) plt.subplots_adjust(left=0.075, right=0.975, bottom=0.1, top=0.975) +plt.legend() plt.savefig("../parameter_size_overview.png", dpi=144) \ No newline at end of file