diff --git a/latentscope/scripts/umapper.py b/latentscope/scripts/umapper.py index 6f5c064..3a20441 100644 --- a/latentscope/scripts/umapper.py +++ b/latentscope/scripts/umapper.py @@ -20,11 +20,13 @@ def main(): parser.add_argument('embedding_id', type=str, help='Name of embedding model to use') parser.add_argument('neighbors', type=int, nargs="?", help='Output file', default=25) parser.add_argument('min_dist', type=float, nargs="?", help='Output file', default=0.075) + parser.add_argument('--init', type=str, help='Initialize with UMAP', default=None) + parser.add_argument('--align', type=str, help='Align UMAP with multiple embeddings', default=None) parser.add_argument('--save', action='store_true', help='Save the UMAP model') # Parse arguments args = parser.parse_args() - umapper(args.dataset_id, args.embedding_id, args.neighbors, args.min_dist, args.save) + umapper(args.dataset_id, args.embedding_id, args.neighbors, args.min_dist, save=args.save, init=args.init, align=args.align) # TODO move this into shared space @@ -39,7 +41,7 @@ def calculate_point_size(num_points, min_size=10, max_size=30, base_num_points=1 return min(min_size + min_size * np.log(num_points / base_num_points), max_size) -def umapper(dataset_id, embedding_id, neighbors=25, min_dist=0.1, save=False): +def umapper(dataset_id, embedding_id, neighbors=25, min_dist=0.1, save=False, init=None, align=None): DATA_DIR = get_data_dir() # read in the embeddings print("loading embeddings") @@ -63,62 +65,117 @@ def umapper(dataset_id, embedding_id, neighbors=25, min_dist=0.1, save=False): next_umap_number = 1 # make the umap name from the number, zero padded to 3 digits - umap_name = f"umap-{next_umap_number:03d}" - print("RUNNING:", umap_name) - - reducer = umap.UMAP( - n_neighbors=neighbors, - min_dist=min_dist, - metric='cosine', - random_state=42, - n_components=2, - verbose=True, - ) - + umap_id = f"umap-{next_umap_number:03d}" + print("RUNNING:", umap_id) + + def process_umap_embeddings(umap_id, umap_embeddings, emb_id): + min_values = np.min(umap_embeddings, axis=0) + max_values = np.max(umap_embeddings, axis=0) + + # Scale the embeddings to the range [0, 1] + umap_embeddings = (umap_embeddings - min_values) / (max_values - min_values) + + # Scale the embeddings to the range [-1, 1] + umap_embeddings = 2 * umap_embeddings - 1 + + print("writing normalized umap", umap_id) + # save umap embeddings to a parquet file with columns x,y + df = pd.DataFrame(umap_embeddings, columns=['x', 'y']) + output_file = os.path.join(umap_dir, f"{umap_id}.parquet") + df.to_parquet(output_file) + print("wrote", output_file) + + # generate a scatterplot of the umap embeddings and save it to a file + fig, ax = plt.subplots(figsize=(14.22, 14.22)) # 1024px by 1024px at 72 dpi + point_size = calculate_point_size(umap_embeddings.shape[0]) + print("POINT SIZE", point_size, "for", umap_embeddings.shape[0], "points") + plt.scatter(umap_embeddings[:, 0], umap_embeddings[:, 1], s=point_size, alpha=0.5) + plt.axis('off') # remove axis + plt.gca().set_position([0, 0, 1, 1]) # remove margins + plt.savefig(os.path.join(umap_dir, f"{umap_id}.png")) + + # save a json file with the umap parameters + with open(os.path.join(umap_dir, f'{umap_id}.json'), 'w') as f: + json.dump({ + "id": umap_id, + "embedding_id": emb_id, + "neighbors": neighbors, + "min_dist": min_dist, + "init": init, + "align": f"{embedding_id},{align}" if align is not None else None, + }, f, indent=2) + f.close() + + + if align is not None and align != "": + print("aligned umap", align) + # split the align string into umap names + embs = align.split(",") + # load each embedding from its h5 file + a_embedding_ids = [embedding_id] + a_embeddings = [embeddings] + for emb in embs: + print("loading", emb) + emb_path = os.path.join(DATA_DIR, dataset_id, "embeddings", f"{emb}.h5") + with h5py.File(emb_path, 'r') as f: + dataset = f["embeddings"] + a_emb = np.array(dataset) + print("loaded", emb, "shape", a_emb.shape) + a_embeddings.append(a_emb) + a_embedding_ids.append(emb) + + reducer = umap.AlignedUMAP( + n_neighbors=neighbors, + min_dist=min_dist, + metric='cosine', + random_state=42, + n_components=2, + verbose=True, + ) + print("a_embeddings", len(a_embeddings), len(a_embeddings[0])) + relations = [{j: j for j in range(len(a_embeddings[i]))} for i in range(len(a_embeddings)-1)] + print("relations", len(relations)) + aligned = reducer.fit_transform(a_embeddings, relations=relations) + print("ALIGNED", aligned) + for i,emb in enumerate(a_embedding_ids): + print("processing", emb, "umap", next_umap_number+i) + process_umap_embeddings(f"umap-{next_umap_number+i:03d}", aligned[i], emb) + + print("done with aligned umap") + return + + if init is not None and init != "": + print("loading umap", init) + initial_df = pd.read_parquet(os.path.join(umap_dir, f"{init}.parquet")) + initial = initial_df.to_numpy() + print("initial shape", initial.shape) + reducer = umap.UMAP( + init=initial, + n_neighbors=neighbors, + min_dist=min_dist, + metric='cosine', + random_state=42, + n_components=2, + verbose=True, + ) + else: + reducer = umap.UMAP( + n_neighbors=neighbors, + min_dist=min_dist, + metric='cosine', + random_state=42, + n_components=2, + verbose=True, + ) print("reducing", embeddings.shape[0], "embeddings to 2 dimensions") umap_embeddings = reducer.fit_transform(embeddings) - - min_values = np.min(umap_embeddings, axis=0) - max_values = np.max(umap_embeddings, axis=0) - - # Scale the embeddings to the range [0, 1] - umap_embeddings = (umap_embeddings - min_values) / (max_values - min_values) - - # Scale the embeddings to the range [-1, 1] - umap_embeddings = 2 * umap_embeddings - 1 - - print("writing normalized umap") - # save umap embeddings to a parquet file with columns x,y - df = pd.DataFrame(umap_embeddings, columns=['x', 'y']) - output_file = os.path.join(umap_dir, f"{umap_name}.parquet") - df.to_parquet(output_file) - print("wrote", output_file) - - # generate a scatterplot of the umap embeddings and save it to a file - - fig, ax = plt.subplots(figsize=(14.22, 14.22)) # 1024px by 1024px at 72 dpi - point_size = calculate_point_size(umap_embeddings.shape[0]) - print("POINT SIZE", point_size, "for", umap_embeddings.shape[0], "points") - plt.scatter(umap_embeddings[:, 0], umap_embeddings[:, 1], s=point_size, alpha=0.5) - plt.axis('off') # remove axis - plt.gca().set_position([0, 0, 1, 1]) # remove margins - plt.savefig(os.path.join(umap_dir, f"{umap_name}.png")) - - # save a json file with the umap parameters - with open(os.path.join(umap_dir, f'{umap_name}.json'), 'w') as f: - json.dump({ - "id": umap_name, - "embedding_id": embedding_id, - "neighbors": neighbors, - "min_dist": min_dist - }, f, indent=2) - f.close() + process_umap_embeddings(umap_id, umap_embeddings, embedding_id) if save: # save a pickle of the umap - with open(os.path.join(umap_dir, f'{umap_name}.pkl'), 'wb') as f: + with open(os.path.join(umap_dir, f'{umap_id}.pkl'), 'wb') as f: pickle.dump(reducer, f) - print("done with", umap_name) + print("done with", umap_id) if __name__ == "__main__": diff --git a/latentscope/server/jobs.py b/latentscope/server/jobs.py index c9ecf84..2c2edc1 100644 --- a/latentscope/server/jobs.py +++ b/latentscope/server/jobs.py @@ -207,10 +207,11 @@ def run_umap(): embedding_id = request.args.get('embedding_id') neighbors = request.args.get('neighbors') min_dist = request.args.get('min_dist') - print("run umap", dataset, embedding_id, neighbors, min_dist) + init = request.args.get('init') + print("run umap", dataset, embedding_id, neighbors, min_dist, init) job_id = str(uuid.uuid4()) - command = f'ls-umap {dataset} {embedding_id} {neighbors} {min_dist}' + command = f'ls-umap {dataset} {embedding_id} {neighbors} {min_dist} --init={init}' threading.Thread(target=run_job, args=(dataset, job_id, command)).start() return jsonify({"job_id": job_id}) diff --git a/web/src/components/Setup/Umap.jsx b/web/src/components/Setup/Umap.jsx index 19c8233..18f798e 100644 --- a/web/src/components/Setup/Umap.jsx +++ b/web/src/components/Setup/Umap.jsx @@ -13,6 +13,7 @@ Umap.propTypes = { }).isRequired, umap: PropTypes.object, embedding: PropTypes.object, + embeddings: PropTypes.array.isRequired, clusters: PropTypes.array.isRequired, onNew: PropTypes.func.isRequired, onChange: PropTypes.func.isRequired, @@ -20,11 +21,13 @@ Umap.propTypes = { // This component is responsible for the embeddings state // New embeddings update the list -function Umap({ dataset, umap, embedding, clusters, onNew, onChange}) { +function Umap({ dataset, umap, embedding, embeddings, clusters, onNew, onChange}) { const [umapJob, setUmapJob] = useState(null); const { startJob: startUmapJob } = useStartJobPolling(dataset, setUmapJob, `${apiUrl}/jobs/umap`); const { startJob: deleteUmapJob } = useStartJobPolling(dataset, setUmapJob, `${apiUrl}/jobs/delete/umap`); + const [init, setInit] = useState("") + const [umaps, setUmaps] = useState([]); function fetchUmaps(datasetId, callback) { fetch(`${apiUrl}/datasets/${datasetId}/umaps`) @@ -36,7 +39,7 @@ function Umap({ dataset, umap, embedding, clusters, onNew, onChange}) { url: `${apiUrl}/files/${datasetId}/umaps/${d.id}.png`, } }) - callback(array.reverse()) + callback(array) }); } useEffect(() => { @@ -61,14 +64,19 @@ function Umap({ dataset, umap, embedding, clusters, onNew, onChange}) { } }, [umapJob, dataset, setUmaps, onNew]); + + const handleChangeInit = useCallback((e) => { + setInit(e.target.value) + }, [setInit]) + const handleNewUmap = useCallback((e) => { e.preventDefault() const form = e.target const data = new FormData(form) const neighbors = data.get('neighbors') const min_dist = data.get('min_dist') - startUmapJob({embedding_id: embedding?.id, neighbors, min_dist}) - }, [startUmapJob, embedding]) + startUmapJob({embedding_id: embedding?.id, neighbors, min_dist, init}) + }, [startUmapJob, embedding, init]) return (
@@ -81,6 +89,19 @@ function Umap({ dataset, umap, embedding, clusters, onNew, onChange}) { Min Dist: + setUmapJob(null)}/> diff --git a/web/src/pages/Setup.jsx b/web/src/pages/Setup.jsx index ee86a54..2b6a290 100644 --- a/web/src/pages/Setup.jsx +++ b/web/src/pages/Setup.jsx @@ -373,16 +373,33 @@ function Setup() { const prevPointConfig = useRef() const [drawPoints, setDrawPoints] = useState([]); useEffect(() => { + // if cluster and umap are ready and haven't been drawn we update the draw points + // we also set hulls to empty so they can animate in if(clusterIndices.length && umap && cluster && cluster.umap_id == umap.id && clusterIndices.cluster_id == cluster.id) { const config = umap.id + cluster.id + clusterIndices.cluster_id if(prevPointConfig.current !== config) { fetch(`${apiUrl}/datasets/${datasetId}/umaps/${umap.id}/points`).then(response => response.json()).then(data => { let pts = data.map((d,i) => [d.x, d.y, clusterIndices[i].cluster]) setDrawPoints(pts) + // TODO: this doesn't always work out in the right timing + // the other useEffect above should be tied to this one somehow setHulls([]) }) prevPointConfig.current = config } + } else if(umap && !cluster) { + const config = umap.id + if(prevPointConfig.current !== config) { + fetch(`${apiUrl}/datasets/${datasetId}/umaps/${umap.id}/points`).then(response => response.json()).then(data => { + let pts = data.map((d) => [d.x, d.y, -1]) + setDrawPoints(pts) + setHulls([]) + }) + prevPointConfig.current = config + } + } else if(!umap && !cluster) { + setHulls([]) + setDrawPoints([]) } }, [clusterIndices, cluster, umap]) @@ -502,6 +519,7 @@ function Setup() { dataset={dataset} umap={umap} embedding={embedding} + embeddings={embeddings} clusters={clusters} onNew={handleNewUmaps} onChange={(ump) => dispatch({type: "SET_UMAP", payload: ump})} />