Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
thomasdavis committed Dec 19, 2024
1 parent 4516051 commit 8c25d9b
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 153 deletions.
256 changes: 105 additions & 151 deletions apps/registry/app/job-similarity/page.js
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,7 @@ const algorithms = {
},
pathfinder: {
name: 'Pathfinder Network',
compute: (nodes, r = 2, q = 2) => {
compute: (nodes, r = 2) => {
const n = nodes.length;
const distances = Array(n).fill().map(() => Array(n).fill(Infinity));

Expand Down Expand Up @@ -463,6 +463,7 @@ const Header = memo(() => (
</div>
</div>
));
Header.displayName = 'Header';

const Controls = memo(({ dataSource, setDataSource, algorithm, setAlgorithm }) => (
<div className="prose max-w-3xl mx-auto mb-8">
Expand Down Expand Up @@ -499,6 +500,7 @@ const Controls = memo(({ dataSource, setDataSource, algorithm, setAlgorithm }) =
</div>
</div>
));
Controls.displayName = 'Controls';

const GraphContainer = ({ dataSource, algorithm }) => {
const [graphData, setGraphData] = useState(null);
Expand All @@ -508,170 +510,121 @@ const GraphContainer = ({ dataSource, algorithm }) => {
const [highlightLinks, setHighlightLinks] = useState(new Set());
const [loading, setLoading] = useState(true);
const [error, setError] = useState(null);
const [edges, setEdges] = useState([]);

// Fetch data when dataSource changes
useEffect(() => {
const fetchData = async () => {
setLoading(true);
setError(null);
try {
const endpoint = dataSource === 'jobs' ? '/api/job-similarity' : '/api/similarity';
const response = await fetch(endpoint);
if (!response.ok) {
throw new Error('Failed to fetch data');
}
const jsonData = await response.json();

// Filter out items without valid embeddings
const validData = jsonData.filter(item => {
const embedding = dataSource === 'jobs' ?
item.embedding :
(typeof item.embedding === 'string' ? JSON.parse(item.embedding) : item.embedding);
return Array.isArray(embedding) && embedding.length > 0;
});

// Group similar items
const groups = {};
const handleNodeHover = useCallback((node) => {
setHighlightNodes(new Set(node ? [node] : []));
setHighlightLinks(new Set(node ? edges.filter(link => link.source === node || link.target === node) : []));
setHoverNode(node || null);
}, [edges]);

validData.forEach(item => {
const label = dataSource === 'jobs'
? item.title
: (item.position || 'Unknown Position');

if (!groups[label]) {
groups[label] = [];
}
groups[label].push(item);
const handleNodeClick = useCallback((node) => {
if (!node) return;
const nodeData = rawNodes.find(n => n.title === node.id);
if (nodeData) {
window.open(nodeData.url, '_blank');
}
}, [rawNodes]);

const processData = useCallback((data) => {
// Filter out items without valid embeddings
const validData = data.filter(item => {
const embedding = dataSource === 'jobs' ?
item.embedding :
(typeof item.embedding === 'string' ? JSON.parse(item.embedding) : item.embedding);
return Array.isArray(embedding) && embedding.length > 0;
});

// Group similar items
const groups = {};

validData.forEach(item => {
const label = dataSource === 'jobs'
? item.title
: (item.position || 'Unknown Position');

if (!groups[label]) {
groups[label] = [];
}
groups[label].push(item);
});

// Create nodes with normalized embeddings
const nodes = Object.entries(groups)
.map(([label, items], index) => {
const embeddings = items.map(item => {
if (dataSource === 'jobs') return item.embedding;
return typeof item.embedding === 'string' ?
JSON.parse(item.embedding) : item.embedding;
});

// Create nodes with normalized embeddings
const nodes = Object.entries(groups)
.map(([label, items], index) => {
const embeddings = items.map(item => {
if (dataSource === 'jobs') return item.embedding;
return typeof item.embedding === 'string' ?
JSON.parse(item.embedding) : item.embedding;
});

const normalizedEmbeddings = embeddings
.map(emb => normalizeVector(emb))
.filter(emb => emb !== null);

if (normalizedEmbeddings.length === 0) return null;

const avgEmbedding = getAverageEmbedding(normalizedEmbeddings);
if (!avgEmbedding) return null;

return {
id: label,
group: index,
size: Math.log(items.length + 1) * 3,
count: items.length,
uuids: items.map(item => dataSource === 'jobs' ? item.uuid : item.username),
usernames: dataSource === 'jobs' ? null : [...new Set(items.map(item => item.username))],
avgEmbedding,
color: `hsl(${Math.random() * 360}, 70%, 50%)`,
companies: dataSource === 'jobs' ? [...new Set(items.map(item => item.company || 'Unknown Company'))] : null,
countryCodes: dataSource === 'jobs' ? [...new Set(items.map(item => item.countryCode || 'Unknown Location'))] : null
};
})
.filter(node => node !== null);

if (nodes.length === 0) {
throw new Error('No valid data found with embeddings');
}

setRawNodes(nodes);
} catch (err) {
console.error('Error in fetchData:', err);
setError(err.message);
} finally {
setLoading(false);
}
};
const normalizedEmbeddings = embeddings
.map(emb => normalizeVector(emb))
.filter(emb => emb !== null);

if (normalizedEmbeddings.length === 0) return null;

const avgEmbedding = getAverageEmbedding(normalizedEmbeddings);
if (!avgEmbedding) return null;

return {
id: label,
group: index,
size: Math.log(items.length + 1) * 3,
count: items.length,
uuids: items.map(item => dataSource === 'jobs' ? item.uuid : item.username),
usernames: dataSource === 'jobs' ? null : [...new Set(items.map(item => item.username))],
avgEmbedding,
color: `hsl(${Math.random() * 360}, 70%, 50%)`,
companies: dataSource === 'jobs' ? [...new Set(items.map(item => item.company || 'Unknown Company'))] : null,
countryCodes: dataSource === 'jobs' ? [...new Set(items.map(item => item.countryCode || 'Unknown Location'))] : null
};
})
.filter(node => node !== null);

if (nodes.length === 0) {
throw new Error('No valid data found with embeddings');
}

fetchData();
return nodes;
}, [dataSource]);

// Compute links when algorithm changes or when we have new nodes
useEffect(() => {
if (!rawNodes) return;

const links = [];
const threshold = 0.7; // Similarity threshold

// Different algorithms for computing links
if (algorithm === 'mst') {
// Kruskal's algorithm for MST
const parent = new Array(rawNodes.length).fill(0).map((_, i) => i);

function find(x) {
if (parent[x] !== x) parent[x] = find(parent[x]);
return parent[x];
}

function union(x, y) {
parent[find(x)] = find(y);
}

// Create all possible edges with weights
const edges = [];
for (let i = 0; i < rawNodes.length; i++) {
for (let j = i + 1; j < rawNodes.length; j++) {
const similarity = cosineSimilarity(rawNodes[i].avgEmbedding, rawNodes[j].avgEmbedding);
if (similarity > threshold) {
edges.push({ i, j, similarity });
}
}
const fetchData = useCallback(async () => {
setLoading(true);
setError(null);
try {
const response = await fetch(`/api/${dataSource === 'jobs' ? 'job-' : ''}similarity?limit=250&algorithm=${algorithm}`);
if (!response.ok) {
throw new Error('Failed to fetch data');
}
const data = await response.json();
const processedData = processData(data);
setRawNodes(processedData);
} catch (err) {
console.error('Error fetching data:', err);
setError(err.message);
} finally {
setLoading(false);
}
}, [dataSource, algorithm, processData]);

// Sort edges by similarity (descending)
edges.sort((a, b) => b.similarity - a.similarity);
const processLinks = useCallback(() => {
if (!rawNodes) return;

// Build MST
edges.forEach(({ i, j, similarity }) => {
if (find(i) !== find(j)) {
union(i, j);
links.push({
source: rawNodes[i].id,
target: rawNodes[j].id,
value: similarity
});
}
});
} else if (algorithm === 'threshold') {
// Simple threshold algorithm
for (let i = 0; i < rawNodes.length; i++) {
for (let j = i + 1; j < rawNodes.length; j++) {
const similarity = cosineSimilarity(rawNodes[i].avgEmbedding, rawNodes[j].avgEmbedding);
if (similarity > threshold) {
links.push({
source: rawNodes[i].id,
target: rawNodes[j].id,
value: similarity
});
}
}
}
}
const { compute } = algorithms[algorithm];
const links = compute(rawNodes);

setGraphData({ nodes: rawNodes, links });
setEdges(links);
}, [rawNodes, algorithm]);

const handleNodeHover = useCallback(node => {
setHighlightNodes(new Set(node ? [node] : []));
setHighlightLinks(new Set(graphData?.links.filter(link =>
link.source.id === node?.id || link.target.id === node?.id
) || []));
setHoverNode(node || null);
}, [graphData]);
useEffect(() => {
fetchData();
}, [fetchData]);

const handleNodeClick = useCallback(node => {
if (node.uuids && node.uuids.length > 0) {
const baseUrl = dataSource === 'jobs' ? '/jobs/' : '/';
window.open(`${baseUrl}${node.uuids[0]}`, '_blank');
}
}, [dataSource]);
useEffect(() => {
processLinks();
}, [processLinks]);

if (loading) return (
<div className="prose max-w-3xl mx-auto h-[calc(100vh-32rem)] flex items-start justify-center bg-white pt-16">
Expand Down Expand Up @@ -703,8 +656,9 @@ const GraphContainer = ({ dataSource, algorithm }) => {
nodeColor={node => highlightNodes.has(node) ? '#ff0000' : node.color}
nodeCanvasObject={(node, ctx, globalScale) => {
// Draw node
const size = node.size * (4 / Math.max(1, globalScale));
ctx.beginPath();
ctx.arc(node.x, node.y, node.size * 2, 0, 2 * Math.PI);
ctx.arc(node.x, node.y, size, 0, 2 * Math.PI);
ctx.fillStyle = highlightNodes.has(node) ? '#ff0000' : node.color;
ctx.fill();

Expand Down
5 changes: 3 additions & 2 deletions apps/registry/app/similarity/page.js
Original file line number Diff line number Diff line change
Expand Up @@ -141,15 +141,16 @@ export default function SimilarityPage() {
nodeColor={node => highlightNodes.has(node) ? '#ff0000' : node.color}
nodeCanvasObject={(node, ctx, globalScale) => {
// Draw node
const size = node.size * (4 / Math.max(1, globalScale));
ctx.beginPath();
ctx.arc(node.x, node.y, node.size * 2, 0, 2 * Math.PI);
ctx.arc(node.x, node.y, size, 0, 2 * Math.PI);
ctx.fillStyle = highlightNodes.has(node) ? '#ff0000' : node.color;
ctx.fill();

// Only draw label if node is highlighted
if (highlightNodes.has(node)) {
const label = node.id;
const fontSize = Math.max(14, node.size * 1.5);
const fontSize = Math.max(14, size * 1.5);
ctx.font = `${fontSize}px Sans-Serif`;
const textWidth = ctx.measureText(label).width;
const bckgDimensions = [textWidth, fontSize].map(n => n + fontSize * 0.2);
Expand Down

0 comments on commit 8c25d9b

Please sign in to comment.