Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Increase rendering performance #78

Merged
merged 7 commits into from
Aug 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/show_large_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def generate_binary_tree(max_depth: int):
return nodes


nodes = generate_binary_tree(6)
nodes = generate_binary_tree(8)
print(f"{len(nodes)} total nodes")


Expand Down
5 changes: 3 additions & 2 deletions napari_arboretum/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class TreeNode:
"""TreeNode."""

ID: int
t: Tuple[int, int]
t: np.ndarray
generation: int
children: List[int] = field(default_factory=list)

Expand Down Expand Up @@ -154,7 +154,8 @@ def build_subgraph(layer: napari.layers.Tracks, search_node: int) -> List[TreeNo
def _node_from_graph(_id):

idx = np.where(layer.data[:, 0] == _id)[0]
t = (np.min(layer.data[idx, 1]), np.max(layer.data[idx, 1]))
# t = (np.min(layer.data[idx, 1]), np.max(layer.data[idx, 1]))
t = layer.data[idx, 1]
node = TreeNode(ID=_id, t=t, generation=1)

if _id in reverse_graph:
Expand Down
2 changes: 1 addition & 1 deletion napari_arboretum/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def append_mouse_callback(self, track_layer: Tracks) -> None:
when the layer is clicked.
"""

@track_layer.mouse_drag_callbacks.append
@track_layer.mouse_double_click_callbacks.append
def show_tree(tracks: Tracks, event: Event) -> None:
self.tracks = tracks

Expand Down
25 changes: 11 additions & 14 deletions napari_arboretum/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class Edge:
y: Tuple[float, float]
color: np.ndarray = WHITE
id: Optional[int] = None
node: Optional[TreeNode] = None


def layout_tree(nodes: List[TreeNode]) -> Tuple[List[Edge], List[Annotation]]:
Expand Down Expand Up @@ -66,14 +67,15 @@ def layout_tree(nodes: List[TreeNode]) -> Tuple[List[Edge], List[Annotation]]:
y = y_pos.pop(0)

# draw the root of the tree
edges.append(Edge(y=(y, y), x=(node.t[0], node.t[-1]), id=node.ID))
edges.append(Edge(y=(y, y), x=(node.t[0], node.t[-1]), id=node.ID, node=node))

if node.is_root:
annotations.append(Annotation(y=y, x=node.t[0], label=str(node.ID)))

# mark if this is an apoptotic tree
if node.is_leaf:
annotations.append(Annotation(y=y, x=node.t[-1], label=str(node.ID)))

if node.is_root:
annotations.append(Annotation(y=y, x=node.t[0], label=str(node.ID)))
continue

children = [t for t in nodes if t.ID in node.children]

Expand All @@ -94,6 +96,11 @@ def layout_tree(nodes: List[TreeNode]) -> Tuple[List[Edge], List[Annotation]]:

# plot a linking line to the children
edges.append(Edge(y=(y, y_pos[-1]), x=(node.t[-1], child.t[0])))

# if it's a leaf don't plot the annotation
if child.is_leaf:
continue

annotations.append(
Annotation(
y=y_pos[-1],
Expand All @@ -102,14 +109,4 @@ def layout_tree(nodes: List[TreeNode]) -> Tuple[List[Edge], List[Annotation]]:
)
)

# now that we have traversed the tree, calculate the span
tree_span = []
for edge in edges:
tree_span.append(edge.y[0])
tree_span.append(edge.y[1])

# # work out the span of the tree, we can modify positioning here
# min_x = min(tree_span)
# max_x = max(tree_span)

return edges, annotations
12 changes: 12 additions & 0 deletions napari_arboretum/visualisation/base_plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from ..tree import Annotation, Edge, layout_tree
from ..util import TrackPropertyMixin

# from ..profiler import profiler

GUI_MAXIMUM_WIDTH = 600

__all__ = ["TreePlotterBase", "TreePlotterQWidgetBase"]
Expand Down Expand Up @@ -46,6 +48,7 @@ def draw_tree(self) -> None:
subgraph_nodes = build_subgraph(self.tracks, self.track_id)
self.draw_from_nodes(subgraph_nodes, self.track_id)

# @profiler("draw_from_nodes")
def draw_from_nodes(
self, tree_nodes: List[TreeNode], track_id: Optional[int] = None
):
Expand All @@ -61,6 +64,8 @@ def draw_from_nodes(
for a in self.annotations:
self.add_annotation(a)

self.draw_tree_visual()

def update_edge_colors(self, update_live: bool = True) -> None:
"""
Update tree edge colours from the track properties.
Expand Down Expand Up @@ -117,6 +122,13 @@ def draw_current_time_line(self, time: int) -> None:
"""
raise NotImplementedError

@abc.abstractmethod
def draw_tree_visual(self) -> None:
"""
Function to draw the visual after construction.
"""
raise NotImplementedError


class TreePlotterQWidgetBase(TreePlotterBase):
"""
Expand Down
130 changes: 107 additions & 23 deletions napari_arboretum/visualisation/vispy_plotter.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from dataclasses import dataclass
from typing import Optional

import numpy as np
from qtpy.QtWidgets import QWidget
Expand All @@ -11,6 +10,10 @@
__all__ = ["VisPyPlotter"]


DEFAULT_TEXT_SIZE = 8
DEFAULT_BRANCH_WIDTH = 3


@dataclass
class Bounds:
xmin: float
Expand All @@ -19,6 +22,31 @@ class Bounds:
ymax: float


@dataclass
class TrackSubvisualProxy:
pos: np.ndarray
color: np.ndarray = np.array([1.0, 1.0, 1.0, 1.0])

@property
def connex(self):
connex = [True] * (self.pos.shape[0] - 1) + [False]
return connex

@property
def safe_color(self) -> np.ndarray:
if self.color.ndim != 2:
safe_color = np.repeat([self.color], self.pos.shape[0], axis=0)
return safe_color
return self.color


@dataclass
class AnnotationSubvisualProxy:
pos: np.ndarray
text: str
color: str = "white"


class VisPyPlotter(TreePlotterQWidgetBase):
"""
Tree plotter using pyqtgraph as the plotting backend.
Expand Down Expand Up @@ -71,6 +99,14 @@ def autoscale_view(self) -> None:
width * (1 + 2 * padding),
height * (1 + 2 * padding),
)

# change the aspect ratio of the camera if we have just a single branch
# this will centre the camera on the single branch, otherwise, set the
# aspect ratio to match the data
if width == 0:
self.view.camera.aspect = 1.0
else:
self.view.camera.aspect = None
self.view.camera.rect = rect

def update_colors(self) -> None:
Expand All @@ -85,7 +121,8 @@ def add_branch(self, e: Edge) -> None:
"""
Add a single branch to the tree.
"""
self.tree.add_track(e.id, np.column_stack((e.y, e.x)), e.color)
# self.tree.add_track(e.id, np.column_stack((e.y, e.x)), e.color)
self.tree.add_track(e)
self.autoscale_view()

def add_annotation(self, a: Annotation) -> None:
Expand All @@ -104,6 +141,12 @@ def draw_current_time_line(self, time: int) -> None:
pos=np.array([[bounds.xmin - padding, time], [bounds.xmax + padding, time]])
)

def draw_tree_visual(self) -> None:
"""
Draw the whole tree.
"""
self.tree.draw_tree()


class TreeVisual(scene.visuals.Compound):
"""
Expand All @@ -116,7 +159,22 @@ def __init__(self, parent):
self.unfreeze()
# Keep a reference to tracks we add so their colour can be changed later
self.tracks = {}
self.subvisuals = []
self.edges = []
self.annotations = []

subvisuals = [
scene.visuals.Line(color="white", width=DEFAULT_BRANCH_WIDTH),
scene.visuals.Text(
anchor_x="left",
anchor_y="top",
rotation=90,
font_size=DEFAULT_TEXT_SIZE,
color="white",
),
]

for visual in subvisuals:
self.add_subvisual(visual)

def get_branch_color(self, branch_id: int) -> np.ndarray:
return self.tracks[branch_id].color
Expand All @@ -125,9 +183,12 @@ def set_branch_color(self, branch_id: int, color: np.ndarray) -> None:
"""
Set the color of an individual branch.
"""
self.tracks[branch_id].set_data(color=color)
self.tracks[branch_id].color = color
self._subvisuals[0].set_data(
color=np.row_stack([e.safe_color for e in self.edges]),
)

def add_track(self, id: Optional[int], pos: np.ndarray, color: np.ndarray) -> None:
def add_track(self, e: Edge) -> None:
"""
Parameters
----------
Expand All @@ -139,35 +200,58 @@ def add_track(self, id: Optional[int], pos: np.ndarray, color: np.ndarray) -> No
Array of shape (n, 4) specifying RGBA values in range [0, 1] along
the track.
"""
if id is None:
visual = scene.visuals.Line(pos=pos, color=color, width=3)
color = e.color
pos = np.column_stack((e.y, e.x))

if e.node is None:
subvisual_proxy = TrackSubvisualProxy(
pos=pos,
color=np.array([1.0, 1.0, 1.0, 1.0]),
)
else:
# Split up line into individual time steps so color can vary
# along the line
ys = np.arange(pos[0, 1], pos[1, 1] + 1)
ys = np.asarray(e.node.t) # np.arange(pos[0, 1], pos[1, 1] + 1)
xs = np.ones(ys.size) * pos[0, 0]
visual = scene.visuals.Line(
pos=np.column_stack((xs, ys)), color=color, width=3
subvisual_proxy = TrackSubvisualProxy(
pos=np.column_stack((xs, ys)),
color=color,
)
self.tracks[id] = visual
# store a reference to this subvisual proxy
self.tracks[e.id] = subvisual_proxy

self.add_subvisual(visual)
self.subvisuals.append(visual)
self.edges.append(subvisual_proxy)

def add_annotation(self, x: float, y: float, label: str, color):
visual = scene.visuals.Text(

subvisual_proxy = AnnotationSubvisualProxy(
text=label,
color=color,
pos=[y, x, 0],
anchor_x="left",
anchor_y="top",
font_size=10,
)
self.add_subvisual(visual)
self.subvisuals.append(visual)

self.annotations.append(subvisual_proxy)

def clear(self) -> None:
"""Remove all tracks."""
while self.subvisuals:
subvisual = self.subvisuals.pop()
self.remove_subvisual(subvisual)
self.tracks = {}
self.edges = []
self.annotations = []

for visual in self._subvisuals:
visual._pos = None

if hasattr(visual, "_text"):
visual._text = None

def draw_tree(self) -> None:
"""Once the data is added, draw the tree."""

self._subvisuals[0].set_data(
pos=np.row_stack([e.pos for e in self.edges]),
color=np.row_stack([e.safe_color for e in self.edges]),
connect=np.concatenate([e.connex for e in self.edges]),
)

# TextVisual does not have a ``set_data`` method
self._subvisuals[1].pos = np.asarray([a.pos for a in self.annotations])
self._subvisuals[1].text = [a.text for a in self.annotations]