Skip to content

Commit

Permalink
render cluster hulls and colors. improve styles on setup and explore.…
Browse files Browse the repository at this point in the history
… handle dark mode
  • Loading branch information
enjalot committed Feb 1, 2024
1 parent 6315494 commit 23f6640
Show file tree
Hide file tree
Showing 23 changed files with 574 additions and 481 deletions.
8 changes: 5 additions & 3 deletions latentscope/scripts/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,11 @@ def clusterer(dataset_id, umap_id, samples, min_samples):
# plot a convex hull around each cluster
hulls = []
for label in non_noise_labels:
points = umap_embeddings[cluster_labels == label]
indices = np.where(cluster_labels == label)[0]
points = umap_embeddings[indices]
# points = umap_embeddings[cluster_labels == label]
hull = ConvexHull(points)
hull_list = hull.vertices.tolist()
hull_list = [indices[s] for s in hull.vertices.tolist()]
hulls.append(hull_list)
for simplex in hull.simplices:
plt.plot(points[simplex, 0], points[simplex, 1], 'k-')
Expand Down Expand Up @@ -131,7 +133,7 @@ def clusterer(dataset_id, umap_id, samples, min_samples):
for cluster, indices in cluster_indices.items():
label = f"Cluster {cluster}"
description = f"This is cluster {cluster} with {len(indices)} items."
new_row = pd.DataFrame({'label': [label], 'description': [description], 'indices': [list(indices)], 'hulls': [hulls[cluster]]})
new_row = pd.DataFrame({'label': [label], 'description': [description], 'indices': [list(indices)], 'hull': [hulls[cluster]]})
slides_df = pd.concat([slides_df, new_row], ignore_index=True)

# write the df to parquet
Expand Down
26 changes: 15 additions & 11 deletions latentscope/server/search.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
import sys
import json
import numpy as np
from flask import Blueprint, jsonify, request

Expand All @@ -11,7 +11,7 @@

# in memory cache of dataset metadata, embeddings, models and tokenizers
DATASETS = {}
MODELS = {}
EMBEDDINGS = {}

"""
Returns nearest neighbors for a given query string
Expand All @@ -20,29 +20,33 @@
@search_bp.route('/nn', methods=['GET'])
def nn():
dataset = request.args.get('dataset')
model_id = request.args.get('model')
embedding_id = request.args.get('embedding_id')

num = 150
if model_id not in MODELS:
print("loading model", model_id)
if embedding_id not in EMBEDDINGS:
print("loading model", embedding_id)
with open(os.path.join(DATA_DIR, dataset, "embeddings", embedding_id + ".json"), 'r') as f:
metadata = json.load(f)
model_id = metadata.get('model_id')
print("Model ID:", model_id)
model = get_embedding_model(model_id)
model.load_model()
MODELS[model_id] = model
EMBEDDINGS[embedding_id] = model
else:
model = MODELS[model_id]
model = EMBEDDINGS[embedding_id]

if dataset not in DATASETS or model_id not in DATASETS[dataset]:
if dataset not in DATASETS or embedding_id not in DATASETS[dataset]:
# load the dataset embeddings
embeddings = np.load(os.path.join(DATA_DIR, dataset, "embeddings", model_id + ".npy"))
embeddings = np.load(os.path.join(DATA_DIR, dataset, "embeddings", embedding_id + ".npy"))
print("fitting embeddings")
from sklearn.neighbors import NearestNeighbors
nne = NearestNeighbors(n_neighbors=num, metric="cosine")
nne.fit(embeddings)
if dataset not in DATASETS:
DATASETS[dataset] = {}
DATASETS[dataset][model_id] = nne
DATASETS[dataset][embedding_id] = nne
else:
nne = DATASETS[dataset][model_id]
nne = DATASETS[dataset][embedding_id]

# embed the query string and find the nearest neighbor
query = request.args.get('query')
Expand Down
Loading

0 comments on commit 23f6640

Please sign in to comment.