-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
barycenter algo for same grouplvl nodes
- Loading branch information
Showing
2 changed files
with
363 additions
and
229 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,132 +1,194 @@ | ||
import logging | ||
from core.layout.layout_manager import LayoutManager | ||
from collections import defaultdict | ||
|
||
from core.layout.layout_manager import LayoutManager | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
class HorizontalLayout(LayoutManager): | ||
""" | ||
Applies a horizontal layout strategy to arrange nodes. | ||
Applies an iterative "barycenter" layout strategy to arrange nodes horizontally. | ||
Each node's y-position is determined purely from how it connects to other nodes, | ||
including same-level and cross-level edges. | ||
""" | ||
|
||
def apply(self, diagram, verbose=False) -> None: | ||
logger.debug("Applying horizontal layout...") | ||
""" | ||
Main entry point: Called by clab2drawio code to apply a horizontal layout. | ||
:param diagram: CustomDrawioDiagram instance with .nodes | ||
:param verbose: Whether to log extra info | ||
""" | ||
logger.debug("Applying iterative barycenter layout (horizontal)...") | ||
self.diagram = diagram | ||
self.verbose = verbose | ||
self._calculate_positions() | ||
|
||
def _calculate_positions(self): | ||
nodes = self.diagram.nodes | ||
nodes = sorted(nodes.values(), key=lambda node: (node.graph_level, node.name)) | ||
|
||
# Get padding from styles | ||
padding_x = self.diagram.styles['padding_x'] | ||
padding_y = self.diagram.styles['padding_y'] | ||
|
||
x_start, y_start = 100, 100 | ||
|
||
logger.debug("Nodes before calculate_positions:" + str(nodes)) | ||
|
||
def prioritize_placement(nodes, level): | ||
diagram = self.diagram | ||
if level == diagram.get_max_level(): | ||
ordered_nodes = sorted(nodes, key=lambda node: node.name) | ||
else: | ||
multi_connection_nodes = [node for node in nodes if node.get_connection_count_within_level() > 1] | ||
single_connection_nodes = [node for node in nodes if node.get_connection_count_within_level() == 1] | ||
zero_connection_nodes = [node for node in nodes if node.get_connection_count_within_level() == 0] | ||
|
||
multi_connection_nodes_with_lateral = [] | ||
multi_connection_nodes_without_lateral = [] | ||
for node in multi_connection_nodes: | ||
if any( | ||
link.target in multi_connection_nodes | ||
for link in node.get_lateral_links() | ||
): | ||
multi_connection_nodes_with_lateral.append(node) | ||
else: | ||
multi_connection_nodes_without_lateral.append(node) | ||
|
||
sorted_multi_connection_nodes_with_lateral = [] | ||
while multi_connection_nodes_with_lateral: | ||
node = multi_connection_nodes_with_lateral.pop(0) | ||
sorted_multi_connection_nodes_with_lateral.append(node) | ||
for link in node.get_lateral_links(): | ||
if link.target in multi_connection_nodes_with_lateral: | ||
multi_connection_nodes_with_lateral.remove(link.target) | ||
sorted_multi_connection_nodes_with_lateral.append(link.target) | ||
|
||
multi_connection_nodes_without_lateral = sorted( | ||
multi_connection_nodes_without_lateral, key=lambda node: node.name | ||
) | ||
sorted_multi_connection_nodes_with_lateral = sorted( | ||
sorted_multi_connection_nodes_with_lateral, key=lambda node: node.name | ||
) | ||
single_connection_nodes = sorted(single_connection_nodes, key=lambda node: node.name) | ||
|
||
ordered_nodes = ( | ||
single_connection_nodes[: len(single_connection_nodes) // 2] | ||
+ multi_connection_nodes_without_lateral | ||
+ sorted_multi_connection_nodes_with_lateral | ||
+ single_connection_nodes[len(single_connection_nodes) // 2 :] | ||
+ zero_connection_nodes | ||
) | ||
|
||
return ordered_nodes | ||
|
||
nodes_by_graphlevel = defaultdict(list) | ||
for node in nodes: | ||
nodes_by_graphlevel[node.graph_level].append(node) | ||
|
||
for graphlevel, graphlevel_nodes in nodes_by_graphlevel.items(): | ||
ordered_nodes = prioritize_placement(graphlevel_nodes, graphlevel) | ||
for i, node in enumerate(ordered_nodes): | ||
# horizontal layout | ||
node.pos_x = x_start + graphlevel * padding_x | ||
node.pos_y = y_start + i * padding_y | ||
|
||
self._center_align_nodes(nodes_by_graphlevel, layout="horizontal", verbose=self.verbose) | ||
intermediaries_x, intermediaries_y = self.diagram.get_nodes_between_interconnected() | ||
self._adjust_intermediary_nodes(intermediaries_y, layout="horizontal", verbose=self.verbose) | ||
|
||
def _adjust_intermediary_nodes(self, intermediaries, layout, verbose=False): | ||
if not intermediaries: | ||
return | ||
intermediaries_by_level = defaultdict(list) | ||
for node in intermediaries: | ||
intermediaries_by_level[node.graph_level].append(node) | ||
|
||
selected_level = max( | ||
intermediaries_by_level.keys(), | ||
key=lambda lvl: len(intermediaries_by_level[lvl]), | ||
) | ||
selected_group = intermediaries_by_level[selected_level] | ||
|
||
if len(selected_group) == 1: | ||
node = selected_group[0] | ||
if layout == "vertical": | ||
node.pos_x = node.pos_x - 100 | ||
else: | ||
node.pos_y = node.pos_y - 100 | ||
else: | ||
for i, node in enumerate(selected_group): | ||
if layout == "vertical": | ||
node.pos_x = node.pos_x - 100 + i * 200 | ||
# Put nodes into bins by their graph_level | ||
nodes_by_level = defaultdict(list) | ||
for n in self.diagram.nodes.values(): | ||
nodes_by_level[n.graph_level].append(n) | ||
|
||
# Initialize: give each level a naive top->bottom order | ||
sorted_levels = sorted(nodes_by_level.keys()) | ||
for level in sorted_levels: | ||
# sort by name | ||
nodes_by_level[level].sort(key=lambda nd: nd.name) | ||
# assign an initial y from top->bottom | ||
for i, nd in enumerate(nodes_by_level[level]): | ||
nd.pos_y = float(100 + i * self.diagram.styles["padding_y"]) | ||
|
||
# We'll do N passes. Each pass: left->right, then right->left | ||
# In each pass, we compute barycenter for each node (based on neighbors) | ||
# and reorder the level by that barycenter. | ||
|
||
def compute_barycenters_at_level(level_nodes): | ||
""" | ||
For each node in level_nodes, compute the average y of all its neighbors, | ||
ignoring their levels (so it includes same-level edges and multi-level edges). | ||
Store the result in node._bary (temp). | ||
""" | ||
for nd in level_nodes: | ||
# gather y positions of neighbors | ||
neighbor_y_positions = [] | ||
for nbr in nd.get_neighbors(): | ||
# skip if neighbor has no numeric pos_y | ||
try: | ||
ny = float(nbr.pos_y) | ||
except (TypeError, ValueError): | ||
ny = 0.0 | ||
neighbor_y_positions.append(ny) | ||
|
||
if neighbor_y_positions: | ||
nd._bary = sum(neighbor_y_positions) / len(neighbor_y_positions) | ||
else: | ||
node.pos_y = node.pos_y - 100 + i * 200 | ||
|
||
def _center_align_nodes(self, nodes_by_graphlevel, layout="horizontal", verbose=False): | ||
attr_x, attr_y = ("pos_x", "pos_y") if layout == "vertical" else ("pos_y", "pos_x") | ||
|
||
prev_graphlevel_center = None | ||
for graphlevel, nodes in sorted(nodes_by_graphlevel.items()): | ||
graphlevel_centers = [getattr(node, attr_x) for node in nodes] | ||
|
||
if prev_graphlevel_center is None: | ||
prev_graphlevel_center = (min(graphlevel_centers) + max(graphlevel_centers)) / 2 | ||
nd._bary = 0.0 | ||
|
||
def reorder_by_barycenter(level_nodes): | ||
""" | ||
Sort the list of nodes by the barycenter we just computed. | ||
Keep it stable so that ties won't cause random reordering. | ||
""" | ||
level_nodes.sort(key=lambda nd: nd._bary) | ||
|
||
num_passes = 4 # or more, typically 4~6 is enough | ||
for _iter in range(num_passes): | ||
# left->right sweep | ||
for level in sorted_levels: | ||
level_nodes = nodes_by_level[level] | ||
compute_barycenters_at_level(level_nodes) | ||
reorder_by_barycenter(level_nodes) | ||
# reassign y after sorting | ||
for i, nd in enumerate(level_nodes): | ||
nd.pos_y = float(100 + i * self.diagram.styles["padding_y"]) | ||
|
||
# right->left sweep | ||
for level in reversed(sorted_levels): | ||
level_nodes = nodes_by_level[level] | ||
compute_barycenters_at_level(level_nodes) | ||
reorder_by_barycenter(level_nodes) | ||
# reassign y | ||
for i, nd in enumerate(level_nodes): | ||
nd.pos_y = float(100 + i * self.diagram.styles["padding_y"]) | ||
|
||
# Assign final pos_x from graph_level, keep the pos_y from the last iteration | ||
for level in sorted_levels: | ||
for node in nodes_by_level[level]: | ||
node.pos_x = float(100 + level * self.diagram.styles["padding_x"]) | ||
|
||
self._center_align_nodes(nodes_by_level) | ||
self._adjust_intermediary_nodes(diagram) | ||
|
||
logger.debug("Iterative barycenter layout complete (horizontal).") | ||
|
||
def _center_align_nodes(self, nodes_by_level): | ||
""" | ||
Shift each level vertically so they are centered | ||
around a consistent "global_center" or around the previous column. | ||
""" | ||
sorted_levels = sorted(nodes_by_level.keys()) | ||
# pick a global center, or compute from the leftmost column | ||
global_center = 300.0 | ||
|
||
prev_center = None | ||
for level in sorted_levels: | ||
level_nodes = nodes_by_level[level] | ||
if not level_nodes: | ||
continue | ||
# find min & max y | ||
min_y = min(nd.pos_y for nd in level_nodes) | ||
max_y = max(nd.pos_y for nd in level_nodes) | ||
col_center = (min_y + max_y) / 2.0 | ||
|
||
if prev_center is None: | ||
# for the leftmost column, align it to global_center | ||
offset = global_center - col_center | ||
for nd in level_nodes: | ||
nd.pos_y += offset | ||
# update col_center after shift | ||
min_y = min(nd.pos_y for nd in level_nodes) | ||
max_y = max(nd.pos_y for nd in level_nodes) | ||
col_center = (min_y + max_y) / 2.0 | ||
prev_center = col_center | ||
else: | ||
graphlevel_center = sum(graphlevel_centers) / len(nodes) | ||
offset = prev_graphlevel_center - graphlevel_center | ||
for node in nodes: | ||
setattr(node, attr_x, getattr(node, attr_x) + offset) | ||
prev_graphlevel_center = sum(getattr(node, attr_x) for node in nodes) / len(nodes) | ||
# for subsequent columns, line them up with the previous column's center | ||
offset = prev_center - col_center | ||
for nd in level_nodes: | ||
nd.pos_y += offset | ||
# update col_center | ||
min_y = min(nd.pos_y for nd in level_nodes) | ||
max_y = max(nd.pos_y for nd in level_nodes) | ||
col_center = (min_y + max_y) / 2.0 | ||
prev_center = col_center | ||
|
||
def _adjust_intermediary_nodes(self, diagram, offset=100.0): | ||
""" | ||
After the main layout, push nodes out of the way if they lie directly on | ||
a horizontal or vertical line that connects other nodes. | ||
:param diagram: The CustomDrawioDiagram with .nodes and .get_links_from_nodes() | ||
:param offset: How many pixels to shift a node if it is detected "on" a link. | ||
""" | ||
all_links = diagram.get_links_from_nodes() | ||
nodes = list(diagram.nodes.values()) | ||
|
||
# minimal bounding-box approach | ||
for nd in nodes: | ||
nd.half_w = float(nd.width) / 2.0 if nd.width else 20.0 | ||
nd.half_h = float(nd.height) / 2.0 if nd.height else 20.0 | ||
|
||
for link in all_links: | ||
A = link.source | ||
B = link.target | ||
|
||
# If the link is horizontal (A.y ~ B.y): | ||
if abs(A.pos_y - B.pos_y) < 1e-5: | ||
left_x = min(A.pos_x, B.pos_x) | ||
right_x = max(A.pos_x, B.pos_x) | ||
for N in nodes: | ||
if N not in (A, B): | ||
# bounding box check | ||
Ny_top = N.pos_y - N.half_h | ||
Ny_bot = N.pos_y + N.half_h | ||
# "in line" if A.y in that range | ||
if Ny_top <= A.pos_y <= Ny_bot: | ||
Nx_left = N.pos_x - N.half_w | ||
Nx_right = N.pos_x + N.half_w | ||
# Overlaps horizontally? | ||
if (Nx_left < right_x and Nx_right > left_x): | ||
# SHIFT N vertically out of the way | ||
N.pos_y -= offset | ||
|
||
# If the link is vertical (A.x ~ B.x): | ||
elif abs(A.pos_x - B.pos_x) < 1e-5: | ||
top_y = min(A.pos_y, B.pos_y) | ||
bot_y = max(A.pos_y, B.pos_y) | ||
for N in nodes: | ||
if N not in (A, B): | ||
Nx_left = N.pos_x - N.half_w | ||
Nx_right = N.pos_x + N.half_w | ||
# "in line" if A.x in that range | ||
if Nx_left <= A.pos_x <= Nx_right: | ||
Ny_top = N.pos_y - N.half_h | ||
Ny_bot = N.pos_y + N.half_h | ||
if (Ny_top < bot_y and Ny_bot > top_y): | ||
# SHIFT N horizontally out of the way | ||
N.pos_x -= offset |
Oops, something went wrong.