diff --git a/latentscope/server/datasets.py b/latentscope/server/datasets.py index c77ae35..91ef6cd 100644 --- a/latentscope/server/datasets.py +++ b/latentscope/server/datasets.py @@ -1,6 +1,7 @@ import os import re import json +import fnmatch import pandas as pd from flask import Blueprint, jsonify, request @@ -31,14 +32,14 @@ def get_datasets(): """ Get all metadata files from the given a directory. """ -def scan_for_json_files(directory_path): +def scan_for_json_files(directory_path, match_pattern=r".*\.json$"): try: files = os.listdir(directory_path) except OSError as err: print('Unable to scan directory:', err) return jsonify({"error": "Unable to scan directory"}), 500 - json_files = [file for file in files if file.endswith('.json')] + json_files = [file for file in files if re.match(match_pattern, file)] json_files.sort() print("files", files) print("json", json_files) @@ -72,28 +73,15 @@ def update_dataset_meta(dataset): json.dump(json_contents, json_file) return jsonify(json_contents) - @datasets_bp.route('//embeddings', methods=['GET']) def get_dataset_embeddings(dataset): directory_path = os.path.join(DATA_DIR, dataset, "embeddings") # directory_path = os.path.join(DATA_DIR, dataset, "umaps") - print("dataset", dataset, directory_path) return scan_for_json_files(directory_path) - # print("dataset", dataset, directory_path) - # try: - # files = sorted(os.listdir(directory_path), key=lambda x: os.path.getmtime(os.path.join(directory_path, x)), reverse=True) - # except OSError as err: - # print('Unable to scan directory:', err) - # return jsonify({"error": "Unable to scan directory"}), 500 - - # npy_files = [file.replace(".npy", "") for file in files if file.endswith('.npy')] - # return jsonify(npy_files) - @datasets_bp.route('//umaps', methods=['GET']) def get_dataset_umaps(dataset): directory_path = os.path.join(DATA_DIR, dataset, "umaps") - print("dataset", dataset, directory_path) return scan_for_json_files(directory_path) @datasets_bp.route('//umaps/', methods=['GET']) @@ -112,8 +100,7 @@ def get_dataset_umap_points(dataset, umap): @datasets_bp.route('//clusters', methods=['GET']) def get_dataset_clusters(dataset): directory_path = os.path.join(DATA_DIR, dataset, "clusters") - print("dataset", dataset, directory_path) - return scan_for_json_files(directory_path) + return scan_for_json_files(directory_path, match_pattern=r"cluster-\d+\.json") @datasets_bp.route('//clusters/', methods=['GET']) def get_dataset_cluster(dataset, cluster): @@ -136,25 +123,25 @@ def get_dataset_cluster_indices(dataset, cluster): df = pd.read_parquet(file_path) return df.to_json(orient="records") -@datasets_bp.route('//clusters//labels/', methods=['GET']) -def get_dataset_cluster_labels(dataset, cluster, model): +@datasets_bp.route('//clusters//labels/', methods=['GET']) +def get_dataset_cluster_labels(dataset, cluster, id): # if model == "default": # return get_dataset_cluster_labels_default(dataset, cluster) - file_name = cluster + "-labels-" + model + ".parquet" + file_name = cluster + "-labels-" + id + ".parquet" file_path = os.path.join(DATA_DIR, dataset, "clusters", file_name) df = pd.read_parquet(file_path) df.reset_index(inplace=True) return df.to_json(orient="records") -@datasets_write_bp.route('//clusters//labels//label/', methods=['GET']) -def overwrite_dataset_cluster_label(dataset, cluster, model, index): +@datasets_write_bp.route('//clusters//labels//label/', methods=['GET']) +def overwrite_dataset_cluster_label(dataset, cluster, id, index): index = int(index) new_label = request.args.get('label') print("write label", index, new_label) if new_label is None: return jsonify({"error": "Missing 'label' in request data"}), 400 - file_name = cluster + "-labels-" + model + ".parquet" + file_name = cluster + "-labels-" + id + ".parquet" file_path = os.path.join(DATA_DIR, dataset, "clusters", file_name) try: df = pd.read_parquet(file_path) @@ -173,15 +160,16 @@ def overwrite_dataset_cluster_label(dataset, cluster, model, index): @datasets_bp.route('//clusters//labels_available', methods=['GET']) def get_dataset_cluster_labels_available(dataset, cluster): directory_path = os.path.join(DATA_DIR, dataset, "clusters") - try: - files = sorted(os.listdir(directory_path), key=lambda x: os.path.getmtime(os.path.join(directory_path, x)), reverse=True) - except OSError as err: - print('Unable to scan directory:', err) - return jsonify({"error": "Unable to scan directory"}), 500 + return scan_for_json_files(directory_path, match_pattern=rf"{cluster}-labels-.*\.json") + # try: + # files = sorted(os.listdir(directory_path), key=lambda x: os.path.getmtime(os.path.join(directory_path, x)), reverse=True) + # except OSError as err: + # print('Unable to scan directory:', err) + # return jsonify({"error": "Unable to scan directory"}), 500 - pattern = re.compile(r'^' + cluster + '-labels-(.*).parquet$') - model_names = [pattern.match(file).group(1) for file in files if pattern.match(file)] - return jsonify(model_names) + # pattern = re.compile(r'^' + cluster + '-labels-(.*).parquet$') + # model_names = [pattern.match(file).group(1) for file in files if pattern.match(file)] + # return jsonify(model_names) def get_next_scopes_number(dataset): diff --git a/web/package-lock.json b/web/package-lock.json index 5679361..2a19aab 100644 --- a/web/package-lock.json +++ b/web/package-lock.json @@ -10,6 +10,10 @@ "dependencies": { "d3-scale": "^4.0.2", "d3-scale-chromatic": "^3.0.0", + "d3-selection": "^3.0.0", + "d3-shape": "^3.2.0", + "d3-transition": "^3.0.1", + "flubber": "^0.4.2", "react": "^18.2.0", "react-dom": "^18.2.0", "react-router-dom": "^6.20.1", @@ -1812,6 +1816,11 @@ "node": ">= 0.8" } }, + "node_modules/commander": { + "version": "2.20.3", + "resolved": "https://registry.npmjs.org/commander/-/commander-2.20.3.tgz", + "integrity": "sha512-GpVkmM8vF2vQUkj2LvZmD35JxeJOLCwJ9cUkugyk2nuhbv3+mJvpLYYt+0+USMxE+oj+ey/lJEnhZw75x/OMcQ==" + }, "node_modules/concat-map": { "version": "0.0.1", "resolved": "https://registry.npmjs.org/concat-map/-/concat-map-0.0.1.tgz", @@ -1868,6 +1877,22 @@ "node": ">=12" } }, + "node_modules/d3-dispatch": { + "version": "3.0.1", + "resolved": "https://registry.npmjs.org/d3-dispatch/-/d3-dispatch-3.0.1.tgz", + "integrity": "sha512-rzUyPU/S7rwUflMyLc1ETDeBj0NRuHKKAcvukozwhshr6g6c5d8zh4c2gQjY2bZ0dXeGLWc1PF174P2tVvKhfg==", + "engines": { + "node": ">=12" + } + }, + "node_modules/d3-ease": { + "version": "3.0.1", + "resolved": "https://registry.npmjs.org/d3-ease/-/d3-ease-3.0.1.tgz", + "integrity": "sha512-wR/XK3D3XcLIZwpbvQwQ5fK+8Ykds1ip7A2Txe0yxncXSdq1L9skcG7blcedkOX+ZcgxGAmLX1FrRGbADwzi0w==", + "engines": { + "node": ">=12" + } + }, "node_modules/d3-format": { "version": "3.1.0", "resolved": "https://registry.npmjs.org/d3-format/-/d3-format-3.1.0.tgz", @@ -1887,6 +1912,19 @@ "node": ">=12" } }, + "node_modules/d3-path": { + "version": "3.1.0", + "resolved": "https://registry.npmjs.org/d3-path/-/d3-path-3.1.0.tgz", + "integrity": "sha512-p3KP5HCf/bvjBSSKuXid6Zqijx7wIfNW+J/maPs+iwR35at5JCbLUT0LzF1cnjbCHWhqzQTIN2Jpe8pRebIEFQ==", + "engines": { + "node": ">=12" + } + }, + "node_modules/d3-polygon": { + "version": "1.0.6", + "resolved": "https://registry.npmjs.org/d3-polygon/-/d3-polygon-1.0.6.tgz", + "integrity": "sha512-k+RF7WvI08PC8reEoXa/w2nSg5AUMTi+peBD9cmFc+0ixHfbs4QmxxkarVal1IkVkgxVuk9JSHhJURHiyHKAuQ==" + }, "node_modules/d3-scale": { "version": "4.0.2", "resolved": "https://registry.npmjs.org/d3-scale/-/d3-scale-4.0.2.tgz", @@ -1914,6 +1952,25 @@ "node": ">=12" } }, + "node_modules/d3-selection": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/d3-selection/-/d3-selection-3.0.0.tgz", + "integrity": "sha512-fmTRWbNMmsmWq6xJV8D19U/gw/bwrHfNXxrIN+HfZgnzqTHp9jOmKMhsTUjXOJnZOdZY9Q28y4yebKzqDKlxlQ==", + "engines": { + "node": ">=12" + } + }, + "node_modules/d3-shape": { + "version": "3.2.0", + "resolved": "https://registry.npmjs.org/d3-shape/-/d3-shape-3.2.0.tgz", + "integrity": "sha512-SaLBuwGm3MOViRq2ABk3eLoxwZELpH6zhl3FbAoJ7Vm1gofKx6El1Ib5z23NUEhF9AsGl7y+dzLe5Cw2AArGTA==", + "dependencies": { + "d3-path": "^3.1.0" + }, + "engines": { + "node": ">=12" + } + }, "node_modules/d3-time": { "version": "3.1.0", "resolved": "https://registry.npmjs.org/d3-time/-/d3-time-3.1.0.tgz", @@ -1936,6 +1993,32 @@ "node": ">=12" } }, + "node_modules/d3-timer": { + "version": "3.0.1", + "resolved": "https://registry.npmjs.org/d3-timer/-/d3-timer-3.0.1.tgz", + "integrity": "sha512-ndfJ/JxxMd3nw31uyKoY2naivF+r29V+Lc0svZxe1JvvIRmi8hUsrMvdOwgS1o6uBHmiz91geQ0ylPP0aj1VUA==", + "engines": { + "node": ">=12" + } + }, + "node_modules/d3-transition": { + "version": "3.0.1", + "resolved": "https://registry.npmjs.org/d3-transition/-/d3-transition-3.0.1.tgz", + "integrity": "sha512-ApKvfjsSR6tg06xrL434C0WydLr7JewBB3V+/39RMHsaXTOG0zmt/OAXeng5M5LBm0ojmxJrpomQVZ1aPvBL4w==", + "dependencies": { + "d3-color": "1 - 3", + "d3-dispatch": "1 - 3", + "d3-ease": "1 - 3", + "d3-interpolate": "1 - 3", + "d3-timer": "1 - 3" + }, + "engines": { + "node": ">=12" + }, + "peerDependencies": { + "d3-selection": "2 - 3" + } + }, "node_modules/dashdash": { "version": "1.14.1", "resolved": "https://registry.npmjs.org/dashdash/-/dashdash-1.14.1.tgz", @@ -2033,6 +2116,11 @@ "gl-matrix": "^3.3.0" } }, + "node_modules/earcut": { + "version": "2.2.4", + "resolved": "https://registry.npmjs.org/earcut/-/earcut-2.2.4.tgz", + "integrity": "sha512-/pjZsA1b4RPHbeWZQn66SWS8nZZWLQQ23oE3Eam7aroEFGEvwKAsJfZ9ytiEMycfzXWpca4FA9QIOehf7PocBQ==" + }, "node_modules/ecc-jsbn": { "version": "0.1.2", "resolved": "https://registry.npmjs.org/ecc-jsbn/-/ecc-jsbn-0.1.2.tgz", @@ -2606,6 +2694,24 @@ "integrity": "sha512-36yxDn5H7OFZQla0/jFJmbIKTdZAQHngCedGxiMmpNfEZM0sdEeT+WczLQrjK6D7o2aiyLYDnkw0R3JK0Qv1RQ==", "dev": true }, + "node_modules/flubber": { + "version": "0.4.2", + "resolved": "https://registry.npmjs.org/flubber/-/flubber-0.4.2.tgz", + "integrity": "sha512-79RkJe3rA4nvRCVc2uXjj7U/BAUq84TS3KHn6c0Hr9K64vhj83ZNLUziNx4pJoBumSPhOl5VjH+Z0uhi+eE8Uw==", + "dependencies": { + "d3-array": "^1.2.0", + "d3-polygon": "^1.0.3", + "earcut": "^2.1.1", + "svg-path-properties": "^0.2.1", + "svgpath": "^2.2.1", + "topojson-client": "^3.0.0" + } + }, + "node_modules/flubber/node_modules/d3-array": { + "version": "1.2.4", + "resolved": "https://registry.npmjs.org/d3-array/-/d3-array-1.2.4.tgz", + "integrity": "sha512-KHW6M86R+FUPYGb3R5XiYjXPq7VzwxZ22buHhAEVG5ztoEcZZMLov530mmccaqA1GghZArjQV46fuc8kUqhhHw==" + }, "node_modules/for-each": { "version": "0.3.3", "resolved": "https://registry.npmjs.org/for-each/-/for-each-0.3.3.tgz", @@ -4471,6 +4577,19 @@ "url": "https://github.com/sponsors/ljharb" } }, + "node_modules/svg-path-properties": { + "version": "0.2.2", + "resolved": "https://registry.npmjs.org/svg-path-properties/-/svg-path-properties-0.2.2.tgz", + "integrity": "sha512-GmrB+b6woz6CCdQe6w1GHs/1lt25l7SR5hmhF8jRdarpv/OgjLyuQygLu1makJapixeb1aQhP/Oa1iKi93o/aQ==" + }, + "node_modules/svgpath": { + "version": "2.6.0", + "resolved": "https://registry.npmjs.org/svgpath/-/svgpath-2.6.0.tgz", + "integrity": "sha512-OIWR6bKzXvdXYyO4DK/UWa1VA1JeKq8E+0ug2DG98Y/vOmMpfZNj+TIG988HjfYSqtcy/hFOtZq/n/j5GSESNg==", + "funding": { + "url": "https://github.com/fontello/svg2ttf?sponsor=1" + } + }, "node_modules/text-table": { "version": "0.2.0", "resolved": "https://registry.npmjs.org/text-table/-/text-table-0.2.0.tgz", @@ -4486,6 +4605,19 @@ "node": ">=4" } }, + "node_modules/topojson-client": { + "version": "3.1.0", + "resolved": "https://registry.npmjs.org/topojson-client/-/topojson-client-3.1.0.tgz", + "integrity": "sha512-605uxS6bcYxGXw9qi62XyrV6Q3xwbndjachmNxu8HWTtVPxZfEJN9fd/SZS1Q54Sn2y0TMyMxFj/cJINqGHrKw==", + "dependencies": { + "commander": "2" + }, + "bin": { + "topo2geo": "bin/topo2geo", + "topomerge": "bin/topomerge", + "topoquantize": "bin/topoquantize" + } + }, "node_modules/tough-cookie": { "version": "2.5.0", "resolved": "https://registry.npmjs.org/tough-cookie/-/tough-cookie-2.5.0.tgz", diff --git a/web/package.json b/web/package.json index f82f9c7..7cbde44 100644 --- a/web/package.json +++ b/web/package.json @@ -13,6 +13,10 @@ "dependencies": { "d3-scale": "^4.0.2", "d3-scale-chromatic": "^3.0.0", + "d3-selection": "^3.0.0", + "d3-shape": "^3.2.0", + "d3-transition": "^3.0.1", + "flubber": "^0.4.2", "react": "^18.2.0", "react-dom": "^18.2.0", "react-router-dom": "^6.20.1", diff --git a/web/src/components/HullPlot.jsx b/web/src/components/HullPlot.jsx index 3a19f08..855b34e 100644 --- a/web/src/components/HullPlot.jsx +++ b/web/src/components/HullPlot.jsx @@ -1,67 +1,243 @@ import React, { useEffect, useRef } from 'react'; import { scaleLinear } from 'd3-scale'; +import { line, curveLinearClosed, curveCatmullRomClosed } from 'd3-shape'; +import { select } from 'd3-selection'; +import { transition } from 'd3-transition'; +import { easeExpOut, easeExpIn, easeCubicInOut} from 'd3-ease'; +import { interpolate } from 'flubber'; import "./HullPlot.css" -const HullPlot = ({ - points, + +const HullPlot = ({ + // points, hulls, fill, stroke, - size, + delay = 0, + duration = 2000, + strokeWidth, symbol, - xDomain, - yDomain, - width, + xDomain, + yDomain, + width, height }) => { - const container = useRef(); - + const svgRef = useRef(); + const prevPoints = useRef(); + const prevHulls = useRef(); + const prevMod = useRef(); + useEffect(() => { - if(xDomain && yDomain) { - const xScale = scaleLinear() - .domain(xDomain) - .range([0, width]) - const yScale = scaleLinear() - .domain(yDomain) - .range([height, 0]) - - const zScale = (t) => t/(.1 + xDomain[1] - xDomain[0]) - const canvas = container.current - const ctx = canvas.getContext('2d') - ctx.clearRect(0, 0, width, height) - ctx.fillStyle = fill - ctx.strokeStyle = stroke - ctx.font = `${zScale(size)}px monospace` - ctx.globalAlpha = 0.75 - let rw = zScale(size) - if(!hulls.length || !points.length) return - hulls.forEach(hull => { - // a hull is a list of indices into points - if(!hull) return; - ctx.beginPath() - hull.forEach((index, i) => { - if(i === 0) { - ctx.moveTo(xScale(points[index][0]), yScale(points[index][1])) - } else { - ctx.lineTo(xScale(points[index][0]), yScale(points[index][1])) - } - }) - ctx.lineTo(xScale(points[hull[0]][0]), yScale(points[hull[0]][1])) - if(fill) - ctx.fill() - if(stroke) - ctx.stroke() + if (!xDomain || !yDomain || !hulls.length) return; + + // console.log("NO PRE HULLS CURRENT", !prevHulls.current) + const hullsChanged = !prevHulls.current || (JSON.stringify(hulls[0]) !== JSON.stringify(prevHulls.current[0])) + // const pointsChanged = !prevPoints.current || (JSON.stringify(points[0]) !== JSON.stringify(prevPoints.current[0])) + + // console.log("HULLS CHANGED", hullsChanged) + if(!hullsChanged) return; + // if(!hullsChanged || !pointsChanged) { + // return + // } + + const svg = select(svgRef.current); + // Calculate scale factors + // The scale factors are calculated to fit the -1 to 1 domain within the current xDomain and yDomain + const xScaleFactor = width / (xDomain[1] - xDomain[0]); + const yScaleFactor = height / (yDomain[1] - yDomain[0]); + + // Calculate translation to center the drawing at (0,0) + // This centers the view at (0,0) and accounts for the SVG's inverted y-axis + const xOffset = width / 2 - (xScaleFactor * (xDomain[1] + xDomain[0]) / 2); + const yOffset = height / 2 + (yScaleFactor * (yDomain[1] + yDomain[0]) / 2); + + // Calculate a scaled stroke width + const scaledStrokeWidth = strokeWidth / Math.sqrt(xScaleFactor * yScaleFactor); + + const g = svg.select("g.hull-container"); + g.attr('transform', `translate(${xOffset}, ${yOffset}) scale(${xScaleFactor}, ${yScaleFactor})`); + + const draw = line() + .x(d => d[0]) + .y(d => -d[1]) + // .curve(curveCatmullRomClosed); + .curve(curveLinearClosed); + + let sel = g.selectAll("path.hull") + .data(hulls) + + const exit = sel.exit() + .transition() + .duration(duration) + .delay(delay) + .ease(easeExpOut) + .style("opacity", 0) + .remove() + + const enter = sel.enter() + .append("path") + .classed("hull", true) + .attr("d", draw) + .style("fill", fill) + .style("stroke", stroke) + .style("stroke-width", scaledStrokeWidth) + .style("opacity", 0.) + .transition() + .delay(delay + 100) + .duration(duration - 100) + .ease(easeExpOut) + .style("opacity", 0.75) + + const update = sel + .transition() + .duration(duration) + .delay(delay) + .ease(easeCubicInOut) + .style("opacity", 0.75) + // .attr("d", draw) + .attrTween("d", function(d,i) { + // console.log("d,i", d, i) + // console.log(d.hull, prevHulls.current.find(h => h.index == d.index).hull) + const prev = prevHulls.current ? prevHulls.current[i] : null + // console.log(d, prev) + if(!prev) return () => draw(d) + const inter = interpolate( + draw(prev), + draw(d) + ); + return function(t) { + return inter(t) + } }) - } - }, [points, hulls, fill, stroke, size, xDomain, yDomain, width, height]) - return ; + setTimeout(() => { + prevHulls.current = hulls + // prevHulls.current = mod + // prevPoints.current = points + }, duration) + + }, [hulls]) + + // This effect will rerender instantly when the fill, stroke, strokeWidth, or domain changes + useEffect(() => { + if (!xDomain || !yDomain || !hulls.length ) return; + const svg = select(svgRef.current); + + // Calculate scale factors + // The scale factors are calculated to fit the -1 to 1 domain within the current xDomain and yDomain + const xScaleFactor = width / (xDomain[1] - xDomain[0]); + const yScaleFactor = height / (yDomain[1] - yDomain[0]); + + // Calculate translation to center the drawing at (0,0) + // This centers the view at (0,0) and accounts for the SVG's inverted y-axis + const xOffset = width / 2 - (xScaleFactor * (xDomain[1] + xDomain[0]) / 2); + const yOffset = height / 2 + (yScaleFactor * (yDomain[1] + yDomain[0]) / 2); + + // Calculate a scaled stroke width + const scaledStrokeWidth = strokeWidth / Math.sqrt(xScaleFactor * yScaleFactor); + + const g = svg.select("g.hull-container"); + g.attr('transform', `translate(${xOffset}, ${yOffset}) scale(${xScaleFactor}, ${yScaleFactor})`); + + const draw = line() + .x(d => d[0]) + .y(d => -d[1]) + // .curve(curveCatmullRomClosed); + .curve(curveLinearClosed); + + // Draw hulls + let sel = g.selectAll("path.hull") + .data(hulls) + sel.enter() + .append("path") + .classed("hull", true) + .attr("d", draw) + .style("fill", fill) + .style("stroke", stroke) + .attr("stroke-width", scaledStrokeWidth) + .style("opacity", 0.75) + + sel.exit().remove() + + sel.attr("d", draw) + + }, [fill, stroke, strokeWidth, xDomain, yDomain, width, height]) + + return ( + + ); }; export default HullPlot; + + + +// const HullPlotCanvas = ({ +// points, +// hulls, +// fill, +// stroke, +// strokeWidth, +// symbol, +// xDomain, +// yDomain, +// width, +// height +// }) => { +// const container = useRef(); + +// useEffect(() => { +// if(xDomain && yDomain) { +// const xScale = scaleLinear() +// .domain(xDomain) +// .range([0, width]) +// const yScale = scaleLinear() +// .domain(yDomain) +// .range([height, 0]) + +// const zScale = (t) => t/(.1 + xDomain[1] - xDomain[0]) +// const canvas = container.current +// const ctx = canvas.getContext('2d') +// ctx.clearRect(0, 0, width, height) +// ctx.fillStyle = fill +// ctx.strokeStyle = stroke +// ctx.font = `${zScale(strokeWidth)}px monospace` +// ctx.globalAlpha = 0.75 +// let rw = zScale(strokeWidth) +// if(!hulls.length || !points.length) return +// hulls.forEach(hull => { +// // a hull is a list of indices into points +// if(!hull) return; +// ctx.beginPath() +// hull.forEach((index, i) => { +// if(i === 0) { +// ctx.moveTo(xScale(points[index][0]), yScale(points[index][1])) +// } else { +// ctx.lineTo(xScale(points[index][0]), yScale(points[index][1])) +// } +// }) +// ctx.lineTo(xScale(points[hull[0]][0]), yScale(points[hull[0]][1])) +// if(fill) +// ctx.fill() +// if(stroke) +// ctx.stroke() +// }) +// } + +// }, [points, hulls, fill, stroke, strokeWidth, xDomain, yDomain, width, height]) + +// return ; +// }; + +// export default HullPlot; diff --git a/web/src/components/IndexDataTable.jsx b/web/src/components/IndexDataTable.jsx index 5d6fe2a..1a672bf 100644 --- a/web/src/components/IndexDataTable.jsx +++ b/web/src/components/IndexDataTable.jsx @@ -51,7 +51,7 @@ function IndexDataTable({dataset, indices, distances = [], clusterIndices = [], useEffect(() => { if(indices && indices.length) { - console.log("refetching hydrate") + // console.log("refetching hydrate") hydrateIndices(indices) } }, [indices]) diff --git a/web/src/components/Scatter.jsx b/web/src/components/Scatter.jsx index 66887e6..26ab9f0 100644 --- a/web/src/components/Scatter.jsx +++ b/web/src/components/Scatter.jsx @@ -10,10 +10,11 @@ import styles from "./Scatter.module.css" import PropTypes from 'prop-types'; ScatterPlot.propTypes = { - points: PropTypes.array.isRequired, - colors: PropTypes.array, + points: PropTypes.array.isRequired, // an array of [x,y] points + colors: PropTypes.array, // an array of integer values width: PropTypes.number.isRequired, height: PropTypes.number.isRequired, + duration: PropTypes.number, onScatter: PropTypes.func, onView: PropTypes.func, onSelect: PropTypes.func, @@ -46,14 +47,16 @@ const calculatePointOpacity = (numPoints) => { function ScatterPlot ({ points, - colors, + categories, width, height, + duration = 0, onScatter, onView, onSelect, onHover, }) { + const container = useRef(); const xDomain = useRef([-1, 1]); const yDomain = useRef([-1, 1]); @@ -62,74 +65,102 @@ function ScatterPlot ({ const yScale = scaleLinear() .domain(yDomain.current) + const scatterplotRef = useRef(null); + // setup the scatterplot on first render useEffect(() => { - if(points && points.length){ + const scatterSettings = { + canvas: container.current, + width, + height, + pointColorHover: [0.1, 0.1, 0.1, 0.5], + xScale, + yScale, + } + const scatterplot = createScatterplot(scatterSettings); + scatterplotRef.current = scatterplot; + + onView && onView(xScale, yScale) + scatterplot.subscribe( + "view", + ({ camera, view, xScale: xs, yScale: ys }) => { + xDomain.current = xs.domain(); + yDomain.current = ys.domain(); + onView && onView(xDomain.current, yDomain.current) + } + ); + scatterplot.subscribe("select", ({ points }) => { + onSelect && onSelect(points) + }); + scatterplot.subscribe("deselect", () => { + onSelect && onSelect([]) + }); + scatterplot.subscribe("pointOver", (pointIndex) => { + onHover && onHover(pointIndex) + }); + scatterplot.subscribe("pointOut", () => { + onHover && onHover(null) + }); + + onScatter && onScatter(scatterplot) + + return () => { + scatterplotRef.current = null; + scatterplot.destroy(); + }; + }, [width, height, onScatter, onView, onSelect, onHover]) + + const prevPointsRef = useRef(); + useEffect(() => { + const scatterplot = scatterplotRef.current; + const prevPoints = prevPointsRef.current; + if(scatterplot && points && points.length){ const pointSize = calculatePointSize(points.length); const opacity = calculatePointOpacity(points.length); // console.log("point size", pointSize, opacity) let pointColor = [250/255, 128/255, 114/255, 1] //salmon - let drawPoints = points - if(colors?.length) { - drawPoints = points.map((p, i) => { - return [p[0], p[1], colors[i]] - }) - const uniques = groups(colors, d => d).map(d => d[0]).sort((a,b) => a - b) + + // let drawPoints = points + let categories = points[0].length === 3 ? true : false + if(categories) { + // drawPoints = points.map((p, i) => { + // return [p[0], p[1], categories[i]] + // }) + const uniques = groups(points.map(d => d[2]), d => d).map(d => d[0]).sort((a,b) => a - b) + // TODO: colors should already be chosen before passing in here // const colorScale = scaleSequential(interpolateViridis) // const colorScale = scaleSequential(interpolateTurbo) const colorScale = scaleSequential(interpolateCool) .domain(extent(uniques).reverse()); pointColor = uniques.map(u => rgb(colorScale(u)).hex()) } - const scatterSettings = { - canvas: container.current, - width, - height, - pointSize, - opacity, - pointColor, - pointColorHover: [0.1, 0.1, 0.1, 0.5], - xScale, - yScale, - } - if(colors?.length){ - scatterSettings.colorBy = 'valueA' - } - const scatterplot = createScatterplot(scatterSettings); - - scatterplot.draw(drawPoints); - - onView && onView(xScale, yScale) - scatterplot.subscribe( - "view", - ({ camera, view, xScale: xs, yScale: ys }) => { - xDomain.current = xs.domain(); - yDomain.current = ys.domain(); - onView && onView(xDomain.current, yDomain.current) + scatterplot.set({ + opacity: opacity, + pointSize: pointSize, + }) + if(categories){ + scatterplot.set({colorBy: 'valueA'}); } - ); - scatterplot.subscribe("select", ({ points }) => { - onSelect && onSelect(points) - }); - scatterplot.subscribe("deselect", () => { - onSelect && onSelect([]) - }); - scatterplot.subscribe("pointOver", (pointIndex) => { - onHover && onHover(pointIndex) - }); - scatterplot.subscribe("pointOut", () => { - onHover && onHover(null) - }); - - // TODO: this may not be proper React - onScatter && onScatter(scatterplot) - - return () => { - scatterplot.destroy(); - }; + if(prevPoints && prevPoints.length === points.length) { + // console.log("transitioning scatterplot") + scatterplot.draw(points, { transition: true, transitionDuration: duration}).then(() => { + // don't color till after + scatterplot.set({ + pointColor: pointColor, + }) + scatterplot.draw(points, { transition: false }); + }) + } else { + // console.log("fresh draw scatterplot") + scatterplot.set({ + pointColor: pointColor, + }) + scatterplot.draw(points, { transition: false }); + } + prevPointsRef.current = points; } - }, [points, colors, width, height]); + }, [points, categories, width, height]); return ; } diff --git a/web/src/components/Setup/ClusterLabels.jsx b/web/src/components/Setup/ClusterLabels.jsx index 9184404..591d8b8 100644 --- a/web/src/components/Setup/ClusterLabels.jsx +++ b/web/src/components/Setup/ClusterLabels.jsx @@ -16,14 +16,14 @@ ClusterLabels.propTypes = { selectedLabelId: PropTypes.string, onChange: PropTypes.func.isRequired, onLabels: PropTypes.func, - onLabelIds: PropTypes.func, + onLabelSets: PropTypes.func, onHoverLabel: PropTypes.func, onClickLabel: PropTypes.func, }; // This component is responsible for the embeddings state // New embeddings update the list -function ClusterLabels({ dataset, cluster, selectedLabelId, onChange, onLabels, onLabelIds, onHoverLabel, onClickLabel}) { +function ClusterLabels({ dataset, cluster, selectedLabelId, onChange, onLabels, onLabelSets, onHoverLabel, onClickLabel}) { const [clusterLabelsJob, setClusterLabelsJob] = useState(null); const { startJob: startClusterLabelsJob } = useStartJobPolling(dataset, setClusterLabelsJob, `${apiUrl}/jobs/cluster_label`); const { startJob: rerunClusterLabelsJob } = useStartJobPolling(dataset, setClusterLabelsJob, `${apiUrl}/jobs/rerun`); @@ -41,15 +41,16 @@ function ClusterLabels({ dataset, cluster, selectedLabelId, onChange, onLabels, }, []); // the models used to label a particular cluster (the ones the user has run) - const [clusterLabelModels, setClusterLabelModels] = useState([]); + const [clusterLabelSets, setClusterLabelSets] = useState([]); // the actual labels for the given cluster const [clusterLabels, setClusterLabels] = useState([]); useEffect(() => { - console.log("in cluster labels", dataset, cluster, selectedLabelId) if(dataset && cluster && selectedLabelId) { - fetch(`${apiUrl}/datasets/${dataset.id}/clusters/${cluster.id}/labels/${selectedLabelId}`) + const id = selectedLabelId.split("-")[3] || selectedLabelId + fetch(`${apiUrl}/datasets/${dataset.id}/clusters/${cluster.id}/labels/${id}`) .then(response => response.json()) .then(data => { + data.cluster_id = cluster.id setClusterLabels(data) }).catch(err => { console.log(err) @@ -58,39 +59,50 @@ function ClusterLabels({ dataset, cluster, selectedLabelId, onChange, onLabels, } else { setClusterLabels([]) } - }, [selectedLabelId, setClusterLabels, dataset, cluster, clusterLabelModels]) + }, [selectedLabelId, setClusterLabels, dataset, cluster, clusterLabelSets]) useEffect(() => { if(cluster) { fetch(`${apiUrl}/datasets/${dataset.id}/clusters/${cluster.id}/labels_available`) .then(response => response.json()) .then(data => { - console.log("cluster changed, set label models fetched", cluster.id, data, clusterLabelsJob) + // console.log("cluster changed, labels available", cluster.id, data) + const labelsAvailable = data.filter(d => d.cluster_id == cluster.id) + let lbl; if(clusterLabelsJob) { - let lbl; if(clusterLabelsJob?.job_name == "label"){ - let label_id = clusterLabelsJob.run_id.split("-")[3] - lbl = data.find(d => d == label_id) - console.log("label_id", label_id, lbl) + let label_id = clusterLabelsJob.run_id//.split("-")[3] + let found = labelsAvailable.find(d => d.id == label_id) + if(found) lbl = found } else if(clusterLabelsJob.job_name == "rm") { lbl = data[0] } - onLabelIds(data.map(id => ({cluster_id: cluster.id, id: id})), lbl) // onChange(lbl) - } else { - onLabelIds(data.map(id => ({cluster_id: cluster.id, id: id}))) + } else if(selectedLabelId){ + if(selectedLabelId == "default" && labelsAvailable[0]) { + lbl = labelsAvailable[0] + } else if(selectedLabelId.indexOf(cluster.id) < 0 && labelsAvailable[0]) { + lbl = labelsAvailable[0] + } else { + lbl = labelsAvailable.find(d => d.id == selectedLabelId) || { id: "default" } + } + } else if(labelsAvailable[0]) { + lbl = labelsAvailable[0] + } else { + lbl = { id: "default" } } - setClusterLabelModels(data) + onLabelSets(labelsAvailable, lbl) + setClusterLabelSets(labelsAvailable) }).catch(err => { console.log(err) - setClusterLabelModels([]) - onLabelIds([]) + setClusterLabelSets([]) + onLabelSets([]) }) } else { - setClusterLabelModels([]) - onLabelIds([]) + setClusterLabelSets([]) + onLabelSets([]) } - }, [dataset, cluster, clusterLabelsJob, setClusterLabelModels, onLabelIds]) + }, [dataset, cluster, clusterLabelsJob, setClusterLabelSets, onLabelSets]) useEffect(() => { if(clusterLabels?.length) { @@ -123,7 +135,7 @@ function ClusterLabels({ dataset, cluster, selectedLabelId, onChange, onLabels,