Skip to content

Commit

Permalink
support aligned umap in backend
Browse files Browse the repository at this point in the history
  • Loading branch information
enjalot committed Feb 28, 2024
1 parent bde255f commit c709a45
Show file tree
Hide file tree
Showing 4 changed files with 155 additions and 58 deletions.
161 changes: 109 additions & 52 deletions latentscope/scripts/umapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,13 @@ def main():
parser.add_argument('embedding_id', type=str, help='Name of embedding model to use')
parser.add_argument('neighbors', type=int, nargs="?", help='Output file', default=25)
parser.add_argument('min_dist', type=float, nargs="?", help='Output file', default=0.075)
parser.add_argument('--init', type=str, help='Initialize with UMAP', default=None)
parser.add_argument('--align', type=str, help='Align UMAP with multiple embeddings', default=None)
parser.add_argument('--save', action='store_true', help='Save the UMAP model')

# Parse arguments
args = parser.parse_args()
umapper(args.dataset_id, args.embedding_id, args.neighbors, args.min_dist, args.save)
umapper(args.dataset_id, args.embedding_id, args.neighbors, args.min_dist, save=args.save, init=args.init, align=args.align)


# TODO move this into shared space
Expand All @@ -39,7 +41,7 @@ def calculate_point_size(num_points, min_size=10, max_size=30, base_num_points=1
return min(min_size + min_size * np.log(num_points / base_num_points), max_size)


def umapper(dataset_id, embedding_id, neighbors=25, min_dist=0.1, save=False):
def umapper(dataset_id, embedding_id, neighbors=25, min_dist=0.1, save=False, init=None, align=None):
DATA_DIR = get_data_dir()
# read in the embeddings
print("loading embeddings")
Expand All @@ -63,62 +65,117 @@ def umapper(dataset_id, embedding_id, neighbors=25, min_dist=0.1, save=False):
next_umap_number = 1

# make the umap name from the number, zero padded to 3 digits
umap_name = f"umap-{next_umap_number:03d}"
print("RUNNING:", umap_name)

reducer = umap.UMAP(
n_neighbors=neighbors,
min_dist=min_dist,
metric='cosine',
random_state=42,
n_components=2,
verbose=True,
)

umap_id = f"umap-{next_umap_number:03d}"
print("RUNNING:", umap_id)

def process_umap_embeddings(umap_id, umap_embeddings, emb_id):
min_values = np.min(umap_embeddings, axis=0)
max_values = np.max(umap_embeddings, axis=0)

# Scale the embeddings to the range [0, 1]
umap_embeddings = (umap_embeddings - min_values) / (max_values - min_values)

# Scale the embeddings to the range [-1, 1]
umap_embeddings = 2 * umap_embeddings - 1

print("writing normalized umap", umap_id)
# save umap embeddings to a parquet file with columns x,y
df = pd.DataFrame(umap_embeddings, columns=['x', 'y'])
output_file = os.path.join(umap_dir, f"{umap_id}.parquet")
df.to_parquet(output_file)
print("wrote", output_file)

# generate a scatterplot of the umap embeddings and save it to a file
fig, ax = plt.subplots(figsize=(14.22, 14.22)) # 1024px by 1024px at 72 dpi
point_size = calculate_point_size(umap_embeddings.shape[0])
print("POINT SIZE", point_size, "for", umap_embeddings.shape[0], "points")
plt.scatter(umap_embeddings[:, 0], umap_embeddings[:, 1], s=point_size, alpha=0.5)
plt.axis('off') # remove axis
plt.gca().set_position([0, 0, 1, 1]) # remove margins
plt.savefig(os.path.join(umap_dir, f"{umap_id}.png"))

# save a json file with the umap parameters
with open(os.path.join(umap_dir, f'{umap_id}.json'), 'w') as f:
json.dump({
"id": umap_id,
"embedding_id": emb_id,
"neighbors": neighbors,
"min_dist": min_dist,
"init": init,
"align": f"{embedding_id},{align}" if align is not None else None,
}, f, indent=2)
f.close()


if align is not None and align != "":
print("aligned umap", align)
# split the align string into umap names
embs = align.split(",")
# load each embedding from its h5 file
a_embedding_ids = [embedding_id]
a_embeddings = [embeddings]
for emb in embs:
print("loading", emb)
emb_path = os.path.join(DATA_DIR, dataset_id, "embeddings", f"{emb}.h5")
with h5py.File(emb_path, 'r') as f:
dataset = f["embeddings"]
a_emb = np.array(dataset)
print("loaded", emb, "shape", a_emb.shape)
a_embeddings.append(a_emb)
a_embedding_ids.append(emb)

reducer = umap.AlignedUMAP(
n_neighbors=neighbors,
min_dist=min_dist,
metric='cosine',
random_state=42,
n_components=2,
verbose=True,
)
print("a_embeddings", len(a_embeddings), len(a_embeddings[0]))
relations = [{j: j for j in range(len(a_embeddings[i]))} for i in range(len(a_embeddings)-1)]
print("relations", len(relations))
aligned = reducer.fit_transform(a_embeddings, relations=relations)
print("ALIGNED", aligned)
for i,emb in enumerate(a_embedding_ids):
print("processing", emb, "umap", next_umap_number+i)
process_umap_embeddings(f"umap-{next_umap_number+i:03d}", aligned[i], emb)

print("done with aligned umap")
return

if init is not None and init != "":
print("loading umap", init)
initial_df = pd.read_parquet(os.path.join(umap_dir, f"{init}.parquet"))
initial = initial_df.to_numpy()
print("initial shape", initial.shape)
reducer = umap.UMAP(
init=initial,
n_neighbors=neighbors,
min_dist=min_dist,
metric='cosine',
random_state=42,
n_components=2,
verbose=True,
)
else:
reducer = umap.UMAP(
n_neighbors=neighbors,
min_dist=min_dist,
metric='cosine',
random_state=42,
n_components=2,
verbose=True,
)
print("reducing", embeddings.shape[0], "embeddings to 2 dimensions")
umap_embeddings = reducer.fit_transform(embeddings)

min_values = np.min(umap_embeddings, axis=0)
max_values = np.max(umap_embeddings, axis=0)

# Scale the embeddings to the range [0, 1]
umap_embeddings = (umap_embeddings - min_values) / (max_values - min_values)

# Scale the embeddings to the range [-1, 1]
umap_embeddings = 2 * umap_embeddings - 1

print("writing normalized umap")
# save umap embeddings to a parquet file with columns x,y
df = pd.DataFrame(umap_embeddings, columns=['x', 'y'])
output_file = os.path.join(umap_dir, f"{umap_name}.parquet")
df.to_parquet(output_file)
print("wrote", output_file)

# generate a scatterplot of the umap embeddings and save it to a file

fig, ax = plt.subplots(figsize=(14.22, 14.22)) # 1024px by 1024px at 72 dpi
point_size = calculate_point_size(umap_embeddings.shape[0])
print("POINT SIZE", point_size, "for", umap_embeddings.shape[0], "points")
plt.scatter(umap_embeddings[:, 0], umap_embeddings[:, 1], s=point_size, alpha=0.5)
plt.axis('off') # remove axis
plt.gca().set_position([0, 0, 1, 1]) # remove margins
plt.savefig(os.path.join(umap_dir, f"{umap_name}.png"))

# save a json file with the umap parameters
with open(os.path.join(umap_dir, f'{umap_name}.json'), 'w') as f:
json.dump({
"id": umap_name,
"embedding_id": embedding_id,
"neighbors": neighbors,
"min_dist": min_dist
}, f, indent=2)
f.close()
process_umap_embeddings(umap_id, umap_embeddings, embedding_id)

if save:
# save a pickle of the umap
with open(os.path.join(umap_dir, f'{umap_name}.pkl'), 'wb') as f:
with open(os.path.join(umap_dir, f'{umap_id}.pkl'), 'wb') as f:
pickle.dump(reducer, f)
print("done with", umap_name)
print("done with", umap_id)


if __name__ == "__main__":
Expand Down
5 changes: 3 additions & 2 deletions latentscope/server/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,10 +207,11 @@ def run_umap():
embedding_id = request.args.get('embedding_id')
neighbors = request.args.get('neighbors')
min_dist = request.args.get('min_dist')
print("run umap", dataset, embedding_id, neighbors, min_dist)
init = request.args.get('init')
print("run umap", dataset, embedding_id, neighbors, min_dist, init)

job_id = str(uuid.uuid4())
command = f'ls-umap {dataset} {embedding_id} {neighbors} {min_dist}'
command = f'ls-umap {dataset} {embedding_id} {neighbors} {min_dist} --init={init}'
threading.Thread(target=run_job, args=(dataset, job_id, command)).start()
return jsonify({"job_id": job_id})

Expand Down
29 changes: 25 additions & 4 deletions web/src/components/Setup/Umap.jsx
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,21 @@ Umap.propTypes = {
}).isRequired,
umap: PropTypes.object,
embedding: PropTypes.object,
embeddings: PropTypes.array.isRequired,
clusters: PropTypes.array.isRequired,
onNew: PropTypes.func.isRequired,
onChange: PropTypes.func.isRequired,
};

