Skip to content

Commit

Permalink
Merge pull request #74 from BenTenmann/weighted-feather
Browse files Browse the repository at this point in the history
feat: added weighted graph capability to feather algo
  • Loading branch information
benedekrozemberczki authored Aug 4, 2021
2 parents 84e0694 + 8d60bca commit 3b0768a
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 39 deletions.
21 changes: 11 additions & 10 deletions karateclub/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,61 +5,62 @@

"""General Estimator base class."""


class Estimator(object):
"""Estimator base class with constructor and public methods."""
seed: int

def __init__(self):
"""Creatinng an estimator."""
"""Creating an estimator."""
pass


def fit(self):
"""Fitting a model."""
pass


def get_embedding(self):
"""Getting the embeddings (graph or node level)."""
pass


def get_memberships(self):
"""Getting the membership dictionary."""
pass


def get_cluster_centers(self):
"""Getting the cluster centers."""
pass


def _set_seed(self):
"""Creating the initial random seed."""
random.seed(self.seed)
np.random.seed(self.seed)

def _ensure_integrity(self, graph: nx.classes.graph.Graph) -> nx.classes.graph.Graph:
@staticmethod
def _ensure_integrity(graph: nx.classes.graph.Graph) -> nx.classes.graph.Graph:
"""Ensure walk traversal conditions."""
edge_list = [(index, index) for index in range(graph.number_of_nodes())]
graph.add_edges_from(edge_list)

return graph

def _check_indexing(self, graph: nx.classes.graph.Graph):
@staticmethod
def _check_indexing(graph: nx.classes.graph.Graph):
"""Checking the consecutive numeric indexing."""
numeric_indices = [index for index in range(graph.number_of_nodes())]
node_indices = sorted([node for node in graph.nodes()])
assert numeric_indices == node_indices, "The node indexing is wrong."

assert numeric_indices == node_indices, "The node indexing is wrong."

def _check_graph(self, graph: nx.classes.graph.Graph) -> nx.classes.graph.Graph:
"""Check the Karate Club assumptions about the graph."""
self._check_indexing(graph)
graph = self._ensure_integrity(graph)
return graph

return graph

def _check_graphs(self, graphs: List[nx.classes.graph.Graph]):
"""Check the Karate Club assumptions for a list of graphs."""
graphs = [self._check_graph(graph) for graph in graphs]

return graphs

99 changes: 70 additions & 29 deletions karateclub/graph_embedding/feathergraph.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,35 @@
import math
from typing import List
from functools import partial
from typing import List, Callable

import numpy as np
import networkx as nx
import scipy.sparse as sparse

from karateclub.estimator import Estimator


def _weighted_directed_degree(node: int, graph: nx.classes.graph.Graph) -> float:
out = graph.degree(node, weight='weight')

return out


def _unweighted_undirected_degree(node: int, graph: nx.classes.graph.Graph) -> float:
out = graph.degree[node]

return float(out)


def _get_degree_fn(graph) -> Callable:
"""Gets the function to calculate the graph node degree"""
fn = _weighted_directed_degree if nx.classes.function.is_weighted(graph) \
else _unweighted_undirected_degree
fn = partial(fn, graph=graph)

return fn


class FeatherGraph(Estimator):
r"""An implementation of `"FEATHER-G" <https://arxiv.org/abs/2005.07959>`_
from the CIKM '20 paper "Characteristic Functions on Graphs: Birds of a Feather,
Expand All @@ -21,16 +46,28 @@ class FeatherGraph(Estimator):
pooling (str): Permutation invariant pooling function, one of:
(:obj:`"mean"`, :obj:`"max"`, :obj:`"min"`). Default is "mean."
"""
def __init__(self, order: int=5, eval_points: int=25,
theta_max: float=2.5, seed: int=42, pooling: str="mean"):
n_nodes: int
degree_fn: Callable
_embedding: List[np.ndarray]

def __init__(self, order: int = 5, eval_points: int = 25,
theta_max: float = 2.5, seed: int = 42, pooling: str = "mean"):
super(FeatherGraph, self).__init__()

self.order = order
self.eval_points = eval_points
self.theta_max = theta_max
self.seed = seed
self.pooling = pooling

try:
pool_fn = getattr(np, pooling)
except AttributeError:
raise ValueError(f'{pooling.__repr__()} is not a valid pooling function')

self.pooling = pooling
self.pool_fn = partial(pool_fn, axis=0)

