Skip to content

Commit

Permalink
v0.2.13
Browse files Browse the repository at this point in the history
  • Loading branch information
yh202109 committed Jul 3, 2024
1 parent eba70ec commit 42d2922
Showing 1 changed file with 25 additions and 14 deletions.
39 changes: 25 additions & 14 deletions mtbp3/statlab/kappa.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ def bootstrap_cohen_ci(self, n_iterations=1000, confidence_level=0.95, outfmt='s
else:
return [self.cohen_kappa, n_iterations, confidence_level, lower_bound, upper_bound]

def create_bubble_plot(self, out_path="", title="", axis_label=[], max_size_ratio=100):
def create_bubble_plot(self, out_path="", axis_label=[], max_size_ratio=0, hist=False):
"""
Creates a bubble plot based on the y_count_sq matrix.
Expand All @@ -268,29 +268,40 @@ def create_bubble_plot(self, out_path="", title="", axis_label=[], max_size_rati
if self.n_rater == 2 and self.y_count_sq is not None and self.y_count_sq.shape[0] == self.y_count_sq.shape[1] and self.y_count_sq.shape[0] > 0:
categories = self.y_count_sq.columns
n_categories = len(categories)
r1 = []
max_size_ratio = max_size_ratio if max_size_ratio >= 1 else max(1,int(150 / n_categories))

r1 = []
r2 = []
sizes = []
for i1, c1 in enumerate(categories):
for i2, c2 in enumerate(categories):
r1.append(c1)
r2.append(c2)
sizes.append(self.y_count_sq.iloc[i1, i2])
data = pd.DataFrame({'r1': r1, 'r2': r2, 'sizes': sizes})
sns.scatterplot(data=data, x="r1", y="r2", size="sizes", sizes=(min(sizes), max(sizes)*max_size_ratio), legend=False)
for i in range(len(data)):
plt.text(data['r1'][i], data['r2'][i], data['sizes'][i], ha='center', va='center')
df0 = pd.DataFrame({'r1': r1, 'r2': r2, 'sizes': sizes})
if hist:
sns.jointplot(
data=df0, x="r1", y="r2", kind="scatter",
height=5, ratio=3, marginal_ticks=True,
marginal_kws={"weights": sizes, "shrink":.5},
joint_kws={"size": sizes, "legend": False, "sizes":(min(sizes), max(sizes)*max_size_ratio)}
)
#sns.jointplot(data=df0, x="r1", y="r2", size="sizes", kind="scatter")
else:
sns.scatterplot(data=df0, x="r1", y="r2", size="sizes", sizes=(min(sizes), max(sizes)*max_size_ratio), legend=False)
tmp1 = plt.xlim()
tmp1d = ((tmp1[1] - tmp1[0])/n_categories)
plt.xlim(tmp1[0] - tmp1d, tmp1[1] + tmp1d)
plt.ylim(tmp1[0] - tmp1d, tmp1[1] + tmp1d)

for i in range(len(df0)):
plt.text(df0['r1'][i], df0['r2'][i], df0['sizes'][i], ha='center', va='center')

if not axis_label:
axis_label = ['Rater 1', 'Rater 2']
plt.xlabel(axis_label[0])
plt.ylabel(axis_label[1])
if not title:
title = 'Bubble Plot'
plt.title(title)
tmp1 = plt.xlim()
tmp1d = ((tmp1[1] - tmp1[0])/n_categories)
plt.xlim(tmp1[0] - tmp1d, tmp1[1] + tmp1d)
plt.ylim(tmp1[0] - tmp1d, tmp1[1] + tmp1d)

plt.tight_layout()
if out_path:
try:
Expand All @@ -317,5 +328,5 @@ def create_bubble_plot(self, out_path="", title="", axis_label=[], max_size_rati
print("Number of rating categories: "+str(kappa.n_category))
print("Number of sample: "+str(kappa.y_count.shape[0]))

kappa.create_bubble_plot(out_path='statlab_kappa_fig1.svg')
kappa.create_bubble_plot(hist=True)

0 comments on commit 42d2922

Please sign in to comment.