Skip to content

Commit

Permalink
save current state
Browse files Browse the repository at this point in the history
  • Loading branch information
ajtritt committed Jul 7, 2022
1 parent 1585876 commit 1eb3d07
Show file tree
Hide file tree
Showing 5 changed files with 538 additions and 628 deletions.
2 changes: 1 addition & 1 deletion activ/cca/viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,6 @@ def cross_decomp_scatter(x, y, regressor=LinearRegression(), labels=None, fitlin
yfit = regressor.predict(xfit)

ax.plot(xfit, yfit, color='black')
x_pos, y_pos = (0.5, 0.1)
x_pos, y_pos = (0.7, 0.1)
ax.text(x_pos, y_pos, "$R^2$ = %0.2f" % cv_r2, size=fontsize, transform=ax.transAxes)
return ax
41 changes: 32 additions & 9 deletions activ/clustering/viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def shade_around(data, ax=None, stretch=None, alpha=0.2, color='blue', convex=Tr
# zi = interpolator(Xi, Yi)
# cfset = ax.contourf(xi, yi, zi, colors=color, alpha=alpha)

cfset = ax.tricontourf(x, y, z, colors=color, alpha=alpha)
cfset = ax.tricontourf(x, y, z, colors=color, alpha=alpha, zorder=0)
else:
ch = concave_hull(data, 16)
gpd.GeoSeries([ch]).plot(ax=ax, alpha=alpha, color=color, edgecolor='none')
Expand Down Expand Up @@ -459,7 +459,7 @@ def plot_max1d_simdata_results(path, ax=None, flip=False, fontsize='x-large'):
sim_sweep_plot(est_noc_max1d, true_noc, ax=ax, flip=flip, fontsize=fontsize)


def make_clustered_plot(emb, n_clusters, feature_colors, weights=None, ax=None, stretch=None, fs='x-large', **kwargs):
def make_clustered_plot(emb, n_clusters, feature_colors, weights=None, ax=None, stretch=None, fs='x-large', add_labels=False, highlight=0, **kwargs):
"""
Plot 2-D UMAP embedding. Cluster embedding and shade around the
resulting clusters. Pass in weights to plot points as pie charts
Expand All @@ -474,6 +474,28 @@ def make_clustered_plot(emb, n_clusters, feature_colors, weights=None, ax=None,
else:
cluster_plot(emb, labels, colors=feature_colors, ax=ax, stretch=stretch, fs=fs, **kwargs)

if add_labels:
for label in np.unique(labels):
mask = labels == label
center = np.mean(emb[mask], axis=0)
ax.text(center[0], center[1], str(label), fontsize=fs)
elif highlight > 0:
if weights is None:
raise ValueError("to highlight clusters, please provide weights for determining homogeneity")
print(np.unique(labels, return_counts=True))
label_var = np.zeros(max(labels) + 1)
for label in np.unique(labels):
mask = labels == label
label_var[label] = np.var(weights[mask], axis=0).sum()
print('label_var before', label_var)
label_var = np.argsort(label_var)[:highlight]
print('label_var after', label_var)
for label in label_var:
mask = labels == label
print(mask.sum())
shade_around(emb[mask], ax=ax, color='lightgrey', convex=True)

return labels

def get_real_noc(tested_noc, foc, smooth=True, use_median=False, spread_asm=True, spread_foc=True):
ret = dict()
Expand Down Expand Up @@ -587,7 +609,7 @@ def plot_real_foc_results(path, ax=None, max1d_cutoff=False, ci=None, n_sigma=1,

ax.set_ylabel(ylabel, fontsize=fontsize)
ax.set_xlabel("# Outcome clusters", fontsize=fontsize)
return est
return est, med


def plot_real_accuracy_chance_results(path, ax=None, fontsize='x-large', ylabel="Prediction Accuracy"):
Expand All @@ -600,16 +622,16 @@ def plot_real_accuracy_chance_results(path, ax=None, fontsize='x-large', ylabel=


x_noc, lower, med, upper = flatten_summarize(np.arange(2, 51),
real_chance, smooth=False,
real_accuracy, smooth=False,
iqr=True)

plot_line(x_noc, med, lower=lower, upper=upper, ax=ax, color='gray')
plot_line(x_noc, med, lower=lower, upper=upper, ax=ax, color='black', label='Accuracy')
raw = med

x_noc, lower, med, upper = flatten_summarize(np.arange(2, 51),
real_accuracy, smooth=False,
real_chance, smooth=False,
iqr=True)

plot_line(x_noc, med, lower=lower, upper=upper, ax=ax, color='black')
plot_line(x_noc, med, lower=lower, upper=upper, ax=ax, color='gray', label='Chance accuracy')
chance = med

yticks = np.array([2, 4, 6, 8, 10], dtype=np.int)
ax.set_yticks(yticks/10)
Expand All @@ -620,6 +642,7 @@ def plot_real_accuracy_chance_results(path, ax=None, fontsize='x-large', ylabel=

ax.set_ylabel(ylabel, fontsize=fontsize)
ax.set_xlabel("# Outcome clusters", fontsize=fontsize)
return raw, chance


def entropy_across_clusterings(variable, labels):
Expand Down
4 changes: 2 additions & 2 deletions activ/nmf/viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,8 +430,8 @@ def plot_umap_nmf_piechart(weights, umap_emb, s=100, ax=None, fontsize=None, pal
"""
weights_pie_scatter(weights, umap_emb, s=s, ax=ax, palette=palette)
ax.tick_params(labelsize=fontsize)
ax.set_xlabel('UMAP dimesion 1', fontsize=fontsize)
ax.set_ylabel('UMAP dimesion 2', fontsize=fontsize)
ax.set_xlabel('UMAP dimension 1', fontsize=fontsize)
ax.set_ylabel('UMAP dimension 2', fontsize=fontsize)


def plot_umap_nmf_max(emb, weights, bases_labels, right=False, min_dist=0.0, legend=True, ax=None, palette=None):
Expand Down
4 changes: 4 additions & 0 deletions activ/viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@
from .analytics import cv_r2_score, linefit, _check_X_y


def remove_ticks(ax):
ax.tick_params(axis='x', which='both', bottom=False, top=False, labelbottom=False)
ax.tick_params(axis='y', which='both', left=False, right=False, labelleft=False)


def get_nmf_colors():
"""
Expand Down
Loading

0 comments on commit 1eb3d07

Please sign in to comment.