def _create_D_inverse(self, graph):
def _create_d_inverse(self) -> sparse.coo_matrix:
"""
Creating a sparse inverse degree matrix.
Expand All @@ -40,14 +77,14 @@ def _create_D_inverse(self, graph):
Return types:
* **D_inverse** *(Scipy array)* - Diagonal inverse degree matrix.
"""
index = np.arange(graph.number_of_nodes())
values = np.array([1.0/graph.degree[node] for node in range(graph.number_of_nodes())])
shape = (graph.number_of_nodes(), graph.number_of_nodes())
index = np.arange(self.n_nodes)
values = np.array([1.0 / self.degree_fn(node) for node in range(self.n_nodes)]) # <- ?

shape = (self.n_nodes, self.n_nodes)
D_inverse = sparse.coo_matrix((values, (index, index)), shape=shape)
return D_inverse


def _get_normalized_adjacency(self, graph):
def _get_normalized_adjacency(self, graph: nx.classes.graph.Graph) -> sparse.coo_matrix:
"""
Calculating the normalized adjacency matrix.
Expand All @@ -57,13 +94,13 @@ def _get_normalized_adjacency(self, graph):
Return types:
* **A_hat** *(SciPy array)* - The scattering matrix of the graph.
"""
A = nx.adjacency_matrix(graph, nodelist=range(graph.number_of_nodes()))
D_inverse = self._create_D_inverse(graph)
A = nx.adjacency_matrix(graph, nodelist=range(self.n_nodes))
D_inverse = self._create_d_inverse()

A_hat = D_inverse.dot(A)
return A_hat


def _create_node_feature_matrix(self, graph):
def _create_node_feature_matrix(self, graph: nx.classes.graph.Graph) -> np.ndarray:
"""
Calculating the node features.
Expand All @@ -73,13 +110,18 @@ def _create_node_feature_matrix(self, graph):
Return types:
* **X** *(NumPy array)* - The node features.
"""
log_degree = np.array([math.log(graph.degree(node)+1) for node in range(graph.number_of_nodes())]).reshape(-1, 1)
clustering_coefficient = np.array([nx.clustering(graph, node) for node in range(graph.number_of_nodes())]).reshape(-1, 1)
log_degree = np.array([math.log(self.degree_fn(node) + 1)
for node in range(self.n_nodes)])
log_degree = log_degree.reshape(-1, 1)

clustering_coefficient = np.array([nx.clustering(graph, node)
for node in range(self.n_nodes)])
clustering_coefficient = clustering_coefficient.reshape(-1, 1)

X = np.concatenate([log_degree, clustering_coefficient], axis=1)
return X


def _calculate_feather(self, graph):
def _calculate_feather(self, graph: nx.classes.graph.Graph) -> np.ndarray:
"""
Calculating the characteristic function features of a graph.
Expand All @@ -89,29 +131,28 @@ def _calculate_feather(self, graph):
Return types:
* **features** *(Numpy vector)* - The embedding of a single graph.
"""
self.n_nodes = graph.number_of_nodes()
self.degree_fn = _get_degree_fn(graph)

A_tilde = self._get_normalized_adjacency(graph)
X = self._create_node_feature_matrix(graph)
theta = np.linspace(0.01, self.theta_max, self.eval_points)

X = np.outer(X, theta)
X = X.reshape(graph.number_of_nodes(), -1)
X = np.concatenate([np.cos(X), np.sin(X)], axis=1)

feature_blocks = []
for _ in range(self.order):
X = A_tilde.dot(X)
feature_blocks.append(X)

feature_blocks = np.concatenate(feature_blocks, axis=1)
if self.pooling == "mean":
feature_blocks = np.mean(feature_blocks, axis=0)
elif self.pooling == "min":
feature_blocks = np.min(feature_blocks, axis=0)
elif self.pooling == "max":
feature_blocks = np.max(feature_blocks, axis=0)
else:
raise ValueError("Wrong pooling function.")
return feature_blocks
feature_blocks = self.pool_fn(feature_blocks)

return feature_blocks

def fit(self, graphs: List[nx.classes.graph.Graph]):
def fit(self, graphs: List[nx.classes.graph.Graph]) -> None:
"""
Fitting a graph level FEATHER model.
Expand All @@ -122,11 +163,11 @@ def fit(self, graphs: List[nx.classes.graph.Graph]):
graphs = self._check_graphs(graphs)
self._embedding = [self._calculate_feather(graph) for graph in graphs]


def get_embedding(self) -> np.array:
r"""Getting the embedding of graphs.
Return types:
* **embedding** *(Numpy array)* - The embedding of graphs.
"""
return np.array(self._embedding)

0 comments on commit 3b0768a

Please sign in to comment.