// This component is responsible for the embeddings state
// New embeddings update the list
function Umap({ dataset, umap, embedding, clusters, onNew, onChange}) {
function Umap({ dataset, umap, embedding, embeddings, clusters, onNew, onChange}) {
const [umapJob, setUmapJob] = useState(null);
const { startJob: startUmapJob } = useStartJobPolling(dataset, setUmapJob, `${apiUrl}/jobs/umap`);
const { startJob: deleteUmapJob } = useStartJobPolling(dataset, setUmapJob, `${apiUrl}/jobs/delete/umap`);

const [init, setInit] = useState("")

const [umaps, setUmaps] = useState([]);
function fetchUmaps(datasetId, callback) {
fetch(`${apiUrl}/datasets/${datasetId}/umaps`)
Expand All @@ -36,7 +39,7 @@ function Umap({ dataset, umap, embedding, clusters, onNew, onChange}) {
url: `${apiUrl}/files/${datasetId}/umaps/${d.id}.png`,
}
})
callback(array.reverse())
callback(array)
});
}
useEffect(() => {
Expand All @@ -61,14 +64,19 @@ function Umap({ dataset, umap, embedding, clusters, onNew, onChange}) {
}
}, [umapJob, dataset, setUmaps, onNew]);


const handleChangeInit = useCallback((e) => {
setInit(e.target.value)
}, [setInit])

