Skip to content

Commit

Permalink
barycenter algo for same grouplvl nodes
Browse files Browse the repository at this point in the history
  • Loading branch information
FloSch62 committed Dec 21, 2024
1 parent ed8fe2e commit 91e8866
Show file tree
Hide file tree
Showing 2 changed files with 363 additions and 229 deletions.
296 changes: 179 additions & 117 deletions core/layout/horizontal_layout.py
Original file line number Diff line number Diff line change
@@ -1,132 +1,194 @@
import logging
from core.layout.layout_manager import LayoutManager
from collections import defaultdict

from core.layout.layout_manager import LayoutManager

logger = logging.getLogger(__name__)

class HorizontalLayout(LayoutManager):
"""
Applies a horizontal layout strategy to arrange nodes.
Applies an iterative "barycenter" layout strategy to arrange nodes horizontally.
Each node's y-position is determined purely from how it connects to other nodes,
including same-level and cross-level edges.
"""

def apply(self, diagram, verbose=False) -> None:
logger.debug("Applying horizontal layout...")
"""
Main entry point: Called by clab2drawio code to apply a horizontal layout.
:param diagram: CustomDrawioDiagram instance with .nodes
:param verbose: Whether to log extra info
"""
logger.debug("Applying iterative barycenter layout (horizontal)...")
self.diagram = diagram
self.verbose = verbose
self._calculate_positions()

def _calculate_positions(self):
nodes = self.diagram.nodes
nodes = sorted(nodes.values(), key=lambda node: (node.graph_level, node.name))

# Get padding from styles
padding_x = self.diagram.styles['padding_x']
padding_y = self.diagram.styles['padding_y']

x_start, y_start = 100, 100

logger.debug("Nodes before calculate_positions:" + str(nodes))

def prioritize_placement(nodes, level):
diagram = self.diagram
if level == diagram.get_max_level():
ordered_nodes = sorted(nodes, key=lambda node: node.name)
else:
multi_connection_nodes = [node for node in nodes if node.get_connection_count_within_level() > 1]
single_connection_nodes = [node for node in nodes if node.get_connection_count_within_level() == 1]
zero_connection_nodes = [node for node in nodes if node.get_connection_count_within_level() == 0]

multi_connection_nodes_with_lateral = []
multi_connection_nodes_without_lateral = []
for node in multi_connection_nodes:
if any(
link.target in multi_connection_nodes
for link in node.get_lateral_links()
):
multi_connection_nodes_with_lateral.append(node)
else:
multi_connection_nodes_without_lateral.append(node)

sorted_multi_connection_nodes_with_lateral = []
while multi_connection_nodes_with_lateral:
node = multi_connection_nodes_with_lateral.pop(0)
sorted_multi_connection_nodes_with_lateral.append(node)
for link in node.get_lateral_links():
if link.target in multi_connection_nodes_with_lateral:
multi_connection_nodes_with_lateral.remove(link.target)
sorted_multi_connection_nodes_with_lateral.append(link.target)

multi_connection_nodes_without_lateral = sorted(
multi_connection_nodes_without_lateral, key=lambda node: node.name
)
sorted_multi_connection_nodes_with_lateral = sorted(
sorted_multi_connection_nodes_with_lateral, key=lambda node: node.name
)
single_connection_nodes = sorted(single_connection_nodes, key=lambda node: node.name)

