diff --git a/examples/sp2gcl/evaluation_test.py b/examples/sp2gcl/evaluation_test.py new file mode 100644 index 00000000..1ec265c5 --- /dev/null +++ b/examples/sp2gcl/evaluation_test.py @@ -0,0 +1,65 @@ +import os +from tensorlayerx.model import TrainOneStep, WithLoss +os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' +import tensorlayerx as tlx +import tensorlayerx.nn as nn + +class SemiSpvzLoss(WithLoss): + def __init__(self, net, loss_fn): + super(SemiSpvzLoss, self).__init__(backbone=net, loss_fn=loss_fn) + + def forward(self, data, label): + logits = self.backbone_network(data['x']) + train_logits = tlx.gather(logits, data['train_idx']) + train_y = tlx.gather(data['y'], data['train_idx']) + loss = self._loss_fn(train_logits, train_y) + return loss + +class LogReg(nn.Module): + def __init__(self, hid_dim, out_dim, name=None): + super(LogReg, self).__init__(name=name) + self.linear = nn.Linear(in_features=hid_dim, out_features=out_dim, W_init=tlx.initializers.xavier_uniform(), b_init=tlx.initializers.zeros()) + + def forward(self, x): + return self.linear(x) + + +def node_evaluation(emb, y, train_idx, test_idx, lr=1e-2, weight_decay=1e-4): + nclass = y.max().item() + 1 + logreg = LogReg(hid_dim=emb.shape[1], out_dim=nclass, name="logreg") + opt = tlx.optimizers.Adam(lr=lr, weight_decay=weight_decay) + train_weights = logreg.trainable_weights + loss = tlx.losses.softmax_cross_entropy_with_logits + loss_func = SemiSpvzLoss(logreg, loss) + train_one_step = TrainOneStep(loss_func, opt, train_weights) + + data = { + 'x': emb, + 'y': y, + 'train_idx':train_idx, + 'test_idx':test_idx + } + + best_test_acc = 0 + for epoch in range(100): + logreg.set_train() + loss = train_one_step(data=data, label=y) + logreg.set_eval() + test_logits = logreg(emb[test_idx]) + test_preds = tlx.argmax(test_logits, axis=1) + test_acc = tlx.reduce_sum(test_preds == y[test_idx]).float() / test_idx.size(0) + if test_acc > best_test_acc: + best_test_acc = test_acc + logreg.save_weights("./"+logreg.name+".npz", format='npz_dict') + + # print("==Epoch [{:0>3d}] ".format(epoch+1)\ + # + " train loss: {:.4f}".format(loss.item())\ + # + " acc: {:.4f}".format(test_acc.item())) + + logreg.load_weights("./"+logreg.name+".npz", format='npz_dict') + logreg.set_eval() + test_logits = logreg(emb[test_idx]) + test_preds = tlx.argmax(test_logits, axis=1) + test_acc = tlx.reduce_sum(test_preds == y[test_idx]).float() / test_idx.size(0) + + return test_acc.item() diff --git a/examples/sp2gcl/readme.md b/examples/sp2gcl/readme.md new file mode 100644 index 00000000..5f33211a --- /dev/null +++ b/examples/sp2gcl/readme.md @@ -0,0 +1,33 @@ +# Graph Contrastive Learning with Stable and Scalable + +- Paper link: [https://proceedings.neurips.cc/paper_files/paper/2023/file/8e9a6582caa59fda0302349702965171-Paper-Conference.pdf](https://proceedings.neurips.cc/paper_files/paper/2023/file/8e9a6582caa59fda0302349702965171-Paper-Conference.pdf) +- Author's code repo: [https://github.com/bdy9527/Sp2GCL](https://github.com/bdy9527/Sp2GCL). + +# Dataset Statics + +| Dataset | # Nodes | # Edges | # Classes | +|----------|---------|----------|-----------| +| PubMed | 19,717 | 88,648 | 3 | +| Wiki-CS | 11,701 | 216,123 | 10 | +| Facebook | 22,470 | 342,004 | 4 | + + + +Results +------- + +```bash +TL_BACKEND="torch" python sp2gcl_trainer.py --dataset pubmed --hidden_dim 4096 --spe_dim 30 --output_dim 32 --lr 0.001 --period 50 --l2_coef 5e-4 --n_epoch 5 +TL_BACKEND="torch" python sp2gcl_trainer.py --dataset wikics --hidden_dim 2048 --spe_dim 100 --output_dim 32 --lr 0.001 --period 10 --l2_coef 5e-1 --n_epoch 3 +TL_BACKEND="torch" python sp2gcl_trainer.py --dataset facebook --hidden_dim 1500 --spe_dim 100 --output_dim 32 --lr 0.001 --period 10 --l2_coef 5e-4 --n_epoch 5 +``` + + +# Dataset Statics + +| Dataset | Paper Code | Out(th) | +| -------- | ---------- | ---------- | +| PubMed | 82.3±0.3 | 78.66±0.76 | +| Wiki-CS | 79.42±0.19 | 78.64±0.20 | +| Facebook | 90.43±0.13 | 87.53±0.34 | + diff --git a/examples/sp2gcl/sp2gcl_trainer.py b/examples/sp2gcl/sp2gcl_trainer.py new file mode 100644 index 00000000..be4b7a31 --- /dev/null +++ b/examples/sp2gcl/sp2gcl_trainer.py @@ -0,0 +1,142 @@ +import os +# os.environ['TL_BACKEND'] = 'torch' +# os.environ['CUDA_VISIBLE_DEVICES'] = '0' +os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' +import argparse +import tensorlayerx as tlx +import scipy.sparse.linalg +from evaluation_test import node_evaluation +from tensorlayerx.model import WithLoss, TrainOneStep +from gammagl.models import SpaSpeNode +from gammagl.datasets import FacebookPagePage, WikiCS, Planetoid +from gammagl.utils import get_laplacian, to_scipy_sparse_matrix, get_train_val_test_split, mask_to_index +import numpy as np + + +def compute_laplacian(data, args): + edge_index = data.edge_index + num_nodes = data.num_nodes + + index, attr = get_laplacian(edge_index, num_nodes, normalization="sym") + L = to_scipy_sparse_matrix(index, attr) + e, u = scipy.sparse.linalg.eigsh(L, k=args.spe_dim, which='SM', tol=1e-3) + + data.e = tlx.convert_to_tensor(e, dtype=tlx.float32) + data.u = tlx.convert_to_tensor(u, dtype=tlx.float32) + + return data + + +class ContrastiveLoss(WithLoss): + def __init__(self, model, loss_fn): + super(ContrastiveLoss, self).__init__(backbone=model, loss_fn=loss_fn) + + def forward(self, data, label): + h1, h2 = self.backbone_network(data['x'], data['edge_index'], data['e'], data['u']) + + logits = tlx.matmul(h1, tlx.transpose(h2, perm=(1, 0))) + + exp_logits = tlx.exp(logits) + diag = tlx.convert_to_tensor(np.diagonal(tlx.convert_to_numpy(exp_logits))) + + sum_rows_a = tlx.reduce_sum(exp_logits, axis=1) + sum_rows_b = tlx.reduce_sum(exp_logits, axis=0) + + log_prob1 = tlx.log(diag / (sum_rows_a - diag)) + log_prob2 = tlx.log(diag / (sum_rows_b - diag)) + loss = -0.5 * (log_prob1 + log_prob2).mean() + + return loss + +def main(args): + if args.dataset in ['pubmed', 'wikics', 'facebook']: + if args.dataset == 'facebook': + dataset = FacebookPagePage() + elif args.dataset == 'wikics': + dataset = WikiCS() + elif args.dataset == 'pubmed': + dataset = Planetoid(name=args.dataset) + + else: + raise ValueError('Unknown dataset: {}'.format(args.dataset)) + + data = dataset[0] + data = compute_laplacian(data, args) + e = tlx.convert_to_tensor(data.e[:args.spe_dim], dtype=tlx.float32) + u = tlx.convert_to_tensor(data.u[:, :args.spe_dim], dtype=tlx.float32) + if 'train_mask' in data.keys: + train_mask = data.train_mask + test_mask = data.test_mask + val_mask = data.val_mask + else: + train_mask, val_mask, test_mask = get_train_val_test_split(data, 0.1, 0.1) + train_idx = mask_to_index(train_mask) + test_idx = mask_to_index(test_mask) + val_idx = mask_to_index(val_mask) + + model = SpaSpeNode(input_dim=data.x.shape[1], + spe_dim=args.spe_dim, + hidden_dim=args.hidden_dim, + output_dim=args.output_dim, + period=args.period, + name="sp2gcl") + + optimizer = tlx.optimizers.Adam(lr=args.lr, weight_decay=args.l2_coef) + train_weights = model.trainable_weights + loss_func = ContrastiveLoss(model, tlx.losses.softmax_cross_entropy_with_logits) + train_one_step = TrainOneStep(loss_func, optimizer, train_weights) + + data_all = { + 'x': data.x, + 'edge_index': data.edge_index, + 'train_idx': train_idx, + 'valid_idx': val_idx, + 'test_idx': test_idx, + 'e': data.e, + 'u': data.u + } + + best_val_acc = 0 + for epoch in range(args.n_epoch): + model.set_train() + train_loss = train_one_step(data=data_all, label=data.y) + model.set_eval() + spa_emb = tlx.detach(model.spa_encoder(data.x, data.edge_index)) + spe_emb = tlx.detach(model.spe_encoder(e, u)) + val_acc = node_evaluation(tlx.concat((spa_emb, spe_emb), axis=-1), data.y, train_idx, val_idx) + if val_acc > best_val_acc: + best_val_acc = val_acc + model.save_weights(args.best_model_path+model.name+".npz", format='npz_dict') + + print("Epoch [{:0>3d}] ".format(epoch+1)\ + + " train loss: {:.4f}".format(train_loss.item())\ + + " val acc: {:.4f}".format(val_acc)) + + model.load_weights(args.best_model_path+model.name+".npz", format='npz_dict') + model.set_eval() + spa_emb = tlx.detach(model.spa_encoder(data.x, data.edge_index)) + spe_emb = tlx.detach(model.spe_encoder(e, u)) + test_acc = node_evaluation(tlx.concat((spa_emb, spe_emb), axis=-1), data.y, train_idx, test_idx) + print("Test acc: {:.4f}".format(test_acc)) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--dataset', default='wikics', help='dataset') + parser.add_argument('--dataset_path', type=str, default=r'', help="path to save dataset") + parser.add_argument("--best_model_path", type=str, default=r'./', help="path to save best model") + parser.add_argument('--spe_dim', type=int, default=100) + parser.add_argument('--period', type=int, default=256) + parser.add_argument('--hidden_dim', type=int, default=512, help="dimention of hidden layers") + parser.add_argument('--output_dim', type=int, default=512) + parser.add_argument('--lr', type=float, default=0.01, help="learnin rate") + parser.add_argument('--l2_coef', type=float, default=5e-4, help="l2 loss coeficient") + parser.add_argument('--n_epoch', type=int, default=3, help="number of epoch") + parser.add_argument("--gpu", type=int, default=0) + args = parser.parse_args() + if args.gpu >=0: + tlx.set_device("GPU", args.gpu) + else: + tlx.set_device("CPU") + + main(args) diff --git a/gammagl/datasets/__init__.py b/gammagl/datasets/__init__.py index 2dc598bb..5531fc7c 100644 --- a/gammagl/datasets/__init__.py +++ b/gammagl/datasets/__init__.py @@ -20,6 +20,7 @@ from .wikics import WikiCS from .blogcatalog import BlogCatalog from .molecule_net import MoleculeNet +from .facebook import FacebookPagePage from .acm4heco import ACM4HeCo from .yelp import Yelp @@ -45,6 +46,7 @@ 'PolBlogs', 'WikiCS', 'MoleculeNet', + 'FacebookPagePage', 'NGSIM_US_101', 'Yelp' ] diff --git a/gammagl/datasets/facebook.py b/gammagl/datasets/facebook.py new file mode 100644 index 00000000..47a30476 --- /dev/null +++ b/gammagl/datasets/facebook.py @@ -0,0 +1,70 @@ +from typing import Callable, Optional +import os +import numpy as np +import tensorlayerx as tlx + +from gammagl.data import Graph, InMemoryDataset, download_url + +class FacebookPagePage(InMemoryDataset): + r"""The Facebook Page-Page network dataset introduced in the + `"Multi-scale Attributed Node Embedding" + `_ paper. + Nodes represent verified pages on Facebook and edges are mutual likes. + It contains 22,470 nodes, 342,004 edges, 128 node features and 4 classes. + + Args: + root (str): Root directory where the dataset should be saved. + transform (callable, optional): A function/transform that takes in an + :obj:`gammagl.data.Graph` object and returns a transformed + version. The data object will be transformed before every access. + (default: :obj:`None`) + pre_transform (callable, optional): A function/transform that takes in + an :obj:`gammagl.data.Graph` object and returns a + transformed version. The data object will be transformed before + being saved to disk. (default: :obj:`None`) + force_reload (bool, optional): Whether to re-process the dataset. + (default: :obj:`False`) + """ + + url = 'https://graphmining.ai/datasets/ptg/facebook.npz' + + def __init__( + self, + root: str = None, + transform: Optional[Callable] = None, + pre_transform: Optional[Callable] = None, + force_reload: bool = False, + ) -> None: + super().__init__(root, transform, pre_transform, force_reload=force_reload) + self.data, self.slices=self.load_data(self.processed_paths[0]) + + @property + def raw_file_names(self) -> str: + return 'facebook.npz' + + @property + def processed_file_names(self) -> str: + return tlx.BACKEND + '_data.pt' + + def download(self) -> None: + download_url(self.url, self.raw_dir) + + def process(self) -> None: + data = np.load(self.raw_paths[0], 'r', allow_pickle=True) + x = tlx.convert_to_tensor(data['features'], dtype=tlx.float32) + y = tlx.convert_to_tensor(data['target'], dtype=tlx.int64) + edge_index = tlx.convert_to_tensor(data['edges'], dtype=tlx.int64) + if tlx.BACKEND == 'mindspore': + edge_index = edge_index.T + else: + edge_index = tlx.ops.transpose(edge_index) + + data = Graph(x=x, edge_index=edge_index, y=y) + + if self.pre_transform is not None: + data = self.pre_transform(data) + + self.save_data(self.collate([data]), self.processed_paths[0]) + + + diff --git a/gammagl/datasets/wikics.py b/gammagl/datasets/wikics.py index f2347ff3..a6746ede 100644 --- a/gammagl/datasets/wikics.py +++ b/gammagl/datasets/wikics.py @@ -81,14 +81,14 @@ def process(self): edges = [[(i, j) for j in js] for i, js in enumerate(data['links'])] edges = list(chain(*edges)) - edge_index = np.ascontiguousarray(np.array(edges, dtype=np.int64).T) + edge_index = tlx.transpose(tlx.convert_to_tensor(edges, dtype=tlx.int64)) if self.is_undirected: - edge_index = to_undirected(tlx.convert_to_tensor(edge_index), num_nodes=tlx.get_tensor_shape(x[0])) + edge_index = to_undirected(tlx.convert_to_tensor(edge_index), num_nodes=tlx.get_tensor_shape(x)[0]) - train_mask = np.ascontiguousarray(np.array(data['train_masks'], dtype=np.bool).T) - val_mask = np.ascontiguousarray(np.array(data['val_masks'], dtype=np.bool).T) - test_mask = np.ascontiguousarray(np.array(data['test_mask'], dtype=np.bool).T) - stopping_mask = np.ascontiguousarray(np.array(data['stopping_masks'], dtype=np.bool).T) + train_mask = tlx.transpose(tlx.convert_to_tensor(data['train_masks'], dtype=tlx.bool)) + val_mask = tlx.transpose(tlx.convert_to_tensor(data['val_masks'], dtype=tlx.bool)) + test_mask = tlx.transpose(tlx.convert_to_tensor(data['test_mask'], dtype=tlx.bool)) + stopping_mask = tlx.transpose(tlx.convert_to_tensor(data['stopping_masks'], dtype=tlx.bool)) data = Graph(x=x, y=y, edge_index=edge_index, train_mask=train_mask, val_mask=val_mask, test_mask=test_mask, @@ -96,4 +96,4 @@ def process(self): if self.pre_transform is not None: data = self.pre_transform(data) - self.save_data(self.collate([data]), self.processed_paths[0]) \ No newline at end of file + self.save_data(self.collate([data]), self.processed_paths[0]) diff --git a/gammagl/models/__init__.py b/gammagl/models/__init__.py index 9d564bfb..7e78e2ba 100644 --- a/gammagl/models/__init__.py +++ b/gammagl/models/__init__.py @@ -53,6 +53,7 @@ from .sfgcn import SFGCNModel from .grace_spco import Grace_Spco_Encoder, Grace_Spco_Model from .graphormer import Graphormer +from .sp2gcl import SpaSpeNode, Encoder, EigenMLP from .heco import HeCo from .fusedgat import FusedGATModel from .hid_net import Hid_net @@ -115,6 +116,9 @@ 'Specformer', 'SFGCNModel', 'Graphormer', + 'Encoder', + 'EigenMLP', + 'SpaSpeNode', 'FusedGATModel', 'hid_net', 'HEAT', diff --git a/gammagl/models/sp2gcl.py b/gammagl/models/sp2gcl.py new file mode 100644 index 00000000..0716389a --- /dev/null +++ b/gammagl/models/sp2gcl.py @@ -0,0 +1,88 @@ +import tensorlayerx as tlx +import tensorlayerx.nn as nn +from gammagl.layers.conv import GCNConv + + +class Encoder(nn.Module): + def __init__(self, in_channels, hidden_channels): + super().__init__() + self.conv1 = GCNConv(in_channels, hidden_channels) + self.bn1 = nn.BatchNorm1d(num_features=hidden_channels, momentum=0.01) + self.prelu1 = nn.PRelu() + self.conv2 = GCNConv(hidden_channels, hidden_channels) + self.bn2 = nn.BatchNorm1d(num_features=hidden_channels, momentum=0.01) + self.prelu2 = nn.PRelu() + + def forward(self, x, edge_index, edge_weight=None): + x = self.conv1(x, edge_index, edge_weight=edge_weight) + x = self.prelu1(self.bn1(x)) + x = self.conv2(x, edge_index, edge_weight=edge_weight) + x = self.prelu2(self.bn2(x)) + return x + + +class EigenMLP(nn.Module): + def __init__(self, input_dim, hidden_dim, period): + super(EigenMLP, self).__init__() + + self.period = period + + self.phi = nn.Sequential(nn.Linear(in_features=1, out_features=hidden_dim), nn.ReLU(), + nn.Linear(in_features=hidden_dim, out_features=hidden_dim)) + self.psi = nn.Sequential(nn.Linear(in_features=hidden_dim, out_features=hidden_dim), nn.ReLU(), + nn.Linear(in_features=hidden_dim, out_features=1)) + self.mlp = nn.Sequential(nn.Linear(in_features=2 * period, out_features=hidden_dim), nn.ReLU(), + nn.Linear(in_features=hidden_dim, out_features=hidden_dim)) + + self.relu = nn.ReLU() + + def forward(self, e, u): + u = tlx.expand_dims(u, axis=-1) + u_transformed = self.psi(self.phi(u) + self.phi(-u)) + u = tlx.ops.squeeze(u_transformed, axis=-1) + + # e = e * 100 + period_term = tlx.arange(1, self.period + 1) + e_unsqueeze = tlx.expand_dims(e, axis=1) + fourier_e = tlx.reshape(tlx.stack((tlx.sin(e_unsqueeze * period_term), tlx.cos(e_unsqueeze * period_term)), axis=1), (-1, self.period * 2)) + + h = tlx.matmul(u, fourier_e) + h = self.mlp(h) + return h + + +class SpaSpeNode(nn.Module): + def __init__(self, input_dim, spe_dim, hidden_dim, output_dim, period, name = None): + super().__init__(name=name) + self.input_dim = input_dim + self.spe_dim = spe_dim + self.hidden_dim = hidden_dim + self.output_dim = output_dim + self.period = period + + self.spa_encoder = Encoder(self.input_dim, self.hidden_dim) + self.spe_encoder = EigenMLP(self.spe_dim, self.hidden_dim, self.period) + + self.spa_projection_head = nn.Sequential( + nn.Linear(in_features=hidden_dim, out_features=output_dim, W_init='xavier_uniform'), + nn.PRelu(), + nn.Linear(in_features=output_dim, out_features=output_dim, W_init='xavier_uniform') + ) + + self.spe_projection_head = nn.Sequential( + nn.Linear(in_features=hidden_dim, out_features=output_dim, W_init='xavier_uniform'), + nn.PRelu(), + nn.Linear(in_features=output_dim, out_features=output_dim, W_init='xavier_uniform') + ) + + def forward(self, x, edge_index, e, u): + x_node_spa = self.spa_encoder(x, edge_index) + x_node_spe = self.spe_encoder(e, u) + + h_node_spa = self.spa_projection_head(x_node_spa) + h_node_spe = self.spe_projection_head(x_node_spe) + + h1 = tlx.l2_normalize(h_node_spa, axis=-1, eps=1e-12) + h2 = tlx.l2_normalize(h_node_spe, axis=-1, eps=1e-12) + + return h1, h2 diff --git a/gammagl/utils/get_laplacian.py b/gammagl/utils/get_laplacian.py index 27419011..06f990dd 100644 --- a/gammagl/utils/get_laplacian.py +++ b/gammagl/utils/get_laplacian.py @@ -50,7 +50,10 @@ def get_laplacian(edge_index, num_nodes, edge_weight=None, normalization=None): # Compute A_norm = -D^{-1/2} A D^{-1/2}. deg_inv_sqrt = tlx.pow(deg, -0.5) # deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float('inf'), 0) - edge_weight = tlx.gather(deg_inv_sqrt, row) * edge_weight * tlx.gather(deg_inv_sqrt, col) + if len(tlx.get_tensor_shape(tlx.gather(deg_inv_sqrt, row))) + 1 == len(tlx.get_tensor_shape(edge_weight)): + edge_weight = tlx.expand_dims(tlx.gather(deg_inv_sqrt, row), axis=-1) * edge_weight * tlx.expand_dims(tlx.gather(deg_inv_sqrt, row), axis=-1) + else: + edge_weight = tlx.gather(deg_inv_sqrt, row) * edge_weight * tlx.gather(deg_inv_sqrt, col) # L = I - A_norm. edge_weight = tlx.reshape(edge_weight, (-1, 1)) diff --git a/tests/datasets/test_facebook.py b/tests/datasets/test_facebook.py new file mode 100644 index 00000000..414894c8 --- /dev/null +++ b/tests/datasets/test_facebook.py @@ -0,0 +1,11 @@ +from gammagl.datasets import FacebookPagePage + +def test_facebook(): + dataset = FacebookPagePage() + g = dataset[0] + assert len(dataset) == 1 + assert g.num_nodes == 22470 + assert g.num_edges == 342004 + assert g.x.shape[1] == 128 + assert len(set(g.y.numpy())) == 4 + diff --git a/tests/models/test_sp2gcl.py b/tests/models/test_sp2gcl.py new file mode 100644 index 00000000..043dfec5 --- /dev/null +++ b/tests/models/test_sp2gcl.py @@ -0,0 +1,34 @@ +import tensorlayerx as tlx +from gammagl.models import SpaSpeNode, Encoder, EigenMLP +import numpy as np +from scipy.sparse import csr_matrix +import scipy.sparse.linalg + +def test_spaspenode(): + if tlx.BACKEND == "tensorflow": + return + edge_index = tlx.convert_to_tensor([[0, 1, 1, 2], [1, 0, 2, 1]], dtype=tlx.int64) + x = tlx.convert_to_tensor(np.random.randn(3, 5), dtype=tlx.float32) + y = tlx.convert_to_tensor([0, 1, 2], dtype=tlx.int64) + + num_nodes = x.shape[0] + row, col = edge_index + data_adj = csr_matrix((np.ones(len(row)), (row, col)), shape=(num_nodes, num_nodes)) + degree = np.array(data_adj.sum(axis=1)).flatten() + deg_inv_sqrt = 1.0 / np.sqrt(degree) + deg_inv_sqrt[np.isinf(deg_inv_sqrt)] = 0 + I = csr_matrix(np.eye(num_nodes)) + D_inv_sqrt = csr_matrix((deg_inv_sqrt, (np.arange(num_nodes), np.arange(num_nodes)))) + L = I - D_inv_sqrt.dot(data_adj).dot(D_inv_sqrt) + + k = min(2, num_nodes - 1) + e, u = scipy.sparse.linalg.eigsh(L, k=k, which='SM', tol=1e-3) + e = tlx.convert_to_tensor(e, dtype=tlx.float32) + u = tlx.convert_to_tensor(u, dtype=tlx.float32) + + model = SpaSpeNode(input_dim=x.shape[1], spe_dim=20, hidden_dim=32, output_dim=16, period=20) + + + h_node_spa, h_node_spe = model(x, edge_index, e, u) + assert h_node_spa.shape == (num_nodes, 16) + assert h_node_spe.shape == (num_nodes, 16)