Skip to content

Commit

Permalink
added parameters for hypergraph function
Browse files Browse the repository at this point in the history
  • Loading branch information
tm4185s committed Dec 7, 2023
1 parent ac82d33 commit 443e402
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 18 deletions.
1 change: 0 additions & 1 deletion requirements/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,5 @@ numpy>=1.19.5
quantities>=0.14.1
matplotlib>=3.3.2
seaborn>=0.9.0
bokeh>=3.0.0
holoviews>=1.16.0
networkx>=3.0.0
42 changes: 35 additions & 7 deletions viziphant/patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,8 +399,8 @@ def plot_patterns(spiketrains, patterns, circle_sizes=(3, 50, 70),
axes.yaxis.set_label_coords(-0.01, 0.5)
return axes

# TODO: add a parameter node_size
def plot_patterns_hypergraph(patterns, num_neurons=None):
def plot_patterns_hypergraph(patterns, node_size=3, pattern_size=None, num_neurons=None,\
highlight_patterns=None, mark_neuron=None, node_color='white'):
"""
Hypergraph visualization of spike patterns.
Expand Down Expand Up @@ -430,12 +430,19 @@ def plot_patterns_hypergraph(patterns, num_neurons=None):
:func:`elephant.spade.spade` or
:func:`elephant.cell_assembly_detection.cell_assembly_detection`
pattern detectors.
node_size (optional): int
Change the size of the drawen nodes
pattern_size (optional): tuple or int
Only draw patterns that are in range of pattern_size
num_neurons: None or int
If None, only the neurons that are part of a pattern are shown. If an
integer is passed, it identifies the total number of recorded neurons
including non-pattern neurons to be additionally shown in the graph.
Default: None
highlight_patterns (optional) : int
Highlight pattern which includes neuron x
node_color (optional) : String
change the color of the nodes
Returns
-------
A handle to a matplotlib figure containing the hypergraph.
Expand Down Expand Up @@ -493,11 +500,33 @@ def plot_patterns_hypergraph(patterns, num_neurons=None):

# Create one hypergraph per dataset
hyperedges = []

# Create "range" of pattern_size
if (isinstance(pattern_size, int)): pattern_size=(pattern_size, pattern_size)

# Create one hyperedge from every pattern
for pattern in patterns:
# A hyperedge is the set of neurons of a pattern
hyperedges.append(pattern['neurons'])

if pattern_size is None:
hyperedges.append(pattern['neurons'])
# check if hyperedge(pattern) is greater or equal to min_pattern_size
elif len(pattern['neurons']) >= pattern_size[0] and len(pattern['neurons']) <= pattern_size[1] or highlight_patterns in pattern['neurons']:
hyperedges.append(pattern['neurons'])

# check if neuron to highlight is in hyperedge
temp_hyperedges = []
if highlight_patterns is not None:
if isinstance(highlight_patterns, int):
for edge in hyperedges:
if highlight_patterns in edge:
temp_hyperedges.append(edge)
hyperedges = temp_hyperedges

# TODO: highlight_patterns as a list

if (len(hyperedges) == 0):
raise Exception('Could not find any hyperedges that match the given parameters')

# Currently, all hyperedges receive the same weights
weights = [weight] * len(hyperedges)

Expand All @@ -507,8 +536,7 @@ def plot_patterns_hypergraph(patterns, num_neurons=None):
weights=weights,
repulse=repulsive)
hypergraphs.append(hg)

view = View(hypergraphs)
view = View(hypergraphs, node_size, mark_neuron, node_color)
fig = view.show(subset_style=VisualizationStyle.COLOR,
triangulation_style=VisualizationStyle.INVISIBLE)

Expand Down
37 changes: 27 additions & 10 deletions viziphant/patterns_src/view.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class View:
for the visualization of hypergraphs.
"""

def __init__(self, hypergraphs, title=None):
def __init__(self, hypergraphs, node_size=3, mark_neuron=None, node_color='white', title=None):
"""
Constructs a View object that handles the visualization
of the given hypergraphs.
Expand All @@ -45,6 +45,15 @@ def __init__(self, hypergraphs, title=None):
hypergraphs: list of Hypergraph objects
Hypergraphs to be visualized.
Each hypergraph should contain data of one data set.
node_size (optional) : int
Size of the nodes in the Hypergraphs
mark_neuron (optional) : int
Neuron with given number will be highlighted
node_color (optional) : String
change the color of the nodes
"""

# Hyperedge drawings
Expand All @@ -53,12 +62,21 @@ def __init__(self, hypergraphs, title=None):
# Which color of the color map to use next
self.current_color = 1

# Size of the vertices TODO: add as parameter
# radius of the hyperedges
self.node_radius = .2

# Size of the nodes (vertices of hypergraph)
self.node_size = node_size

# Color of the nodes
self.node_color = node_color

# Selected title of the figure
self.title = title

# Marked node will be in a different color
self.mark_neuron = mark_neuron

# If no data was provided, fill in dummy data
if hypergraphs:
self.hypergraphs = hypergraphs
Expand Down Expand Up @@ -107,7 +125,7 @@ def _setup_graph_visualization(self):
# The hv.Graph visualization is used for displaying the data
# hv.Graph displays the nodes (and optionally binary edges) of a graph
dynamic_map = hv.DynamicMap(hv.Graph, streams=[pipe])

# Define options for visualization
dynamic_map.opts(
# Some space around the Graph in order to avoid nodes being on the
Expand All @@ -124,8 +142,7 @@ def _setup_graph_visualization(self):
# All in black
cmap=['#ffffff', '#ffffff'] * 50,
# Size of the nodes
node_size=self.node_radius))

node_size=self.node_size, node_color=self.node_color, show_legend=True))
return dynamic_map, pipe

def _setup_hyperedge_drawing(self):
Expand Down Expand Up @@ -226,10 +243,10 @@ def show(self,
# Set size of the plot to a square to avoid distortions
self.plot = plot.redim.range(x=(-1, 11), y=(-1, 11))
# TODO: how to get axes? currently figure
axes = hv.render(plot, backend="matplotlib")
return axes

def draw_hyperedges(self,
fig = hv.render(plot, backend="matplotlib")
return fig
def draw_hyperedges(self, highlight_neuron=None,
subset_style=VisualizationStyle.COLOR,
triangulation_style=VisualizationStyle.INVISIBLE):
"""
Expand Down Expand Up @@ -396,7 +413,7 @@ def _update_nodes(self, data):
nodes = hv.Nodes((pos_x, pos_y, vertex_ids, vertex_labels),
extents=(0.01, 0.01, 0.01, 0.01),
vdims='Label')

new_data = ((edge_source, edge_target), nodes)
self.pipe.send(new_data)

Expand Down

0 comments on commit 443e402

Please sign in to comment.