const handleNewUmap = useCallback((e) => {
e.preventDefault()
const form = e.target
const data = new FormData(form)
const neighbors = data.get('neighbors')
const min_dist = data.get('min_dist')
startUmapJob({embedding_id: embedding?.id, neighbors, min_dist})
}, [startUmapJob, embedding])
startUmapJob({embedding_id: embedding?.id, neighbors, min_dist, init})
}, [startUmapJob, embedding, init])

return (
<div className="dataset--umaps-new">
Expand All @@ -81,6 +89,19 @@ function Umap({ dataset, umap, embedding, clusters, onNew, onChange}) {
Min Dist:
<input type="text" name="min_dist" defaultValue="0.1" disabled={!!umapJob} />
</label>
<label>
Initialize from UMAP:
<select name="init" disabled={!!umapJob} onChange={handleChangeInit}>
<option value="">None</option>
{umaps.map((um, index) => {
let emb = embeddings.find(d => um.embedding_id == d.id)
return (
<option key={index} value={um.id}>
{um.embedding_id} - {um.id} - {emb?.model_id} [{emb?.dimensions}]
</option>
)})}
</select>
</label>
<button type="submit" disabled={!!umapJob}>New UMAP</button>
</form>
<JobProgress job={umapJob} clearJob={()=> setUmapJob(null)}/>
Expand Down
18 changes: 18 additions & 0 deletions web/src/pages/Setup.jsx
Original file line number Diff line number Diff line change
Expand Up @@ -373,16 +373,33 @@ function Setup() {
const prevPointConfig = useRef()
const [drawPoints, setDrawPoints] = useState([]);
useEffect(() => {
// if cluster and umap are ready and haven't been drawn we update the draw points
// we also set hulls to empty so they can animate in
if(clusterIndices.length && umap && cluster && cluster.umap_id == umap.id && clusterIndices.cluster_id == cluster.id) {
const config = umap.id + cluster.id + clusterIndices.cluster_id
if(prevPointConfig.current !== config) {
fetch(`${apiUrl}/datasets/${datasetId}/umaps/${umap.id}/points`).then(response => response.json()).then(data => {
let pts = data.map((d,i) => [d.x, d.y, clusterIndices[i].cluster])
setDrawPoints(pts)
// TODO: this doesn't always work out in the right timing
// the other useEffect above should be tied to this one somehow
setHulls([])
})
prevPointConfig.current = config
}
} else if(umap && !cluster) {
const config = umap.id
if(prevPointConfig.current !== config) {
fetch(`${apiUrl}/datasets/${datasetId}/umaps/${umap.id}/points`).then(response => response.json()).then(data => {
let pts = data.map((d) => [d.x, d.y, -1])
setDrawPoints(pts)
setHulls([])
})
prevPointConfig.current = config
}
} else if(!umap && !cluster) {
setHulls([])
setDrawPoints([])
}
}, [clusterIndices, cluster, umap])

Expand Down Expand Up @@ -502,6 +519,7 @@ function Setup() {
dataset={dataset}
umap={umap}
embedding={embedding}
embeddings={embeddings}
clusters={clusters}
onNew={handleNewUmaps} onChange={(ump) => dispatch({type: "SET_UMAP", payload: ump})} />
</Stage>
Expand Down

0 comments on commit c709a45

Please sign in to comment.