ordered_nodes = (
single_connection_nodes[: len(single_connection_nodes) // 2]
+ multi_connection_nodes_without_lateral
+ sorted_multi_connection_nodes_with_lateral
+ single_connection_nodes[len(single_connection_nodes) // 2 :]
+ zero_connection_nodes
)

return ordered_nodes

nodes_by_graphlevel = defaultdict(list)
for node in nodes:
nodes_by_graphlevel[node.graph_level].append(node)

for graphlevel, graphlevel_nodes in nodes_by_graphlevel.items():
ordered_nodes = prioritize_placement(graphlevel_nodes, graphlevel)
for i, node in enumerate(ordered_nodes):
# horizontal layout
node.pos_x = x_start + graphlevel * padding_x
node.pos_y = y_start + i * padding_y

self._center_align_nodes(nodes_by_graphlevel, layout="horizontal", verbose=self.verbose)
intermediaries_x, intermediaries_y = self.diagram.get_nodes_between_interconnected()
self._adjust_intermediary_nodes(intermediaries_y, layout="horizontal", verbose=self.verbose)

def _adjust_intermediary_nodes(self, intermediaries, layout, verbose=False):
if not intermediaries:
return
intermediaries_by_level = defaultdict(list)
for node in intermediaries:
intermediaries_by_level[node.graph_level].append(node)

selected_level = max(
intermediaries_by_level.keys(),
key=lambda lvl: len(intermediaries_by_level[lvl]),
)
selected_group = intermediaries_by_level[selected_level]

if len(selected_group) == 1:
node = selected_group[0]
if layout == "vertical":
node.pos_x = node.pos_x - 100
else:
node.pos_y = node.pos_y - 100
else:
for i, node in enumerate(selected_group):
if layout == "vertical":
node.pos_x = node.pos_x - 100 + i * 200
# Put nodes into bins by their graph_level
nodes_by_level = defaultdict(list)
for n in self.diagram.nodes.values():
nodes_by_level[n.graph_level].append(n)

# Initialize: give each level a naive top->bottom order
sorted_levels = sorted(nodes_by_level.keys())
for level in sorted_levels:
# sort by name
nodes_by_level[level].sort(key=lambda nd: nd.name)
# assign an initial y from top->bottom
for i, nd in enumerate(nodes_by_level[level]):
nd.pos_y = float(100 + i * self.diagram.styles["padding_y"])

# We'll do N passes. Each pass: left->right, then right->left
# In each pass, we compute barycenter for each node (based on neighbors)
# and reorder the level by that barycenter.

def compute_barycenters_at_level(level_nodes):
"""
For each node in level_nodes, compute the average y of all its neighbors,
ignoring their levels (so it includes same-level edges and multi-level edges).
Store the result in node._bary (temp).
"""
for nd in level_nodes:
# gather y positions of neighbors
neighbor_y_positions = []
for nbr in nd.get_neighbors():
# skip if neighbor has no numeric pos_y
try:
ny = float(nbr.pos_y)
except (TypeError, ValueError):
ny = 0.0
neighbor_y_positions.append(ny)

if neighbor_y_positions:
nd._bary = sum(neighbor_y_positions) / len(neighbor_y_positions)
else:
node.pos_y = node.pos_y - 100 + i * 200

def _center_align_nodes(self, nodes_by_graphlevel, layout="horizontal", verbose=False):
attr_x, attr_y = ("pos_x", "pos_y") if layout == "vertical" else ("pos_y", "pos_x")

prev_graphlevel_center = None
for graphlevel, nodes in sorted(nodes_by_graphlevel.items()):
graphlevel_centers = [getattr(node, attr_x) for node in nodes]

if prev_graphlevel_center is None:
prev_graphlevel_center = (min(graphlevel_centers) + max(graphlevel_centers)) / 2
nd._bary = 0.0

def reorder_by_barycenter(level_nodes):
"""
Sort the list of nodes by the barycenter we just computed.
Keep it stable so that ties won't cause random reordering.
"""
level_nodes.sort(key=lambda nd: nd._bary)

num_passes = 4 # or more, typically 4~6 is enough
for _iter in range(num_passes):
# left->right sweep
for level in sorted_levels:
level_nodes = nodes_by_level[level]
compute_barycenters_at_level(level_nodes)
reorder_by_barycenter(level_nodes)
# reassign y after sorting
for i, nd in enumerate(level_nodes):
nd.pos_y = float(100 + i * self.diagram.styles["padding_y"])

# right->left sweep
for level in reversed(sorted_levels):
level_nodes = nodes_by_level[level]
compute_barycenters_at_level(level_nodes)
reorder_by_barycenter(level_nodes)
# reassign y
for i, nd in enumerate(level_nodes):
nd.pos_y = float(100 + i * self.diagram.styles["padding_y"])

# Assign final pos_x from graph_level, keep the pos_y from the last iteration
for level in sorted_levels:
for node in nodes_by_level[level]:
node.pos_x = float(100 + level * self.diagram.styles["padding_x"])

self._center_align_nodes(nodes_by_level)
self._adjust_intermediary_nodes(diagram)

logger.debug("Iterative barycenter layout complete (horizontal).")

def _center_align_nodes(self, nodes_by_level):
"""
Shift each level vertically so they are centered
around a consistent "global_center" or around the previous column.
"""
sorted_levels = sorted(nodes_by_level.keys())
# pick a global center, or compute from the leftmost column
global_center = 300.0

prev_center = None
for level in sorted_levels:
level_nodes = nodes_by_level[level]
if not level_nodes:
continue
# find min & max y
min_y = min(nd.pos_y for nd in level_nodes)
max_y = max(nd.pos_y for nd in level_nodes)
col_center = (min_y + max_y) / 2.0

if prev_center is None:
# for the leftmost column, align it to global_center
offset = global_center - col_center
for nd in level_nodes:
nd.pos_y += offset
# update col_center after shift
min_y = min(nd.pos_y for nd in level_nodes)
max_y = max(nd.pos_y for nd in level_nodes)
col_center = (min_y + max_y) / 2.0
prev_center = col_center
else:
graphlevel_center = sum(graphlevel_centers) / len(nodes)
offset = prev_graphlevel_center - graphlevel_center
for node in nodes:
setattr(node, attr_x, getattr(node, attr_x) + offset)
prev_graphlevel_center = sum(getattr(node, attr_x) for node in nodes) / len(nodes)
# for subsequent columns, line them up with the previous column's center
offset = prev_center - col_center
for nd in level_nodes:
nd.pos_y += offset
# update col_center
min_y = min(nd.pos_y for nd in level_nodes)
max_y = max(nd.pos_y for nd in level_nodes)
col_center = (min_y + max_y) / 2.0
prev_center = col_center

def _adjust_intermediary_nodes(self, diagram, offset=100.0):
"""
After the main layout, push nodes out of the way if they lie directly on
a horizontal or vertical line that connects other nodes.
:param diagram: The CustomDrawioDiagram with .nodes and .get_links_from_nodes()
:param offset: How many pixels to shift a node if it is detected "on" a link.
"""
all_links = diagram.get_links_from_nodes()
nodes = list(diagram.nodes.values())

# minimal bounding-box approach
for nd in nodes:
nd.half_w = float(nd.width) / 2.0 if nd.width else 20.0
nd.half_h = float(nd.height) / 2.0 if nd.height else 20.0

for link in all_links:
A = link.source
B = link.target

# If the link is horizontal (A.y ~ B.y):
if abs(A.pos_y - B.pos_y) < 1e-5:
left_x = min(A.pos_x, B.pos_x)
right_x = max(A.pos_x, B.pos_x)
for N in nodes:
if N not in (A, B):
# bounding box check
Ny_top = N.pos_y - N.half_h
Ny_bot = N.pos_y + N.half_h
# "in line" if A.y in that range
if Ny_top <= A.pos_y <= Ny_bot:
Nx_left = N.pos_x - N.half_w
Nx_right = N.pos_x + N.half_w
# Overlaps horizontally?
if (Nx_left < right_x and Nx_right > left_x):
# SHIFT N vertically out of the way
N.pos_y -= offset

# If the link is vertical (A.x ~ B.x):
elif abs(A.pos_x - B.pos_x) < 1e-5:
top_y = min(A.pos_y, B.pos_y)
bot_y = max(A.pos_y, B.pos_y)
for N in nodes:
if N not in (A, B):
Nx_left = N.pos_x - N.half_w
Nx_right = N.pos_x + N.half_w
# "in line" if A.x in that range
if Nx_left <= A.pos_x <= Nx_right:
Ny_top = N.pos_y - N.half_h
Ny_bot = N.pos_y + N.half_h
if (Ny_top < bot_y and Ny_bot > top_y):
# SHIFT N horizontally out of the way
N.pos_x -= offset
Loading

0 comments on commit 91e8866

Please sign in to comment.