diff --git a/README.md b/README.md
index eab90c2..3fe8e81 100644
--- a/README.md
+++ b/README.md
@@ -4,7 +4,8 @@
[![pytorch](https://img.shields.io/badge/PyTorch_2.1+-ee4c2c?logo=pytorch&logoColor=white)](https://pytorch.org/get-started/locally/)
[![pyg](https://img.shields.io/badge/PyG_2.4+-3C2179?logo=pyg&logoColor=#3C2179)](https://pytorch-geometric.readthedocs.io/en/latest/install/installation.html)
-[![arxiv](http://img.shields.io/badge/arxiv-2310.04562-yellow.svg)](https://arxiv.org/abs/2310.04562)
+[![ULTRA arxiv](http://img.shields.io/badge/arxiv-2310.04562-yellow.svg)](https://arxiv.org/abs/2310.04562)
+[![UltraQuery arxiv](http://img.shields.io/badge/arxiv-2404.07198-yellow.svg)](https://arxiv.org/abs/2404.07198)
[![HuggingFace Hub](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Models-black)](https://huggingface.co/collections/mgalkin/ultra-65699bb28369400a5827669d)
![license](https://img.shields.io/badge/License-MIT-green.svg?labelColor=gray)
@@ -37,6 +38,7 @@ This repository is based on PyTorch 2.1 and PyTorch-Geometric 2.4.
* [Pre-train](#pretraining) ULTRA on your own mixture of graphs.
* Run [evaluation on many datasets](#run-on-many-datasets) sequentially.
* Use the pre-trained checkpoints to run inference and fine-tuning on [your own KGs](#adding-your-own-graph).
+* (NEW) Execute complex logical queries on any KG with [UltraQuery](#ultraquery)
Table of contents:
* [Installation](#installation)
@@ -47,8 +49,10 @@ Table of contents:
* [Pretraining](#pretraining)
* [Datasets](#datasets)
* [Adding custom datasets](#adding-your-own-graph)
+* [UltraQuery](#ultraquery)
## Updates
+* **Apr 23rd, 2024**: Release of [UltraQuery](#ultraquery) for complex multi-hop logical query answering on _any_ KG (with new checkpoint and 23 datasets).
* **Jan 15th, 2024**: Accepted at [ICLR 2024](https://openreview.net/forum?id=jVEoydFOl9)!
* **Dec 4th, 2023**: Added a new ULTRA checkpoint `ultra_50g` pre-trained on 50 graphs. Averaged over 16 larger transductive graphs, it delivers 0.389 MRR / 0.549 Hits@10 compared to 0.329 MRR / 0.479 Hits@10 of the `ultra_3g` checkpoint. The inductive performance is still as good! Use this checkpoint for inference on larger graphs.
* **Dec 4th, 2023**: Pre-trained ULTRA models (3g, 4g, 50g) are now also available on the [HuggingFace Hub](https://huggingface.co/collections/mgalkin/ultra-65699bb28369400a5827669d)!
@@ -340,17 +344,188 @@ class CustomDataset(InductiveDataset):
TSV / CSV files are supported by setting a delimiter (eg, `delimiter = "\t"`) in the class definition.
After adding your own dataset, you can immediately run 0-shot inference or fine-tuning of any ULTRA checkpoint.
+## UltraQuery ##
+
+You can now run complex logical queries on any KG with UltraQuery, an inductive query answering approach that uses any Ultra checkpoint with non-parametric fuzzy logic operators. Read more in the [new preprint](https://arxiv.org/abs/2404.07198).
+
+Similar to Ultra, UltraQuery transfers to any KG in the zero-shot fashion and sets a few SOTA results on a variety of query answering benchmarks.
+
+### Checkpoint ###
+
+Any existing ULTRA checkpoint is compatible with UltraQuery but we also ship a newly trained `ultraquery.pth` checkpoint in the `ckpts` folder.
+
+* A new `ultraquery.pth` checkpoint trained on complex queries from the `FB15k237LogicalQuery` dataset for 40,000 steps, the config is in `config/ultraquery/pretrain.yaml` - the same ULTRA architecture but tuned for the multi-source propagation needed in complex queries (no need for score thresholding)
+* You can use any existing ULTRA checkpoint (`3g` / `4g` / `50g`) for starters - don't forget to set the `--threshold` argument to 0.8 or higher (depending on the dataset). Score thresholding is required because those models were trained on simple one-hop link prediction and there are certain issues (namely, the multi-source propagation issue, read Section 4.1 in the [new preprint](https://arxiv.org/abs/2404.07198) for more details)
+
+### Performance
+
+The numbers reported in the preprint were obtained with a model trained with TorchDrug. In this PyG implementation, we managed to get even better performance across the board with the `ultraquery.pth` checkpoint.
+
+`EPFO` is the averaged performance over 9 queries with relation projection, intersection, and union. `Neg` is the averaged performance over 5 queries with negation.
+
+
+
+ Model |
+ Total Average (23 datasets) |
+ Transductive (3 datasets) |
+ Inductive (e) (9 graphs) |
+ Inductive (e,r) (11 graphs) |
+
+
+ EPFO MRR |
+ EPFO Hits@10 |
+ Neg MRR |
+ Neg Hits@10 |
+ EPFO MRR |
+ EPFO Hits@10 |
+ Neg MRR |
+ Neg Hits@10 |
+ EPFO MRR |
+ EPFO Hits@10 |
+ Neg MRR |
+ Neg Hits@10 |
+ EPFO MRR |
+ EPFO Hits@10 |
+ Neg MRR |
+ Neg Hits@10 |
+
+
+ UltraQuery Paper |
+ 0.301 |
+ 0.428 |
+ 0.152 |
+ 0.264 |
+ 0.335 |
+ 0.467 |
+ 0.132 |
+ 0.260 |
+ 0.321 |
+ 0.479 |
+ 0.156 |
+ 0.291 |
+ 0.275 |
+ 0.375 |
+ 0.153 |
+ 0.242 |
+
+
+ UltraQuery PyG |
+ 0.309 |
+ 0.432 |
+ 0.178 |
+ 0.286 |
+ 0.411 |
+ 0.518 |
+ 0.240 |
+ 0.352 |
+ 0.312 |
+ 0.468 |
+ 0.139 |
+ 0.262 |
+ 0.280 |
+ 0.380 |
+ 0.193 |
+ 0.288 |
+
+
+
+In particular, we reach SOTA on FB15k queries (0.764 MRR & 0.834 Hits@10 on EPFO; 0.567 MRR & 0.725 Hits@10 on negation) compared to much larger and heavier baselines (such as QTO).
+
+### Run Inference ###
+
+The running format is similar to the KG completion pipeline - use `run_query.py` and `run_query_many` for running a single expriment on one dataset or on a sequence of datasets.
+Due to the size of the datasets and query complexity, it is recommended to run inference on a GPU.
+
+An example command for running transductive inference with UltraQuery on FB15k237 queries
+
+```bash
+python script/run_query.py -c config/ultraquery/transductive.yaml --dataset FB15k237LogicalQuery --epochs 0 --bpe null --gpus [0] --bs 32 --threshold 0.0 --ultra_ckpt null --qe_ckpt /path/to/ultra/ckpts/ultraquery.pth
+```
+
+An example command for running transductive inference with a vanilla Ultra 4g on FB15k237 queries with scores thresholding
+
+```bash
+python script/run_query.py -c config/ultraquery/transductive.yaml --dataset FB15k237LogicalQuery --epochs 0 --bpe null --gpus [0] --bs 32 --threshold 0.8 --ultra_ckpt /path/to/ultra/ckpts/ultra_4g.pth --qe_ckpt null
+```
+
+An example command for running inductive inference with UltraQuery on `InductiveFB15k237Query:550` queries
+
+```bash
+python script/run_query.py -c config/ultraquery/inductive.yaml --dataset InductiveFB15k237Query --version 550 --epochs 0 --bpe null --gpus [0] --bs 32 --threshold 0.0 --ultra_ckpt null --qe_ckpt /path/to/ultra/ckpts/ultraquery.pth
+```
+
+New arguments for `_query` scripts:
+* `--threshold`: set to 0.0 when using the main UltraQuery checkpoint `ultraquery.pth` or 0.8 (and higher) when using vanilla Ultra checkpoints
+* `--qe_ckpt`: path to the UltraQuery checkpoint, set to `null` if you want to run vanilla Ultra checkpoints
+* `--ultra_ckpt`: path to the original Ultra checkpoints, set to `null` if you want to run the UltraQuery checkpoint
+
+### Datasets ###
+
+23 new datasets available in `datasets_query.py` that will be automatically downloaded upon the first launch.
+All datasets include 14 standard query types (`1p`, `2p`, `3p`, `2i`, `3i`, `ip`, `pi`, `2u-DNF`, `up-DNF`, `2in`, `3in`,`inp`, `pin`, `pni`).
+
+The standard protocol is training on 10 patterns without unions and `ip`,`pi` queries (`1p`, `2p`, `3p`, `2i`, `3i`, `2in`, `3in`,`inp`, `pin`, `pni`) and running evaluation on all 14 patterns including `2u`, `up`, `ip`, `pi`.
+
+
+Transductive query datasets (3)
+
+All are the [BetaE](https://arxiv.org/abs/2010.11465) versions of the datasets including queries with negation and limiting the max number of answers to 100
+* `FB15k237LogicalQuery`, `FB15kLogicalQuery`, `NELL995LogicalQuery`
+
+
+
+
+Inductive (e) query datasets (9)
+
+9 inductive datasets extracted from FB15k237 - first proposed in [Inductive Logical Query Answering in Knowledge Graphs](https://openreview.net/forum?id=-vXEN5rIABY) (NeurIPS 2022)
+
+`InductiveFB15k237Query` with 9 versions where the number shows the how large is the inference graph compared to the train graph (in the number of nodes):
+* `550`, `300`, `217`, `175`, `150`, `134`, `122`, `113`, `106`
+
+In addition, we include the `InductiveFB15k237QueryExtendedEval` dataset with the same versions. Those are supposed to be inference-only datasets that measure the _faithfulness_ of complex query answering approaches. In each split, as validation and test graphs extend the train graphs with more nodes and edges, training queries now have more true answers achievable by simple edge traversal (no missing link prediction required) - the task is to measure how well CLQA models can retrieve new easy answers on training queries but on larger unseen graphs.
+
+
+
+
+Inductive (e,r) query datasets (11)
+
+11 new inductive query datasets (WikiTopics-CLQA) that we built specifically for testing UltraQuery.
+The queries were sampled from the WikiTopics splits proposed in [Double Equivariance for Inductive Link Prediction for Both New Nodes and New Relation Types](https://arxiv.org/abs/2302.01313)
+
+`WikiTopicsQuery` with 11 versions
+* `art`, `award`, `edu`, `health`, `infra`, `loc`, `org`, `people`, `sci`, `sport`, `tax`
+
+
+
+### Metrics
+
+New metrics include `auroc`, `spearmanr`, `mape`. We don't support Mean Rank `mr` in complex queries. If you ever see `nan` in one of those metrics, consider reducing the batch size as those metrics are computed with the variadic functions that might be numerically unstable on large batches.
+
## Citation ##
-If you find this codebase useful in your research, please cite the original paper.
+If you find this codebase useful in your research, please cite the original papers.
+
+The main ULTRA paper:
+
+```bibtex
+@inproceedings{galkin2023ultra,
+ title={Towards Foundation Models for Knowledge Graph Reasoning},
+ author={Mikhail Galkin and Xinyu Yuan and Hesham Mostafa and Jian Tang and Zhaocheng Zhu},
+ booktitle={The Twelfth International Conference on Learning Representations},
+ year={2024},
+ url={https://openreview.net/forum?id=jVEoydFOl9}
+}
+```
+
+UltraQuery:
```bibtex
-@article{galkin2023ultra,
- title={Towards Foundation Models for Knowledge Graph Reasoning},
- author={Mikhail Galkin and Xinyu Yuan and Hesham Mostafa and Jian Tang and Zhaocheng Zhu},
- year={2023},
- eprint={2310.04562},
+@article{galkin2024ultraquery,
+ title={Zero-shot Logical Query Reasoning on any Knowledge Graph},,
+ author={Mikhail Galkin and Jincheng Zhou and Bruno Ribeiro and Jian Tang and Zhaocheng Zhu},
+ year={2024},
+ eprint={2404.07198},
archivePrefix={arXiv},
- primaryClass={cs.CL}
+ primaryClass={cs.AI}
}
```
diff --git a/ckpts/ultraquery.pth b/ckpts/ultraquery.pth
new file mode 100644
index 0000000..dc459d7
Binary files /dev/null and b/ckpts/ultraquery.pth differ
diff --git a/config/inductive/inference.yaml b/config/inductive/inference.yaml
index d19c187..1054f12 100644
--- a/config/inductive/inference.yaml
+++ b/config/inductive/inference.yaml
@@ -8,7 +8,7 @@ dataset:
model:
class: Ultra
relation_model:
- class: NBFNet
+ class: RelNBFNet
input_dim: 64
hidden_dims: [64, 64, 64, 64, 64, 64]
message_func: distmult
@@ -16,7 +16,7 @@ model:
short_cut: yes
layer_norm: yes
entity_model:
- class: IndNBFNet
+ class: EntityNBFNet
input_dim: 64
hidden_dims: [64, 64, 64, 64, 64, 64]
message_func: distmult
diff --git a/config/transductive/inference.yaml b/config/transductive/inference.yaml
index 03f1b6a..3a955ff 100644
--- a/config/transductive/inference.yaml
+++ b/config/transductive/inference.yaml
@@ -7,7 +7,7 @@ dataset:
model:
class: Ultra
relation_model:
- class: NBFNet
+ class: RelNBFNet
input_dim: 64
hidden_dims: [64, 64, 64, 64, 64, 64]
message_func: distmult
@@ -15,7 +15,7 @@ model:
short_cut: yes
layer_norm: yes
entity_model:
- class: IndNBFNet
+ class: EntityNBFNet
input_dim: 64
hidden_dims: [64, 64, 64, 64, 64, 64]
message_func: distmult
diff --git a/config/transductive/pretrain_3g.yaml b/config/transductive/pretrain_3g.yaml
index ed2d21d..09cd6b7 100644
--- a/config/transductive/pretrain_3g.yaml
+++ b/config/transductive/pretrain_3g.yaml
@@ -8,7 +8,7 @@ dataset:
model:
class: Ultra
relation_model:
- class: NBFNet
+ class: RelNBFNet
input_dim: 64
hidden_dims: [64, 64, 64, 64, 64, 64]
message_func: distmult
@@ -16,7 +16,7 @@ model:
short_cut: yes
layer_norm: yes
entity_model:
- class: IndNBFNet
+ class: EntityNBFNet
input_dim: 64
hidden_dims: [64, 64, 64, 64, 64, 64]
message_func: distmult
diff --git a/config/transductive/pretrain_4g.yaml b/config/transductive/pretrain_4g.yaml
index 01609cb..4127027 100644
--- a/config/transductive/pretrain_4g.yaml
+++ b/config/transductive/pretrain_4g.yaml
@@ -8,7 +8,7 @@ dataset:
model:
class: Ultra
relation_model:
- class: NBFNet
+ class: RelNBFNet
input_dim: 64
hidden_dims: [64, 64, 64, 64, 64, 64]
message_func: distmult
@@ -16,7 +16,7 @@ model:
short_cut: yes
layer_norm: yes
entity_model:
- class: IndNBFNet
+ class: EntityNBFNet
input_dim: 64
hidden_dims: [64, 64, 64, 64, 64, 64]
message_func: distmult
diff --git a/config/ultraquery/inductive.yaml b/config/ultraquery/inductive.yaml
new file mode 100644
index 0000000..010d81f
--- /dev/null
+++ b/config/ultraquery/inductive.yaml
@@ -0,0 +1,53 @@
+output_dir: ~/git/ULTRA/output
+
+dataset:
+ class: {{ dataset }}
+ root: ~/git/ULTRA/query-datasets/
+ version: {{ version }} # specify dataset version here or when running the script
+
+model:
+ class: UltraQuery
+ model:
+ class: Ultra
+ relation_model:
+ class: RelNBFNet
+ input_dim: 64
+ hidden_dims: [64, 64, 64, 64, 64, 64]
+ message_func: distmult
+ aggregate_func: sum
+ short_cut: yes
+ layer_norm: yes
+ entity_model:
+ class: QueryNBFNet
+ input_dim: 64
+ hidden_dims: [64, 64, 64, 64, 64, 64]
+ message_func: distmult
+ aggregate_func: sum
+ short_cut: yes
+ layer_norm: yes
+ logic: product
+ dropout_ratio: 0.5
+ threshold: {{ threshold }}
+ more_dropout: 0.0
+
+task:
+ name: InductiveInference
+ strict_negative: yes
+ adversarial_temperature: 0.2
+ sample_weight: no
+ metric: [mrr, hits@1, hits@3, hits@10, auroc, spearmanr] # mape is supported as well
+
+optimizer:
+ class: Adam
+ lr: 5.0e-4
+
+train:
+ gpus: {{ gpus }}
+ batch_size: {{ bs }} # reduce if doesn't fit on a GPU
+ num_epoch: {{ epochs }} # total number of optimization steps will be num_epochs * batch_per_epoch
+ batch_per_epoch: {{ bpe }} # number of batches to be considered as "one epoch"
+ log_interval: 100
+ fast_test: 1000 # UltraQuery is slower in inference, use this option for a random subsample of valid data
+
+ultra_ckpt: {{ ultra_ckpt }} # Ultra checkpoint pre-trained on simple link prediction
+ultraquery_ckpt: {{ qe_ckpt }} # UltraQuery checkpoint trained on complex queries
\ No newline at end of file
diff --git a/config/ultraquery/pretrain.yaml b/config/ultraquery/pretrain.yaml
new file mode 100644
index 0000000..1053216
--- /dev/null
+++ b/config/ultraquery/pretrain.yaml
@@ -0,0 +1,53 @@
+output_dir: ~/git/ULTRA/output
+
+dataset:
+ class: FB15k237LogicalQuery
+ root: ~/git/ULTRA/query-datasets/
+
+
+model:
+ class: UltraQuery
+ model:
+ class: Ultra
+ relation_model:
+ class: RelNBFNet
+ input_dim: 64
+ hidden_dims: [64, 64, 64, 64, 64, 64]
+ message_func: distmult
+ aggregate_func: sum
+ short_cut: yes
+ layer_norm: yes
+ entity_model:
+ class: QueryNBFNet
+ input_dim: 64
+ hidden_dims: [64, 64, 64, 64, 64, 64]
+ message_func: distmult
+ aggregate_func: sum
+ short_cut: yes
+ layer_norm: yes
+ logic: product
+ dropout_ratio: 0.25
+ threshold: 0.0
+ more_dropout: 0.0
+
+task:
+ name: TransductiveInference
+ strict_negative: yes
+ adversarial_temperature: 0.2
+ sample_weight: no
+ metric: [mrr, hits@1, hits@3, hits@10] # auroc, spearmanr, mape are supported as well
+
+optimizer:
+ class: Adam
+ lr: 5.0e-4
+
+train:
+ gpus: {{ gpus }}
+ batch_size: {{ bs }} # 32 for 4x 3090 (24 GB), adjust total num_steps accordingly
+ num_epoch: 10 # total number of optimization steps will be num_epochs * batch_per_epoch
+ batch_per_epoch: 4000 # number of batches to be considered as "one epoch"
+ log_interval: 400
+ fast_test: 1000 # UltraQuery is slower in inference, use this option for a random subsample of valid data
+
+ultra_ckpt: ~/git/ULTRA/ckpts/ultra_4g.pth # Initialize with Ultra 4g
+ultraquery_ckpt: null # UltraQuery checkpoint trained on complex queries
\ No newline at end of file
diff --git a/config/ultraquery/transductive.yaml b/config/ultraquery/transductive.yaml
new file mode 100644
index 0000000..55fe924
--- /dev/null
+++ b/config/ultraquery/transductive.yaml
@@ -0,0 +1,53 @@
+output_dir: ~/git/ULTRA/output
+
+dataset:
+ class: {{ dataset }}
+ root: ~/git/ULTRA/query-datasets/
+
+
+model:
+ class: UltraQuery
+ model:
+ class: Ultra
+ relation_model:
+ class: RelNBFNet
+ input_dim: 64
+ hidden_dims: [64, 64, 64, 64, 64, 64]
+ message_func: distmult
+ aggregate_func: sum
+ short_cut: yes
+ layer_norm: yes
+ entity_model:
+ class: QueryNBFNet
+ input_dim: 64
+ hidden_dims: [64, 64, 64, 64, 64, 64]
+ message_func: distmult
+ aggregate_func: sum
+ short_cut: yes
+ layer_norm: yes
+ logic: product
+ dropout_ratio: 0.25
+ threshold: {{ threshold }}
+ more_dropout: 0.0
+
+task:
+ name: TransductiveInference
+ strict_negative: yes
+ adversarial_temperature: 0.2
+ sample_weight: no
+ metric: [mrr, hits@1, hits@3, hits@10] # auroc, spearmanr, mape are supported as well
+
+optimizer:
+ class: Adam
+ lr: 5.0e-4
+
+train:
+ gpus: {{ gpus }}
+ batch_size: {{ bs }}
+ num_epoch: {{ epochs }} # total number of optimization steps will be num_epochs * batch_per_epoch
+ batch_per_epoch: {{ bpe }} # number of batches to be considered as "one epoch"
+ log_interval: 100
+ fast_test: 1000 # UltraQuery is slower in inference, use this option for a random subsample of valid data
+
+ultra_ckpt: {{ ultra_ckpt }} # Ultra checkpoint pre-trained on simple link prediction
+ultraquery_ckpt: {{ qe_ckpt }} # UltraQuery checkpoint trained on complex queries
\ No newline at end of file
diff --git a/script/run_query.py b/script/run_query.py
new file mode 100644
index 0000000..9846e96
--- /dev/null
+++ b/script/run_query.py
@@ -0,0 +1,264 @@
+import os
+import sys
+import csv
+import math
+import time
+import pprint
+from itertools import islice
+from tqdm import tqdm
+
+import torch
+import torch_geometric as pyg
+from torch import optim
+from torch import nn
+from torch.nn import functional as F
+from torch.utils import data as torch_data
+
+sys.path.append(os.path.dirname(os.path.dirname(__file__)))
+from ultra import datasets_query, tasks, util, query_utils
+from ultra.models import Ultra
+from ultra.ultraquery import UltraQuery
+from ultra.query_utils import batch_evaluate, evaluate, gather_results
+from ultra.variadic import variadic_softmax
+from timeit import default_timer as timer
+
+
+separator = ">" * 30
+line = "-" * 30
+
+def predict_and_target(model, graph, batch):
+ query = batch["query"]
+ type = batch["type"]
+ easy_answer = batch["easy_answer"]
+ hard_answer = batch["hard_answer"]
+
+ # turn off symbolic traversal at inference time
+ pred = model(graph, query, symbolic_traversal=model.training)
+ if not model.training:
+ # eval
+ target = (type, easy_answer, hard_answer)
+ restrict_nodes = getattr(graph, "restrict_nodes", None)
+ ranking, answer_ranking = batch_evaluate(pred, target, restrict_nodes)
+ # answer set cardinality prediction
+ prob = F.sigmoid(pred)
+ num_pred = (prob * (prob > 0.5)).sum(dim=-1)
+ num_easy = easy_answer.sum(dim=-1)
+ num_hard = hard_answer.sum(dim=-1)
+ return (ranking, num_pred), (type, answer_ranking, num_easy, num_hard)
+ else:
+ target = easy_answer.float()
+
+ return pred, target
+
+def train_and_validate(cfg, model, train_graph, train_data, valid_graph, valid_data, query_id2type, device, logger, batch_per_epoch=None):
+ if cfg.train.num_epoch == 0:
+ return
+
+ world_size = util.get_world_size()
+ rank = util.get_rank()
+
+ sampler = torch_data.DistributedSampler(train_data, world_size, rank)
+ train_loader = torch_data.DataLoader(train_data, cfg.train.batch_size, sampler=sampler)
+
+ batch_per_epoch = batch_per_epoch or len(train_loader)
+
+ cls = cfg.optimizer.pop("class")
+ optimizer = getattr(optim, cls)(model.parameters(), **cfg.optimizer)
+ num_params = sum(p.numel() for p in model.parameters())
+ logger.warning(line)
+ logger.warning(f"Number of parameters: {num_params}")
+
+ if world_size > 1:
+ parallel_model = nn.parallel.DistributedDataParallel(model, device_ids=[device])
+ else:
+ parallel_model = model
+
+ step = math.ceil(cfg.train.num_epoch / 10)
+ best_result = float("-inf")
+ best_epoch = -1
+
+ batch_id = 0
+ for i in range(0, cfg.train.num_epoch, step):
+ parallel_model.train()
+ for epoch in range(i, min(cfg.train.num_epoch, i + step)):
+ if util.get_rank() == 0:
+ logger.warning(separator)
+ logger.warning("Epoch %d begin" % epoch)
+
+ losses = []
+ sampler.set_epoch(epoch)
+ for batch in islice(train_loader, batch_per_epoch):
+ if device.type == "cuda":
+ train_graph = train_graph.to(device)
+ batch = query_utils.cuda(batch, device=device)
+ pred, target = predict_and_target(parallel_model, train_graph, batch)
+
+ loss = F.binary_cross_entropy_with_logits(pred, target, reduction="none")
+
+ is_positive = target > 0.5
+ is_negative = target <= 0.5
+ num_positive = is_positive.sum(dim=-1)
+ num_negative = is_negative.sum(dim=-1)
+
+ neg_weight = torch.zeros_like(pred)
+ neg_weight[is_positive] = (1 / num_positive.float()).repeat_interleave(num_positive)
+
+ if cfg.task.adversarial_temperature > 0:
+ with torch.no_grad():
+ logit = pred[is_negative] / cfg.task.adversarial_temperature
+ neg_weight[is_negative] = variadic_softmax(logit, num_negative)
+ #neg_weight[:, 1:] = F.softmax(pred[:, 1:] / cfg.task.adversarial_temperature, dim=-1)
+ else:
+ neg_weight[is_negative] = (1 / num_negative.float()).repeat_interleave(num_negative)
+ loss = (loss * neg_weight).sum(dim=-1) / neg_weight.sum(dim=-1)
+ loss = loss.mean()
+
+ loss.backward()
+ optimizer.step()
+ optimizer.zero_grad()
+
+ if util.get_rank() == 0 and batch_id % cfg.train.log_interval == 0:
+ logger.warning(separator)
+ logger.warning("binary cross entropy: %g" % loss)
+ losses.append(loss.item())
+ batch_id += 1
+
+ if util.get_rank() == 0:
+ avg_loss = sum(losses) / len(losses)
+ logger.warning(separator)
+ logger.warning("Epoch %d end" % epoch)
+ logger.warning(line)
+ logger.warning("average binary cross entropy: %g" % avg_loss)
+
+ epoch = min(cfg.train.num_epoch, i + step)
+ if rank == 0:
+ logger.warning("Save checkpoint to model_epoch_%d.pth" % epoch)
+ state = {
+ "model": model.state_dict(),
+ "optimizer": optimizer.state_dict()
+ }
+ torch.save(state, "model_epoch_%d.pth" % epoch)
+ util.synchronize()
+
+ if rank == 0:
+ logger.warning(separator)
+ logger.warning("Evaluate on valid")
+ result = test(cfg, model, valid_graph, valid_data, query_id2type=query_id2type, device=device, logger=logger)
+ if result['mrr'] > best_result:
+ best_result = result['mrr']
+ best_epoch = epoch
+
+ if rank == 0:
+ logger.warning("Load checkpoint from model_epoch_%d.pth" % best_epoch)
+ state = torch.load("model_epoch_%d.pth" % best_epoch, map_location=device)
+ model.load_state_dict(state["model"])
+ util.synchronize()
+
+
+@torch.no_grad()
+def test(cfg, model, test_graph, test_data, query_id2type, device, logger, return_metrics=False):
+ world_size = util.get_world_size()
+ rank = util.get_rank()
+
+ sampler = torch_data.DistributedSampler(test_data, world_size, rank)
+ test_loader = torch_data.DataLoader(test_data, cfg.train.batch_size, sampler=sampler)
+
+ model.eval()
+ preds, targets = [], []
+ for batch in tqdm(test_loader):
+ if device.type == "cuda":
+ test_graph = test_graph.to(device)
+ batch = query_utils.cuda(batch, device=device)
+
+ predictions, target = predict_and_target(model, test_graph, batch)
+ preds.append(predictions)
+ targets.append(target)
+
+ pred = query_utils.cat(preds)
+ target = query_utils.cat(targets)
+
+ pred, target = gather_results(pred, target, rank, world_size, device)
+
+ metrics = {}
+ if rank == 0:
+ metrics = evaluate(pred, target, cfg.task.metric, query_id2type)
+ query_utils.print_metrics(metrics, logger)
+ else:
+ metrics['mrr'] = (1 / pred[0].float()).mean().item()
+ util.synchronize()
+ return metrics
+
+
+if __name__ == "__main__":
+ args, vars = util.parse_args()
+ cfg = util.load_config(args.config, context=vars)
+ working_dir = util.create_working_directory(cfg)
+
+ torch.manual_seed(args.seed + util.get_rank())
+
+ logger = util.get_root_logger()
+ if util.get_rank() == 0:
+ logger.warning("Random seed: %d" % args.seed)
+ logger.warning("Config file: %s" % args.config)
+ logger.warning(pprint.pformat(cfg))
+
+ task_name = cfg.task["name"]
+ dataset = query_utils.build_query_dataset(cfg)
+ device = util.get_device(cfg)
+ path = os.path.dirname(os.path.expanduser(__file__))
+ results_file = os.path.join(path, f"ultraquery_results_{time.strftime('%Y-%m-%d-%H-%M-%S')}.csv")
+
+ train_data, valid_data, test_data = dataset.split()
+ train_graph, valid_graph, test_graph = dataset.train_graph, dataset.valid_graph, dataset.test_graph
+
+ model = UltraQuery(
+ model=Ultra(
+ rel_model_cfg=cfg.model.model.relation_model,
+ entity_model_cfg=cfg.model.model.entity_model,
+ ),
+ logic=cfg.model.logic,
+ dropout_ratio=cfg.model.dropout_ratio,
+ threshold=cfg.model.threshold,
+ more_dropout=cfg.model.get('more_dropout', 0.0),
+ )
+
+ # initialize with pre-trained ultra for link prediction
+ if "ultra_ckpt" in cfg and cfg.ultra_ckpt is not None:
+ state = torch.load(cfg.ultra_ckpt, map_location="cpu")
+ model.model.model.load_state_dict(state["model"])
+
+ # initialize with a pre-trained ultraquery model for query answering
+ if "ultraquery_ckpt" in cfg and cfg.ultraquery_ckpt is not None:
+ state = torch.load(cfg.ultraquery_ckpt, map_location="cpu")
+ model.load_state_dict(state["model"])
+
+ if "fast_test" in cfg.train:
+ if util.get_rank() == 0:
+ logger.warning("Quick test mode on. Only evaluate on %d samples for valid" % cfg.train.fast_test)
+ g = torch.Generator()
+ g.manual_seed(1024)
+ valid_data = torch_data.random_split(valid_data, [cfg.train.fast_test, len(valid_data) - cfg.train.fast_test], generator=g)[0]
+
+
+ model = model.to(device)
+
+ train_and_validate(cfg, model, train_graph, train_data, valid_graph, valid_data, query_id2type=dataset.id2type, device=device, batch_per_epoch=cfg.train.batch_per_epoch, logger=logger)
+ if util.get_rank() == 0:
+ logger.warning(separator)
+ logger.warning("Evaluate on valid")
+ start = timer()
+ val_metrics = test(cfg, model, valid_graph, valid_data, query_id2type=dataset.id2type, device=device, logger=logger)
+ end = timer()
+ # write to the log file
+ # val_metrics['dataset'] = str(dataset)
+ # util.print_metrics_to_file(val_metrics, results_file)
+ logger.warning(f"Valid time: {end - start}")
+ if util.get_rank() == 0:
+ logger.warning(separator)
+ logger.warning("Evaluate on test")
+ metrics = test(cfg, model, test_graph, test_data, query_id2type=dataset.id2type, device=device, logger=logger)
+
+ # write to the log file
+ if util.get_rank() == 0:
+ metrics['dataset'] = str(dataset)
+ query_utils.print_metrics_to_file(metrics, results_file)
diff --git a/script/run_query_many.py b/script/run_query_many.py
new file mode 100644
index 0000000..677eb0b
--- /dev/null
+++ b/script/run_query_many.py
@@ -0,0 +1,141 @@
+import os
+import sys
+import time
+import random
+import pprint
+import argparse
+
+import torch
+import torch_geometric as pyg
+from torch.utils import data as torch_data
+
+sys.path.append(os.path.dirname(os.path.dirname(__file__)))
+from ultra import datasets_query, tasks, util, query_utils
+from ultra.models import Ultra
+from ultra.ultraquery import UltraQuery
+from timeit import default_timer as timer
+from script.run_query import train_and_validate, test
+
+separator = ">" * 30
+line = "-" * 30
+
+def set_seed(seed):
+ random.seed(seed + util.get_rank())
+ # np.random.seed(seed + util.get_rank())
+ torch.manual_seed(seed + util.get_rank())
+ torch.cuda.manual_seed(seed + util.get_rank())
+ torch.backends.cudnn.deterministic = True
+ torch.backends.cudnn.benchmark = False
+
+
+if __name__ == "__main__":
+ seeds = [1024, 42, 1337, 512, 256]
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument("-c", "--config", help="yaml configuration file", required=True)
+ parser.add_argument("-d", "--datasets", help="target datasets", default='InductiveFB15k237Query:550,InductiveFB15k237Query:300', type=str, required=True)
+ parser.add_argument("-reps", "--repeats", help="number of times to repeat each exp", default=1, type=int)
+ parser.add_argument("-ft", "--finetune", help="finetune the checkpoint on the specified datasets", action='store_true')
+ args, unparsed = parser.parse_known_args()
+
+ datasets = args.datasets.split(",")
+ path = os.path.dirname(os.path.expanduser(__file__))
+ results_file = os.path.join(path, f"ultraquery_results_{time.strftime('%Y-%m-%d-%H-%M-%S')}.csv")
+
+ for graph in datasets:
+ ds, version = graph.split(":") if ":" in graph else (graph, None)
+ for i in range(args.repeats):
+ seed = seeds[i] if i < len(seeds) else random.randint(0, 10000)
+ print(f"Running on {graph}, iteration {i+1} / {args.repeats}, seed: {seed}")
+
+ # get dynamic arguments defined in the config file
+ vars = util.detect_variables(args.config)
+ parser = argparse.ArgumentParser()
+ for var in vars:
+ parser.add_argument("--%s" % var)
+ vars = parser.parse_known_args(unparsed)[0]
+ vars = {k: util.literal_eval(v) for k, v in vars._get_kwargs()}
+
+ if args.finetune:
+ epochs, batch_per_epoch = 1, 1000
+ else:
+ epochs, batch_per_epoch = 0, 'null'
+ vars['epochs'] = epochs
+ vars['bpe'] = batch_per_epoch
+ vars['dataset'] = ds
+ if version is not None:
+ vars['version'] = version
+
+ #args, vars = util.parse_args()
+ cfg = util.load_config(args.config, context=vars)
+ root_dir = os.path.expanduser(cfg.output_dir) # resetting the path to avoid inf nesting
+ os.chdir(root_dir)
+ working_dir = util.create_working_directory(cfg)
+ set_seed(seed)
+
+ logger = util.get_root_logger()
+ if util.get_rank() == 0:
+ logger.warning("Random seed: %d" % seed)
+ logger.warning("Config file: %s" % args.config)
+ logger.warning(pprint.pformat(cfg))
+
+ task_name = cfg.task["name"]
+ dataset = query_utils.build_query_dataset(cfg)
+ device = util.get_device(cfg)
+
+
+ train_data, valid_data, test_data = dataset.split()
+ train_graph, valid_graph, test_graph = dataset.train_graph, dataset.valid_graph, dataset.test_graph
+
+ model = UltraQuery(
+ model=Ultra(
+ rel_model_cfg=cfg.model.model.relation_model,
+ entity_model_cfg=cfg.model.model.entity_model,
+ ),
+ logic=cfg.model.logic,
+ dropout_ratio=cfg.model.dropout_ratio,
+ threshold=cfg.model.threshold,
+ more_dropout=cfg.model.get('more_dropout', 0.0),
+ )
+
+ # initialize with pre-trained ultra for link prediction
+ if "ultra_ckpt" in cfg and cfg.ultra_ckpt is not None:
+ state = torch.load(cfg.ultra_ckpt, map_location="cpu")
+ model.model.model.load_state_dict(state["model"])
+
+ # initialize with a pre-trained ultraquery model for query answering
+ if "ultraquery_ckpt" in cfg and cfg.ultraquery_ckpt is not None:
+ state = torch.load(cfg.ultraquery_ckpt, map_location="cpu")
+ model.load_state_dict(state["model"])
+
+ if "fast_test" in cfg.train:
+ if util.get_rank() == 0:
+ logger.warning("Quick test mode on. Only evaluate on %d samples for valid" % cfg.train.fast_test)
+ g = torch.Generator()
+ g.manual_seed(1024)
+ valid_data = torch_data.random_split(valid_data, [cfg.train.fast_test, len(valid_data) - cfg.train.fast_test], generator=g)[0]
+
+
+ #model = pyg.compile(model, dynamic=True)
+ model = model.to(device)
+
+ train_and_validate(cfg, model, train_graph, train_data, valid_graph, valid_data, query_id2type=dataset.id2type, device=device, batch_per_epoch=cfg.train.batch_per_epoch, logger=logger)
+ if util.get_rank() == 0:
+ logger.warning(separator)
+ logger.warning("Evaluate on valid")
+ start = timer()
+ val_metrics = test(cfg, model, valid_graph, valid_data, query_id2type=dataset.id2type, device=device, logger=logger)
+ end = timer()
+ # write to the log file
+ # val_metrics['dataset'] = str(dataset)
+ # util.print_metrics_to_file(val_metrics, results_file)
+ logger.warning(f"Valid time: {end - start}")
+ if util.get_rank() == 0:
+ logger.warning(separator)
+ logger.warning("Evaluate on test")
+ metrics = test(cfg, model, test_graph, test_data, query_id2type=dataset.id2type, device=device, logger=logger)
+
+ # write to the log file
+ if util.get_rank() == 0:
+ metrics['dataset'] = str(dataset)
+ query_utils.print_metrics_to_file(metrics, results_file)
diff --git a/ultra/datasets_query.py b/ultra/datasets_query.py
new file mode 100644
index 0000000..659e6c3
--- /dev/null
+++ b/ultra/datasets_query.py
@@ -0,0 +1,711 @@
+import os
+import pickle
+from collections import defaultdict
+from tqdm import tqdm
+
+import torch
+import numpy as np
+from torch.nn import functional as F
+from torch.utils import data as torch_data
+from functools import partial
+
+from torch_scatter import scatter_add
+from torch_geometric.data import Data, InMemoryDataset, download_url, extract_zip
+
+from ultra.query_utils import Query
+from ultra.tasks import build_relation_graph
+from ultra.base_nbfnet import index_to_mask
+
+
+class LogicalQueryDataset(InMemoryDataset):
+ """Logical query dataset."""
+
+ struct2type = {
+ ("e", ("r",)): "1p",
+ ("e", ("r", "r")): "2p",
+ ("e", ("r", "r", "r")): "3p",
+ (("e", ("r",)), ("e", ("r",))): "2i",
+ (("e", ("r",)), ("e", ("r",)), ("e", ("r",))): "3i",
+ ((("e", ("r",)), ("e", ("r",))), ("r",)): "ip",
+ (("e", ("r", "r")), ("e", ("r",))): "pi",
+ (("e", ("r",)), ("e", ("r", "n"))): "2in",
+ (("e", ("r",)), ("e", ("r",)), ("e", ("r", "n"))): "3in",
+ ((("e", ("r",)), ("e", ("r", "n"))), ("r",)): "inp",
+ (("e", ("r", "r")), ("e", ("r", "n"))): "pin",
+ (("e", ("r", "r", "n")), ("e", ("r",))): "pni",
+ (("e", ("r",)), ("e", ("r",)), ("u",)): "2u-DNF",
+ ((("e", ("r",)), ("e", ("r",)), ("u",)), ("r",)): "up-DNF",
+ ((("e", ("r", "n")), ("e", ("r", "n"))), ("n",)): "2u-DM",
+ ((("e", ("r", "n")), ("e", ("r", "n"))), ("n", "r")): "up-DM",
+ }
+
+ def __init__(self, root, transform=None, pre_transform=build_relation_graph,
+ query_types=None, union_type="DNF", train_patterns = None, **kwargs):
+
+ self.query_types = query_types
+ self.union_type = union_type
+ self.train_patterns = train_patterns
+ super().__init__(root, transform, pre_transform)
+ # self.data, self.slices = torch.load(self.processed_paths[0])
+
+ @property
+ def raw_file_names(self):
+ return ["train.txt", "valid.txt", "test.txt"]
+
+ def download(self):
+ download_path = download_url(self.url, self.root)
+ extract_zip(download_path, self.root)
+ # os.unlink(download_path)
+
+ def set_query_types(self):
+ query_types = self.query_types or self.struct2type.values()
+ new_query_types = []
+ for query_type in query_types:
+ if "u" in query_type:
+ if "-" not in query_type:
+ query_type = "%s-%s" % (query_type, self.union_type)
+ elif query_type[query_type.find("-") + 1:] != self.union_type:
+ continue
+ new_query_types.append(query_type)
+ self.id2type = sorted(new_query_types)
+ self.type2id = {t: i for i, t in enumerate(self.id2type)}
+
+ def process(self):
+ """
+ Load the dataset from pickle dumps (BetaE format).
+
+ Parameters:
+ path (str): path to pickle dumps
+ query_types (list of str, optional): query types to load.
+ By default, load all query types.
+ union_type (str, optional): which union type to use, ``DNF`` or ``DM``
+ verbose (int, optional): output verbose level
+ """
+ self.set_query_types()
+ path = self.raw_dir
+
+ with open(os.path.join(path, "id2ent.pkl"), "rb") as fin:
+ entity_vocab = pickle.load(fin)
+ with open(os.path.join(path, "id2rel.pkl"), "rb") as fin:
+ relation_vocab = pickle.load(fin)
+ triplets = []
+ num_samples = []
+ for split in ["train", "valid", "test"]:
+ triplet_file = os.path.join(path, "%s.txt" % split)
+ with open(triplet_file) as fin:
+ num_sample = 0
+ for line in fin:
+ h, r, t = [int(x) for x in line.split()]
+ triplets.append((h, t, r))
+ num_sample += 1
+ num_samples.append(num_sample)
+
+ train_edges = torch.tensor([[t[0], t[1]] for t in triplets[:num_samples[0]]], dtype=torch.long).t()
+ train_edge_types = torch.tensor([t[2] for t in triplets[:num_samples[0]]], dtype=torch.long)
+
+ # The 'inverse_rel_plus_one' property is needed for traversal dropout to determine the way of deriving inverse edges
+ # In BetaE datasets, inv_rel = direct_rel + 1, but in inducitve datasets it is inv_rel = direct_rel + num_relations
+ self.train_graph = Data(edge_index=train_edges, edge_type=train_edge_types,
+ num_nodes=len(entity_vocab), num_relations=len(relation_vocab), inverse_rel_plus_one=True)
+ self.valid_graph = self.train_graph
+ self.test_graph = self.train_graph
+
+ if self.pre_transform is not None:
+ self.train_graph = self.pre_transform(self.train_graph)
+
+ # loading queries
+ queries = []
+ types = []
+ easy_answers = []
+ hard_answers = []
+ num_samples = []
+ max_query_length = 0
+
+ for split in ["train", "valid", "test"]:
+
+ pbar = tqdm(desc="Loading %s-*.pkl" % split, total=3)
+ with open(os.path.join(path, "%s-queries.pkl" % split), "rb") as fin:
+ struct2queries = pickle.load(fin)
+ pbar.update(1)
+ type2queries = {self.struct2type[k]: v for k, v in struct2queries.items()}
+ type2queries = {k: v for k, v in type2queries.items() if k in self.type2id}
+ if split == "train":
+ with open(os.path.join(path, "%s-answers.pkl" % split), "rb") as fin:
+ query2easy_answers = pickle.load(fin)
+ query2hard_answers = defaultdict(set)
+ pbar.update(2)
+ else:
+ with open(os.path.join(path, "%s-easy-answers.pkl" % split), "rb") as fin:
+ query2easy_answers = pickle.load(fin)
+ pbar.update(1)
+ with open(os.path.join(path, "%s-hard-answers.pkl" % split), "rb") as fin:
+ query2hard_answers = pickle.load(fin)
+ pbar.update(1)
+
+ num_sample = sum([len(q) for t, q in type2queries.items()])
+ pbar = tqdm(desc="Processing %s queries" % split, total=num_sample)
+ for type in type2queries:
+ struct_queries = sorted(type2queries[type])
+ for query in struct_queries:
+ easy_answers.append(query2easy_answers[query])
+ hard_answers.append(query2hard_answers[query])
+ query = Query.from_nested(query)
+ queries.append(query)
+ max_query_length = max(max_query_length, len(query))
+ types.append(self.type2id[type])
+ pbar.update(1)
+ num_samples.append(num_sample)
+
+ self.queries = queries
+ self.types = types
+ self.easy_answers = easy_answers
+ self.hard_answers = hard_answers
+ self.num_samples = num_samples
+ self.max_query_length = max_query_length
+
+ def __getitem__(self, index):
+ query = self.queries[index]
+ easy_answer = torch.tensor(list(self.easy_answers[index]), dtype=torch.long)
+ hard_answer = torch.tensor(list(self.hard_answers[index]), dtype=torch.long)
+ return {
+ "query": F.pad(query, (0, self.max_query_length - len(query)), value=query.stop),
+ "type": self.types[index],
+ "easy_answer": index_to_mask(easy_answer, self.train_graph.num_nodes),
+ "hard_answer": index_to_mask(hard_answer, self.train_graph.num_nodes),
+ }
+
+ def __len__(self):
+ return len(self.queries)
+
+ def split(self):
+ offset = 0
+ splits = []
+ for num_sample in self.num_samples:
+ split = torch_data.Subset(self, range(offset, offset + num_sample))
+ splits.append(split)
+ offset += num_sample
+ return splits
+
+ def __repr__(self):
+ return "%s()" % (self.name)
+
+ @property
+ def num_relations(self):
+ return int(self.train_graph.num_relations)
+
+ @property
+ def raw_dir(self):
+ return os.path.join(self.root, self.name) # +raw
+
+ @property
+ def processed_dir(self):
+ return os.path.join(self.root, self.name) # + processed
+
+ @property
+ def processed_file_names(self):
+ return "data.pt"
+
+
+class FB15kLogicalQuery(LogicalQueryDataset):
+
+ name = "FB15k-betae"
+ url = "http://snap.stanford.edu/betae/KG_data.zip"
+ md5 = "d54f92e2e6a64d7f525b8fe366ab3f50"
+
+
+class FB15k237LogicalQuery(LogicalQueryDataset):
+
+ name = "FB15k-237-betae"
+ url = "http://snap.stanford.edu/betae/KG_data.zip"
+ md5 = "d54f92e2e6a64d7f525b8fe366ab3f50"
+
+
+class NELL995LogicalQuery(LogicalQueryDataset):
+
+ name = "NELL-betae"
+ url = "http://snap.stanford.edu/betae/KG_data.zip"
+ md5 = "d54f92e2e6a64d7f525b8fe366ab3f50"
+
+
+class InductiveFB15k237Query(LogicalQueryDataset):
+
+ url = "https://zenodo.org/record/7306046/files/%s.zip"
+
+ md5 = {
+ 550: "e78bb9a7de9bd55813bb17f57941303c",
+ 300: "4db5c172acf83f676c9cf6589e033d7e",
+ 217: "9fde4563c619dc4d2b81af200cf7bc6b",
+ 175: "29ee1dbed7662740a2f001a0c6df8911",
+ 150: "61b545de8e5cdb04832f27842d8c0175",
+ 134: "cd8028c9674dc81f38cd17b03af43fe1",
+ 122: "272d2cc1e3f98f76d02daaf066f9d653",
+ 113: "e4ea60448e918c62779cfa757a096aa9",
+ 106: "6f9a1dcf22108074fb94a05b8377a173",
+ "wikikg": "fa30b189436ab46a2ff083dd6a5e6e0b"
+ }
+
+ @property
+ def raw_file_names(self):
+ return [f"train_queries.pkl", "valid_queries.pkl", "test_queries.pkl"]
+
+ def __init__(self, root, version, transform=None, pre_transform=build_relation_graph, query_types=None, union_type="DNF",
+ train_patterns=('1p', '2p', '3p', '2i', '3i', '2in', '3in', 'inp', 'pni', 'pin'), **kwargs):
+ self.version = version
+ super().__init__(root, transform, pre_transform, query_types=query_types, union_type=union_type, train_patterns=train_patterns)
+
+ def download(self):
+ download_path = download_url(self.url % self.version, self.root)
+ extract_zip(download_path, self.root)
+ # os.unlink(download_path)
+
+ def process(self):
+
+ self.set_query_types()
+ path = self.raw_dir
+
+ # Space of entities 0 ... N is split into 3 sets
+ # Train node IDs: 0 ... K
+ # Val inference ids: K ... K+M
+ # Test inference ids: K+M .... N
+ try:
+ train_triplets = self.load_file(os.path.join(path, "train_graph.txt"))
+ val_inference = self.load_file(os.path.join(path, "val_inference.txt"))
+ test_inference = self.load_file(os.path.join(path, "test_inference.txt"))
+ except FileNotFoundError:
+ print("Loading .pt files")
+ train_triplets = self.load_pt(os.path.join(path, "train_graph.pt"))
+ val_inference = self.load_pt(os.path.join(path, "val_inference.pt"))
+ test_inference = self.load_pt(os.path.join(path, "test_inference.pt"))
+
+ entity_vocab, relation_vocab, inv_ent_vocab, inv_rel_vocab, \
+ tr_nodes, vl_nodes, ts_nodes = self.build_vocab(train_triplets, val_inference, test_inference)
+
+ num_node = len(entity_vocab) if entity_vocab else None
+ num_relation = len(relation_vocab) if relation_vocab else None
+
+ # Training graph: only training triples
+ self.train_graph = Data(edge_index=torch.LongTensor(train_triplets)[:, :2].t(),
+ edge_type=torch.LongTensor(train_triplets)[:, 2],
+ num_nodes=len(tr_nodes), num_relations=num_relation)
+
+ # Validation graph: training triples (0..K) + new validation inference triples (K+1...K+M)
+ self.valid_graph = Data(edge_index=torch.LongTensor(train_triplets + val_inference)[:, :2].t(),
+ edge_type=torch.LongTensor(train_triplets + val_inference)[:, 2],
+ num_nodes=num_node, num_relations=num_relation,
+ restrict_nodes=torch.LongTensor(vl_nodes) # need those for evaluation
+ )
+
+ # Test graph: training triples (0..K) + new test inference triples (K+M+1... N)
+ self.test_graph = Data(edge_index=torch.LongTensor(train_triplets + test_inference)[:, :2].t(),
+ edge_type=torch.LongTensor(train_triplets + test_inference)[:, 2],
+ num_nodes=num_node, num_relations=num_relation,
+ restrict_nodes=torch.LongTensor(ts_nodes), # need those for evaluation
+ )
+
+ if self.pre_transform:
+ self.train_graph = self.pre_transform(self.train_graph)
+ self.valid_graph = self.pre_transform(self.valid_graph)
+ self.test_graph = self.pre_transform(self.test_graph)
+
+ # Full graph (aux purposes)
+ self.graph = Data(edge_index=torch.LongTensor(train_triplets + val_inference + test_inference)[:, :2].t(),
+ edge_type=torch.LongTensor(train_triplets + val_inference + test_inference)[:, 2],
+ num_nodes=num_node, num_relations=num_relation)
+ self.entity_vocab = entity_vocab
+ self.relation_vocab = relation_vocab
+ self.inv_entity_vocab = inv_ent_vocab
+ self.inv_relation_vocab = inv_rel_vocab
+
+ # Need those for evaluation
+ self.valid_nodes = torch.LongTensor(vl_nodes)
+ self.test_nodes = torch.LongTensor(ts_nodes)
+
+ self.load_queries(path=path)
+
+ def load_queries(self, path):
+
+ queries = []
+ type_ids = []
+ easy_answers = []
+ hard_answers = []
+ num_samples = []
+ num_entity_for_sample = []
+ max_query_length = 0
+
+ type2struct = {v: k for k, v in self.struct2type.items()}
+ filtered_training_structs = tuple([type2struct[x] for x in self.train_patterns])
+ for split in ["train", "valid", "test"]:
+ with open(os.path.join(path, "%s_queries.pkl" % split), "rb") as fin:
+ struct2queries = pickle.load(fin)
+ if split == "train":
+ query2hard_answers = defaultdict(lambda: defaultdict(set))
+ with open(os.path.join(path, "%s_answers_hard.pkl" % split), "rb") as fin:
+ query2easy_answers = pickle.load(fin)
+ else:
+ with open(os.path.join(path, "%s_answers_easy.pkl" % split), "rb") as fin:
+ query2easy_answers = pickle.load(fin)
+ with open(os.path.join(path, "%s_answers_hard.pkl" % split), "rb") as fin:
+ query2hard_answers = pickle.load(fin)
+ num_sample = 0
+ structs = sorted(struct2queries.keys(), key=lambda s: self.struct2type[s])
+ structs = tqdm(structs, "Loading %s queries" % split)
+ for struct in structs:
+ query_type = self.struct2type[struct]
+ if query_type not in self.type2id:
+ continue
+ # filter complex patterns ip, pi, 2u, up from training queries - those will be eval only
+ if split == "train" and struct not in filtered_training_structs:
+ print(f"Skipping {query_type} - this will be used in evaluation")
+ continue
+ struct_queries = sorted(struct2queries[struct])
+ for query in struct_queries:
+ # The dataset format is slightly different from BetaE's
+ easy_answers.append(query2easy_answers[struct][query])
+ hard_answers.append(query2hard_answers[struct][query])
+ query = Query.from_nested(query)
+ #query = self.to_postfix_notation(query)
+ max_query_length = max(max_query_length, len(query))
+ queries.append(query)
+ type_ids.append(self.type2id[query_type])
+ num_sample += len(struct_queries)
+ num_entity_for_sample += [getattr(self, "%s_graph" % split).num_nodes] * num_sample
+ num_samples.append(num_sample)
+
+ self.queries = queries
+ self.types = type_ids
+ self.easy_answers = easy_answers
+ self.hard_answers = hard_answers
+ self.num_samples = num_samples
+ self.num_entity_for_sample = num_entity_for_sample
+ self.max_query_length = max_query_length
+
+ def load_file(self, path):
+ triplets = []
+ with open(path) as fin:
+ for line in fin:
+ h, r, t = [int(x) for x in line.split()]
+ triplets.append((h, t, r))
+
+ return triplets
+
+ def load_pt(self, path):
+ triplets = torch.load(path, map_location="cpu")
+ return triplets[:, [0, 2, 1]].tolist()
+
+ def build_vocab(self, train_triples, val_triples, test_triples):
+ # datasets are already shipped with contiguous node IDs from 0 to N, so the total num ents is N+1
+ all_triples = np.array(train_triples+val_triples+test_triples)
+ train_nodes = np.unique(np.array(train_triples)[:, [0, 1]])
+ val_nodes = np.unique(np.array(train_triples + val_triples)[:, [0, 1]])
+ test_nodes = np.unique(np.array(train_triples + test_triples)[:, [0, 1]])
+ num_entities = np.max(all_triples[:, [0, 1]]) + 1
+ num_relations = np.max(all_triples[:, 2]) + 1
+
+ ent_vocab = {i: i for i in range(num_entities)}
+ rel_vocab = {i: i for i in range(num_relations)}
+ inv_ent_vocab = {v:k for k,v in ent_vocab.items()}
+ inv_rel_vocab = {v:k for k,v in rel_vocab.items()}
+
+ return ent_vocab, rel_vocab, inv_ent_vocab, inv_rel_vocab, train_nodes, val_nodes, test_nodes
+
+ def __getitem__(self, index):
+ query = self.queries[index]
+ easy_answer = torch.tensor(list(self.easy_answers[index]), dtype=torch.long)
+ hard_answer = torch.tensor(list(self.hard_answers[index]), dtype=torch.long)
+ # num_entity in the inductive setup is different for different splits, take it from the relevant graph
+ num_entity = self.num_entity_for_sample[index]
+ return {
+ "query": F.pad(query, (0, self.max_query_length - len(query)), value=query.stop),
+ "type": self.types[index],
+ "easy_answer": index_to_mask(easy_answer, num_entity),
+ "hard_answer": index_to_mask(hard_answer, num_entity),
+ }
+
+ @property
+ def name(self):
+ return f"{self.version}"
+
+ def __repr__(self):
+ return f"fb_{self.version}"
+
+
+class WikiTopicsQuery(InductiveFB15k237Query):
+
+ url = "https://reltrans.s3.us-east-2.amazonaws.com/WikiTopics_QE.zip"
+ md5 = None
+
+ @property
+ def raw_file_names(self):
+ return [f"train_queries.pkl", "valid_queries.pkl", "test_queries.pkl"]
+
+ def __init__(self, root, version, transform=None, pre_transform=build_relation_graph, query_types=None, union_type="DNF",
+ train_patterns=('1p', '2p', '3p', '2i', '3i', '2in', '3in', 'inp', 'pni', 'pin'), **kwargs):
+ #self.version = version
+ super().__init__(root, version, transform, pre_transform, query_types=query_types, union_type=union_type, train_patterns=train_patterns)
+
+ def download(self):
+ download_path = download_url(self.url, self.root)
+ extract_zip(download_path, self.root)
+ # os.unlink(download_path)
+
+ def process(self):
+
+ self.set_query_types()
+ # Download data if it's not there -> Add wt prefix?
+ path = self.raw_dir
+
+ # WikiTopics are standard inductive datasets: train/valid graph is separated from the test
+ # Space of entities 0 ... N is split into 3 sets
+ # Train node IDs: 0 ... K
+ # Val inference ids: 0 ..., K
+ # Test inference ids: O ..., M
+ try:
+ train_triplets = self.load_file(os.path.join(path, "train_graph.txt"))
+ val_inference = self.load_file(os.path.join(path, "val_inference.txt"))
+ test_inference = self.load_file(os.path.join(path, "test_inference.txt"))
+ except FileNotFoundError:
+ print("Loading .pt files")
+ train_triplets = self.load_pt(os.path.join(path, "train_graph.pt"))
+ val_inference = self.load_pt(os.path.join(path, "val_inference.pt"))
+ test_inference = self.load_pt(os.path.join(path, "test_inference.pt"))
+
+ train_entity_vocab, train_rel_vocab, test_ent_vocab, test_rel_vocab, \
+ tr_nodes, vl_nodes, ts_nodes = self.build_vocab(train_triplets, val_inference, test_inference)
+
+ # Training graph: only training triples
+ self.train_graph = Data(edge_index=torch.LongTensor(train_triplets)[:, :2].t(),
+ edge_type=torch.LongTensor(train_triplets)[:, 2],
+ num_nodes=len(tr_nodes), num_relations=len(train_rel_vocab))
+
+ # Validation graph: the same as training
+ self.valid_graph = Data(edge_index=torch.LongTensor(train_triplets)[:, :2].t(),
+ edge_type=torch.LongTensor(train_triplets)[:, 2],
+ num_nodes=len(tr_nodes), num_relations=len(train_rel_vocab))
+
+ # Test graph: a new graph with new entities/relations
+ self.test_graph = Data(edge_index=torch.LongTensor(test_inference)[:, :2].t(),
+ edge_type=torch.LongTensor(test_inference)[:, 2],
+ num_nodes=len(ts_nodes), num_relations=len(test_rel_vocab))
+ if self.pre_transform:
+ self.train_graph = self.pre_transform(self.train_graph)
+ self.valid_graph = self.pre_transform(self.valid_graph)
+ self.test_graph = self.pre_transform(self.test_graph)
+
+ # dummy graph for compatibility purposes
+ self.graph = self.test_graph
+
+ # Full graph (aux purposes)
+ # self.full_graph_valid = data.Graph(train_triplets + val_inference + test_inference,
+ # num_node=num_node, num_relation=num_relation)
+ self.train_entity_vocab, self.inv_train_ent_vocab = train_entity_vocab, {v:k for k,v in train_entity_vocab.items()}
+ self.train_relation_vocab, self.inv_train_rel_vocab = train_rel_vocab, {v:k for k,v in train_rel_vocab.items()}
+ self.test_entity_vocab, self.inv_test_ent_vocab = test_ent_vocab, {v:k for k,v in test_ent_vocab.items()}
+ self.test_relation_vocab, self.inv_test_rel_vocab = test_rel_vocab, {v:k for k,v in test_rel_vocab.items()}
+
+ # Need those for evaluation
+ self.valid_nodes = torch.tensor(vl_nodes, dtype=torch.long)
+ self.test_nodes = torch.tensor(ts_nodes, dtype=torch.long)
+
+ self.load_queries(path)
+
+ def build_vocab(self, train_triples, val_triples, test_triples):
+ # In WikiTopics, validation graph is the same as train, but test is different
+ train_triples, test_triples = np.array(train_triples), np.array(test_triples)
+ train_nodes = np.unique(train_triples[:, [0, 1]])
+ #val_nodes = np.unique(np.array(train_triples + val_triples)[:, [0, 1]])
+ test_nodes = np.unique(test_triples[:, [0, 1]])
+
+ num_train_entities = len(train_nodes)
+ num_test_entities = len(test_nodes)
+ num_train_relations = np.max(train_triples[:, 2]) + 1
+ num_test_relations = np.max(test_triples[:, 2]) + 1
+
+ train_ent_vocab = {i: i for i in range(num_train_entities)}
+ train_rel_vocab = {i: i for i in range(num_train_relations)}
+ test_ent_vocab = {i: i for i in range(num_test_entities)}
+ test_rel_vocab = {i: i for i in range(num_test_relations)}
+
+ return train_ent_vocab, train_rel_vocab, test_ent_vocab, test_rel_vocab, train_nodes, train_nodes, test_nodes
+
+ def __getitem__(self, index):
+ query = self.queries[index]
+ easy_answer = torch.tensor(list(self.easy_answers[index]), dtype=torch.long)
+ hard_answer = torch.tensor(list(self.hard_answers[index]), dtype=torch.long)
+ # num_entity in the inductive setup is different for different splits, take it from the relevant graph
+ num_entity = self.num_entity_for_sample[index]
+ return {
+ "query": F.pad(query, (0, self.max_query_length - len(query)), value=query.stop),
+ "type": self.types[index],
+ "easy_answer": index_to_mask(easy_answer, num_entity),
+ "hard_answer": index_to_mask(hard_answer, num_entity),
+ }
+
+ @property
+ def raw_dir(self):
+ return os.path.join(self.root, "WikiTopics_QE", self.name) # +raw
+
+ @property
+ def processed_dir(self):
+ return os.path.join(self.root, "WikiTopics_QE", self.name) # + processed
+
+ @property
+ def name(self):
+ return f"{self.version}"
+
+ def __repr__(self):
+ return f"wikitopics_{self.version}"
+
+
+class InductiveFB15k237QueryExtendedEval(InductiveFB15k237Query):
+
+ """
+ This dataset is almost equivalent to the original InductiveComp except that
+ validation and test sets are training queries with a new (possibly larger) answer set
+ being executed on a bigger validation or test graph
+
+ We will load only the train_queries file and 3 different answer sets:
+ 1. train_queries_hard - original answers
+ 2. train_queries_val - answers to train queries over the validation graph (train + new val nodes and edges)
+ 3. train_queries_test - answers to train queries over the test graph (train + new test nodes and edges)
+
+ The dataset is supposed to be used for evaluation/inference only,
+ so make sure num_epochs is set to 0 in the config yaml file
+ """
+
+ def load_queries(self, path):
+ easy_answers = []
+ hard_answers = []
+ queries = []
+ type_ids = []
+ num_samples = []
+ num_entity_for_sample = []
+ max_query_length = 0
+
+ # in this setup, we evaluate train queries on extended validation/test graphs
+ # in extended graphs, training queries now have more answers
+ # conceptually, all answers are "easy", but for eval purposes we load them as hard
+ with open(os.path.join(path, "train_queries.pkl"), "rb") as fin:
+ struct2queries = pickle.load(fin)
+
+ #type2struct = {v: k for k, v in self.struct2type.items()}
+ #filtered_training_structs = tuple([type2struct[x] for x in train_patterns])
+ for split in ["train", "valid", "test"]:
+ if split == "train":
+ with open(os.path.join(path, "train_answers_hard.pkl"), "rb") as fin:
+ query2hard_answers = pickle.load(fin)
+ else:
+ # load new answers
+ with open(os.path.join(path, "train_answers_%s.pkl" % split), "rb") as fin:
+ query2hard_answers = pickle.load(fin)
+
+ query2easy_answers = defaultdict(lambda: defaultdict(set))
+
+ num_sample = 0
+ structs = sorted(struct2queries.keys(), key=lambda s: self.struct2type[s])
+ structs = tqdm(structs, "Loading %s queries" % split)
+ for struct in structs:
+ query_type = self.struct2type[struct]
+ if query_type not in self.type2id:
+ continue
+
+ struct_queries = struct2queries[struct]
+ for i, query in enumerate(struct_queries):
+ # The dataset format is slightly different from BetaE's
+ #easy_answers.append(query2easy_answers[struct][i])
+ q_index = i if split != "train" else query
+ hard_answers.append(query2hard_answers[struct][q_index])
+ query = Query.from_nested(query)
+ max_query_length = max(max_query_length, len(query))
+ queries.append(query)
+ type_ids.append(self.type2id[query_type])
+ num_sample += len(struct_queries)
+
+ num_entity_for_sample += [getattr(self, "%s_graph" % split).num_nodes] * num_sample
+ num_samples.append(num_sample)
+
+ self.queries = queries
+ self.types = type_ids
+
+ self.hard_answers = hard_answers
+ self.easy_answers = [[] for _ in range(len(hard_answers))]
+ self.num_samples = num_samples
+ self.num_entity_for_sample = num_entity_for_sample
+ self.max_query_length = max_query_length
+
+
+
+class JointDataset(LogicalQueryDataset):
+
+ datasets_map = {
+ 'FB15k237': FB15k237LogicalQuery,
+ 'FB15k': FB15kLogicalQuery,
+ 'NELL995': NELL995LogicalQuery,
+ # TODO
+ 'FB_550': partial(InductiveFB15k237Query, version=550),
+ 'FB_300': partial(InductiveFB15k237Query, version=300),
+ 'FB_217': partial(InductiveFB15k237Query, version=217),
+ 'FB_175': partial(InductiveFB15k237Query, version=175),
+ 'FB_150': partial(InductiveFB15k237Query, version=150),
+ 'FB_134': partial(InductiveFB15k237Query, version=134),
+ 'FB_122': partial(InductiveFB15k237Query, version=122),
+ 'FB_113': partial(InductiveFB15k237Query, version=113),
+ 'FB_106': partial(InductiveFB15k237Query, version=106),
+ # WikiTopics
+ 'WT_art': partial(WikiTopicsQuery, version="art"),
+ 'WT_award': partial(WikiTopicsQuery, version="award"),
+ 'WT_edu': partial(WikiTopicsQuery, version="edu"),
+ 'WT_health': partial(WikiTopicsQuery, version="health"),
+ 'WT_infra': partial(WikiTopicsQuery, version="infra"),
+ 'WT_loc': partial(WikiTopicsQuery, version="loc"),
+ 'WT_org': partial(WikiTopicsQuery, version="org"),
+ 'WT_people': partial(WikiTopicsQuery, version="people"),
+ 'WT_sci': partial(WikiTopicsQuery, version="sci"),
+ 'WT_sport': partial(WikiTopicsQuery, version="sport"),
+ 'WT_tax': partial(WikiTopicsQuery, version="tax"),
+ }
+
+ def __init__(self, path, graphs, query_types=None, union_type="DNF"):
+ # super(JointDataset, self).__init__(*args, **kwargs)
+
+ # Initialize all specified datasets
+ self.graphs = [self.datasets_map[dataset](path=path, query_types=query_types, union_type=union_type) for dataset in graphs.split(',')]
+ self.graph_names = graphs
+
+ # Total number of samples obtained from iterating over all graphs
+ self.num_samples = [sum(k) for k in zip(*[graph.num_samples for graph in self.graphs])]
+ self.valid_samples = [torch.cumsum(torch.tensor(k).flatten(), dim=0) for k in zip([graph.num_samples for graph in self.graphs])]
+
+ def __getitem__(self, index):
+ # send a dummy entry, we'll be sampling edges in the collator function
+ return torch.zeros(1,1)
+
+ def __len__(self):
+ return sum([graph.queries for graph in self.graphs])
+
+ def split(self):
+ splits = [[],[],[]]
+ for graph in self.graphs:
+ offset = 0
+ for i, num_sample in enumerate(graph.num_samples):
+ split = torch_data.Subset(graph, range(offset, offset + num_sample))
+ splits[i].append(split)
+ offset += num_sample
+ return splits
+
+ @property
+ def num_nodes(self):
+ """Number of entities in the joint graph"""
+ return sum(graph.train_graph.num_nodes for graph in self.graphs)
+
+ @property
+ def num_edges(self):
+ """Number of edges in the joint graph"""
+ return sum(graph.train_graph.num_edges for graph in self.graphs)
+
+ @property
+ def num_relations(self):
+ """Number of relations in the joint graph"""
+ return sum(graph.train_graph.num_relations for graph in self.graphs)
+
+
+
diff --git a/ultra/layers.py b/ultra/layers.py
index 32dca7d..2de7677 100644
--- a/ultra/layers.py
+++ b/ultra/layers.py
@@ -3,6 +3,7 @@
from torch.nn import functional as F
from torch_scatter import scatter
+import torch_geometric
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.utils import degree
from typing import Tuple
@@ -104,7 +105,12 @@ def propagate(self, edge_index, size=None, **kwargs):
size = self._check_input(edge_index, size)
coll_dict = self._collect(self._fused_user_args, edge_index, size, kwargs)
- msg_aggr_kwargs = self.inspector.distribute("message_and_aggregate", coll_dict)
+ # TODO: use from packaging.version import parse as parse_version as by default 2.4 > 2.14 which is wrong
+ # Let's collectively hope there will be PyG 3.0 after 2.9 and not 2.10
+ pyg_version = [int(i) for i in torch_geometric.__version__.split(".")]
+ col_fn = self.inspector.distribute if pyg_version[1] <= 4 else self.inspector.collect_param_data
+
+ msg_aggr_kwargs = col_fn("message_and_aggregate", coll_dict)
for hook in self._message_and_aggregate_forward_pre_hooks.values():
res = hook(self, (edge_index, msg_aggr_kwargs))
if res is not None:
@@ -115,7 +121,8 @@ def propagate(self, edge_index, size=None, **kwargs):
if res is not None:
out = res
- update_kwargs = self.inspector.distribute("update", coll_dict)
+ # PyG 2.5+ distribute -> collect_param_data
+ update_kwargs = col_fn("update", coll_dict)
out = self.update(out, **update_kwargs)
for hook in self._propagate_forward_hooks.values():
diff --git a/ultra/models.py b/ultra/models.py
index 525975f..8ddf63f 100644
--- a/ultra/models.py
+++ b/ultra/models.py
@@ -10,8 +10,9 @@ def __init__(self, rel_model_cfg, entity_model_cfg):
# kept that because super Ultra sounds cool
super(Ultra, self).__init__()
- self.relation_model = RelNBFNet(**rel_model_cfg)
- self.entity_model = EntityNBFNet(**entity_model_cfg)
+ # adding a bit more flexibility to initializing proper rel/ent classes from the configs
+ self.relation_model = globals()[rel_model_cfg.pop('class')](**rel_model_cfg)
+ self.entity_model = globals()[entity_model_cfg.pop('class')](**entity_model_cfg)
def forward(self, data, batch):
@@ -207,7 +208,71 @@ def forward(self, data, relation_representations, batch):
score = self.mlp(feature).squeeze(-1)
return score.view(shape)
+
+class QueryNBFNet(EntityNBFNet):
+ """
+ The entity-level reasoner for UltraQuery-like complex query answering pipelines
+ Almost the same as EntityNBFNet except that
+ (1) we already get the initial node features at the forward pass time
+ and don't have to read the triples batch
+ (2) we get `query` from the outer loop
+ (3) we return a distribution over all nodes (assuming t_index = all nodes)
+ """
+ def bellmanford(self, data, node_features, query, separate_grad=False):
+
+ size = (data.num_nodes, data.num_nodes)
+ edge_weight = torch.ones(data.num_edges, device=query.device)
+
+ hiddens = []
+ edge_weights = []
+ layer_input = node_features
+
+ for layer in self.layers:
+
+ # for visualization
+ if separate_grad:
+ edge_weight = edge_weight.clone().requires_grad_()
+
+ # Bellman-Ford iteration, we send the original boundary condition in addition to the updated node states
+ hidden = layer(layer_input, query, node_features, data.edge_index, data.edge_type, size, edge_weight)
+ if self.short_cut and hidden.shape == layer_input.shape:
+ # residual connection here
+ hidden = hidden + layer_input
+ hiddens.append(hidden)
+ edge_weights.append(edge_weight)
+ layer_input = hidden
+
+ # original query (relation type) embeddings
+ node_query = query.unsqueeze(1).expand(-1, data.num_nodes, -1) # (batch_size, num_nodes, input_dim)
+ if self.concat_hidden:
+ output = torch.cat(hiddens + [node_query], dim=-1)
+ else:
+ output = torch.cat([hiddens[-1], node_query], dim=-1)
+
+ return {
+ "node_feature": output,
+ "edge_weights": edge_weights,
+ }
+
+ def forward(self, data, node_features, relation_representations, query):
+
+ # initialize relations in each NBFNet layer (with uinque projection internally)
+ for layer in self.layers:
+ layer.relation = relation_representations
+
+ # we already did traversal_dropout in the outer loop of UltraQuery
+ # if self.training:
+ # # Edge dropout in the training mode
+ # # here we want to remove immediate edges (head, relation, tail) from the edge_index and edge_types
+ # # to make NBFNet iteration learn non-trivial paths
+ # data = self.remove_easy_edges(data, h_index, t_index, r_index)
+
+ # node features arrive in shape (bs, num_nodes, dim)
+ # NBFNet needs batch size on the first place
+ output = self.bellmanford(data, node_features, query) # (num_nodes, batch_size, feature_dim)
+ score = self.mlp(output["node_feature"]).squeeze(-1) # (bs, num_nodes)
+ return score
diff --git a/ultra/query_utils.py b/ultra/query_utils.py
new file mode 100644
index 0000000..5c62951
--- /dev/null
+++ b/ultra/query_utils.py
@@ -0,0 +1,521 @@
+import csv
+import copy
+import torch
+import torch.nn.functional as F
+from torch import Tensor
+from torch import distributed as dist
+
+from torch_scatter import scatter_add, scatter_mean, scatter_max
+
+from ultra import variadic, datasets_query
+
+
+class Query(torch.Tensor):
+ """Tensor storage of logical queries in postfix notations."""
+
+ projection = 1 << 58
+ intersection = 1 << 59
+ union = 1 << 60
+ negation = 1 << 61
+ stop = 1 << 62
+ operation = projection | intersection | union | negation | stop
+
+ stack_size = 2
+
+ def __new__(cls, data, device=None):
+ query = torch.as_tensor(data, dtype=torch.long, device=device)
+ query = torch.Tensor._make_subclass(cls, query)
+ return query
+
+ @classmethod
+ def from_nested(cls, nested, binary_op=True):
+ """Construct a logical query from nested tuples (BetaE format)."""
+ if not binary_op:
+ raise ValueError("The current implementation doesn't support nary operations")
+ query = cls.nested_to_postfix(nested, binary_op=binary_op)
+ query.append(cls.stop)
+ return cls(query)
+
+ @classmethod
+ def nested_to_postfix(cls, nested, binary_op=True):
+ """Recursively convert nested tuples into a postfix notation."""
+ query = []
+
+ if len(nested) == 2 and isinstance(nested[-1][-1], int):
+ var, unary_ops = nested
+ if isinstance(var, tuple):
+ query += cls.nested_to_postfix(var, binary_op=binary_op)
+ else:
+ query.append(var)
+ for op in unary_ops:
+ if op == -2:
+ query.append(cls.negation)
+ else:
+ query.append(cls.projection | op)
+ else:
+ if len(nested[-1]) > 1:
+ vars, nary_op = nested, cls.intersection
+ else:
+ vars, nary_op = nested[:-1], cls.union
+ num_args = 2 if binary_op else len(vars)
+ op = nary_op | num_args
+ for i, var in enumerate(vars):
+ query += cls.nested_to_postfix(var)
+ if i + 1 >= num_args:
+ query.append(op)
+
+ return query
+
+ def to_readable(self):
+ """Convert this logical query to a human readable string."""
+ if self.ndim > 1:
+ raise ValueError("readable() can only be called for a single query")
+
+ num_variable = 0
+ stack = []
+ lines = []
+ for op in self:
+ if op.is_operand():
+ entity = op.get_operand().item()
+ stack.append(str(entity))
+ else:
+ var = chr(ord("A") + num_variable)
+ if op.is_projection():
+ relation = op.get_operand().item()
+ line = "%s <- projection_%d(%s)" % (var, relation, stack.pop())
+ elif op.is_intersection():
+ num_args = op.get_operand()
+ args = stack[-num_args:]
+ stack = stack[:-num_args]
+ line = "%s <- intersection(%s)" % (var, ", ".join(args))
+ elif op.is_union():
+ num_args = op.get_operand().item()
+ args = stack[-num_args:]
+ stack = stack[:-num_args]
+ line = "%s <- union(%s, %s)" % (var, ", ".join(args))
+ elif op.is_negation():
+ line = "%s <- negation(%s)" % (var, stack.pop())
+ elif op.is_stop():
+ break
+ else:
+ raise ValueError("Unknown operator `%d`" % op)
+ lines.append(line)
+ stack.append(var)
+ num_variable += 1
+
+ if len(stack) > 1:
+ raise ValueError("Invalid query. More operands than expected")
+ line = "\n".join(lines)
+ return line
+
+ def computation_graph(self):
+ """Get the computation graph of logical queries. Used for visualization."""
+ query = self.view(-1, self.shape[-1])
+ stack = Stack(len(query), self.stack_size, dtype=torch.long, device=query.device)
+ # pointer to the next operator that consumes the output of this operator
+ pointer = -torch.ones(query.shape, dtype=torch.long, device=query.device)
+ # depth of each operator in the computation graph
+ depth = -torch.ones(query.shape, dtype=torch.long, device=query.device)
+ # width of the substree covered by each operator
+ width = -torch.ones(query.shape, dtype=torch.long, device=query.device)
+
+ for i, op in enumerate(query.t()):
+ is_operand = op.is_operand()
+ is_unary = op.is_projection() | op.is_negation()
+ is_binary = op.is_intersection() | op.is_union()
+ is_stop = op.is_stop()
+ if is_operand.any():
+ stack.push(is_operand, i)
+ depth[is_operand, i] = 0
+ width[is_operand, i] = 1
+ if is_unary.any():
+ prev = stack.pop(is_unary)
+ pointer[is_unary, prev] = i
+ depth[is_unary, i] = depth[is_unary, prev] + 1
+ width[is_unary, i] = width[is_unary, prev]
+ stack.push(is_unary, i)
+ if is_binary.any():
+ prev_y = stack.pop(is_binary)
+ prev_x = stack.pop(is_binary)
+ pointer[is_binary, prev_y] = i
+ pointer[is_binary, prev_x] = i
+ depth[is_binary, i] = torch.max(depth[is_binary, prev_x], depth[is_binary, prev_y]) + 1
+ width[is_binary, i] = width[is_binary, prev_x] + width[is_binary, prev_y]
+ stack.push(is_binary, i)
+ if is_stop.all():
+ break
+
+ # each operator covers leaf nodes [left, right)
+ left = torch.where(depth > 0, 0, -1)
+ right = torch.where(depth > 0, width.max(), -1)
+ # backtrack to update left and right
+ for i in reversed(range(query.shape[-1])):
+ has_pointer = pointer[:, i] != -1
+ ptr = pointer[has_pointer, i]
+ depth[has_pointer, i] = depth[has_pointer, ptr] - 1
+ left[has_pointer, i] = left[has_pointer, ptr] + width[has_pointer, ptr] - width[has_pointer, i]
+ right[has_pointer, i] = left[has_pointer, i] + width[has_pointer, i]
+ width[has_pointer, ptr] -= width[has_pointer, i]
+
+ pointer = pointer.view_as(self)
+ depth = depth.view_as(self)
+ left = left.view_as(self)
+ right = right.view_as(self)
+ return pointer, depth, left, right
+
+ def is_operation(self):
+ return (self & self.operation > 0)
+
+ def is_operand(self):
+ return ~(self & self.operation > 0)
+
+ def is_projection(self):
+ return self & self.projection > 0
+
+ def is_intersection(self):
+ return self & self.intersection > 0
+
+ def is_union(self):
+ return self & self.union > 0
+
+ def is_negation(self):
+ return self & self.negation > 0
+
+ def is_stop(self):
+ return self & self.stop > 0
+
+ def get_operation(self):
+ return self & self.operation
+
+ def get_operand(self):
+ return self & ~self.operation
+
+ def __iter__(self):
+ for i in range(len(self)):
+ yield self[i]
+
+
+class Stack(object):
+ """
+ Batch of stacks implemented in PyTorch.
+
+ Parameters:
+ batch_size (int): batch size
+ stack_size (int): max stack size
+ shape (tuple of int, optional): shape of each element in the stack
+ dtype (torch.dtype): dtype
+ device (torch.device): device
+ """
+
+ def __init__(self, batch_size, stack_size, *shape, dtype=None, device=None):
+ self.stack = torch.zeros(batch_size, stack_size, *shape, dtype=dtype, device=device)
+ self.SP = torch.zeros(batch_size, dtype=torch.long, device=device)
+ self.batch_size = batch_size
+ self.stack_size = stack_size
+
+ def push(self, mask, value):
+ if (self.SP[mask] >= self.stack_size).any():
+ raise ValueError("Stack overflow")
+ self.stack[mask, self.SP[mask]] = value
+ self.SP[mask] += 1
+
+ def pop(self, mask=None):
+ if (self.SP[mask] < 1).any():
+ raise ValueError("Stack underflow")
+ if mask is None:
+ mask = torch.ones(self.batch_size, dtype=torch.bool, device=self.stack.device)
+ self.SP[mask] -= 1
+ return self.stack[mask, self.SP[mask]]
+
+ def top(self, mask=None):
+ if (self.SP < 1).any():
+ raise ValueError("Stack is empty")
+ if mask is None:
+ mask = torch.ones(self.batch_size, dtype=torch.bool, device=self.stack.device)
+ return self.stack[mask, self.SP[mask] - 1]
+
+
+def gather_results(pred, target, rank, world_size, device):
+ # for multi-gpu setups: join results together
+ # for single-gpu setups: doesn't do anything special
+ ranking, num_pred = pred
+ type, answer_ranking, num_easy, num_hard = target
+
+ all_size_r = torch.zeros(world_size, dtype=torch.long, device=device)
+ all_size_ar = torch.zeros(world_size, dtype=torch.long, device=device)
+ all_size_p = torch.zeros(world_size, dtype=torch.long, device=device)
+ all_size_r[rank] = len(ranking)
+ all_size_ar[rank] = len(answer_ranking)
+ all_size_p[rank] = len(num_pred)
+ if world_size > 1:
+ dist.all_reduce(all_size_r, op=dist.ReduceOp.SUM)
+ dist.all_reduce(all_size_ar, op=dist.ReduceOp.SUM)
+ dist.all_reduce(all_size_p, op=dist.ReduceOp.SUM)
+
+ # obtaining all ranks
+ cum_size_r = all_size_r.cumsum(0)
+ cum_size_ar = all_size_ar.cumsum(0)
+ cum_size_p = all_size_p.cumsum(0)
+
+ all_ranking = torch.zeros(all_size_r.sum(), dtype=torch.long, device=device)
+ all_num_pred = torch.zeros(all_size_p.sum(), dtype=torch.long, device=device)
+ all_types = torch.zeros(all_size_p.sum(), dtype=torch.long, device=device)
+ all_answer_ranking = torch.zeros(all_size_ar.sum(), dtype=torch.long, device=device)
+ all_num_easy = torch.zeros(all_size_p.sum(), dtype=torch.long, device=device)
+ all_num_hard = torch.zeros(all_size_p.sum(), dtype=torch.long, device=device)
+
+ all_ranking[cum_size_r[rank] - all_size_r[rank]: cum_size_r[rank]] = ranking
+ all_num_pred[cum_size_p[rank] - all_size_p[rank]: cum_size_p[rank]] = num_pred
+ all_types[cum_size_p[rank] - all_size_p[rank]: cum_size_p[rank]] = type
+ all_answer_ranking[cum_size_ar[rank] - all_size_ar[rank]: cum_size_ar[rank]] = answer_ranking
+ all_num_easy[cum_size_p[rank] - all_size_p[rank]: cum_size_p[rank]] = num_easy
+ all_num_hard[cum_size_p[rank] - all_size_p[rank]: cum_size_p[rank]] = num_hard
+
+ if world_size > 1:
+ dist.all_reduce(all_ranking, op=dist.ReduceOp.SUM)
+ dist.all_reduce(all_num_pred, op=dist.ReduceOp.SUM)
+ dist.all_reduce(all_types, op=dist.ReduceOp.SUM)
+ dist.all_reduce(all_answer_ranking, op=dist.ReduceOp.SUM)
+ dist.all_reduce(all_num_easy, op=dist.ReduceOp.SUM)
+ dist.all_reduce(all_num_hard, op=dist.ReduceOp.SUM)
+
+ return (all_ranking.cpu(), all_num_pred.cpu()), (all_types.cpu(), all_answer_ranking.cpu(), all_num_easy.cpu(), all_num_hard.cpu())
+
+def batch_evaluate(pred, target, limit_nodes=None):
+ type, easy_answer, hard_answer = target
+
+ num_easy = easy_answer.sum(dim=-1)
+ num_hard = hard_answer.sum(dim=-1)
+ num_answer = num_easy + num_hard
+ # answer2query = functional._size_to_index(num_answer)
+ answer2query = torch.repeat_interleave(num_answer)
+
+ num_entity = pred.shape[-1]
+
+ # in inductive (e) fb_ datasets, the number of nodes in the graph structure might exceed
+ # the actual number of nodes in the graph, so we'll mask unused nodes
+ if limit_nodes is not None:
+ # print(f"Keeping only {len(limit_nodes)} nodes out of {num_entity}")
+ keep_mask = torch.zeros(num_entity, dtype=torch.bool, device=limit_nodes.device)
+ keep_mask[limit_nodes] = 1
+ #keep_mask = F.one_hot(limit_nodes, num_entity)
+ pred[:, ~keep_mask] = float('-inf')
+
+ order = pred.argsort(dim=-1, descending=True)
+
+ range = torch.arange(num_entity, device=pred.device)
+ ranking = scatter_add(range.expand_as(order), order, dim=-1)
+
+ easy_ranking = ranking[easy_answer]
+ hard_ranking = ranking[hard_answer]
+ # unfiltered rankings of all answers
+ answer_ranking = variadic._extend(easy_ranking, num_easy, hard_ranking, num_hard)[0]
+ order_among_answer = variadic.variadic_sort(answer_ranking, num_answer)[1]
+ order_among_answer = order_among_answer + (num_answer.cumsum(0) - num_answer)[answer2query]
+ ranking_among_answer = scatter_add(variadic.variadic_arange(num_answer), order_among_answer)
+
+ # filtered rankings of all answers
+ ranking = answer_ranking - ranking_among_answer + 1
+ ends = num_answer.cumsum(0)
+ starts = ends - num_hard
+ hard_mask = variadic.multi_slice_mask(starts, ends, ends[-1])
+ # filtered rankings of hard answers
+ ranking = ranking[hard_mask]
+
+ return ranking, answer_ranking
+
+def evaluate(pred, target, metrics, id2type):
+ ranking, num_pred = pred
+ type, answer_ranking, num_easy, num_hard = target
+
+ metric = {}
+ for _metric in metrics:
+ if _metric == "mrr":
+ answer_score = 1 / ranking.float()
+ query_score = variadic.variadic_mean(answer_score, num_hard)
+ type_score = scatter_mean(query_score, type, dim_size=len(id2type))
+ elif _metric.startswith("hits@"):
+ threshold = int(_metric[5:])
+ answer_score = (ranking <= threshold).float()
+ query_score = variadic.variadic_mean(answer_score, num_hard)
+ type_score = scatter_mean(query_score, type, dim_size=len(id2type))
+ elif _metric == "mape":
+ query_score = (num_pred - num_easy - num_hard).abs() / (num_easy + num_hard).float()
+ type_score = scatter_mean(query_score, type, dim_size=len(id2type))
+ elif _metric == "spearmanr":
+ type_score = []
+ for i in range(len(id2type)):
+ mask = type == i
+ score = spearmanr(num_pred[mask], num_easy[mask] + num_hard[mask])
+ type_score.append(score)
+ type_score = torch.stack(type_score)
+ elif _metric == "auroc":
+ ends = (num_easy + num_hard).cumsum(0)
+ starts = ends - num_hard
+ target = variadic.multi_slice_mask(starts, ends, len(answer_ranking)).float()
+ answer_score = variadic_area_under_roc(answer_ranking, target, num_easy + num_hard)
+ mask = (num_easy > 0) & (num_hard > 0)
+ query_score = answer_score[mask]
+ type_score = scatter_mean(query_score, type[mask], dim_size=len(id2type))
+ else:
+ raise ValueError("Unknown metric `%s`" % _metric)
+
+ score = type_score.mean()
+ is_neg = torch.tensor(["n" in t for t in id2type], device=ranking.device)
+ is_epfo = ~is_neg
+ name = _metric
+ for i, query_type in enumerate(id2type):
+ metric["[%s] %s" % (query_type, name)] = type_score[i].item()
+ if is_epfo.any():
+ epfo_score = variadic.masked_mean(type_score, is_epfo)
+ metric["[EPFO] %s" % name] = epfo_score.item()
+ if is_neg.any():
+ neg_score = variadic.masked_mean(type_score, is_neg)
+ metric["[negation] %s" % name] = neg_score.item()
+ metric[name] = score.item()
+
+ return metric
+
+def variadic_area_under_roc(pred, target, size):
+ """
+ Area under receiver operating characteristic curve (ROC) for sets with variadic sizes.
+
+ Suppose there are :math:`N` sets, and the sizes of all sets are summed to :math:`B`.
+
+ Parameters:
+ pred (Tensor): prediction of shape :math:`(B,)`
+ target (Tensor): target of shape :math:`(B,)`.
+ size (Tensor): size of sets of shape :math:`(N,)`
+ """
+
+ index2graph = torch.repeat_interleave(size)
+ _, order = variadic.variadic_sort(pred, size, descending=True)
+ cum_size = (size.cumsum(0) - size)[index2graph]
+ target = target[order + cum_size]
+ total_hit = variadic.variadic_sum(target, size)
+ total_hit = total_hit.cumsum(0) - total_hit
+ hit = target.cumsum(0) - total_hit[index2graph]
+ hit = torch.where(target == 0, hit, torch.zeros_like(hit))
+ all = variadic.variadic_sum((target == 0).float(), size) * \
+ variadic.variadic_sum((target == 1).float(), size)
+ auroc = variadic.variadic_sum(hit, size) / (all + 1e-10)
+ return auroc
+
+def spearmanr(pred, target):
+ """
+ Spearman correlation between prediction and target.
+
+ Parameters:
+ pred (Tensor): prediction of shape :math: `(N,)`
+ target (Tensor): target of shape :math: `(N,)`
+ """
+
+ def get_ranking(input):
+ input_set, input_inverse = input.unique(return_inverse=True)
+ order = input_inverse.argsort()
+ ranking = torch.zeros(len(input_inverse), device=input.device)
+ ranking[order] = torch.arange(1, len(input) + 1, dtype=torch.float, device=input.device)
+
+ # for elements that have the same value, replace their rankings with the mean of their rankings
+ mean_ranking = scatter_mean(ranking, input_inverse, dim=0, dim_size=len(input_set))
+ ranking = mean_ranking[input_inverse]
+ return ranking
+
+ pred = get_ranking(pred)
+ target = get_ranking(target)
+ covariance = (pred * target).mean() - pred.mean() * target.mean()
+ pred_std = pred.std(unbiased=False)
+ target_std = target.std(unbiased=False)
+ spearmanr = covariance / (pred_std * target_std + 1e-10)
+ return spearmanr
+
+
+def spmm_max(index: Tensor, value: Tensor, m: int, n: int,
+ matrix: Tensor) -> Tensor:
+ """
+ The same spmm kernel from torch_sparse
+ https://github.com/rusty1s/pytorch_sparse/blob/master/torch_sparse/spmm.py#L29
+
+ with the only change that instead of scatter_add aggregation
+ we keep scatter_max
+
+ Matrix product of sparse matrix with dense matrix.
+
+ Args:
+ index (:class:`LongTensor`): The index tensor of sparse matrix.
+ value (:class:`Tensor`): The value tensor of sparse matrix, either of
+ floating-point or integer type. Does not work for boolean and
+ complex number data types.
+ m (int): The first dimension of sparse matrix.
+ n (int): The second dimension of sparse matrix.
+ matrix (:class:`Tensor`): The dense matrix of same type as
+ :obj:`value`.
+
+ :rtype: :class:`Tensor`
+ """
+
+ assert n == matrix.size(-2)
+
+ row, col = index[0], index[1]
+ matrix = matrix if matrix.dim() > 1 else matrix.unsqueeze(-1)
+
+ out = matrix.index_select(-2, col)
+ out = out * value.unsqueeze(-1)
+ out = scatter_max(out, row, dim=-2, dim_size=m)[0]
+
+ return out
+
+def build_query_dataset(cfg):
+ data_config = copy.deepcopy(cfg.dataset)
+ cls = data_config.pop("class")
+
+ ds_cls = getattr(datasets_query, cls)
+ dataset = ds_cls(**data_config)
+
+ return dataset
+
+def cat(objs, *args, **kwargs):
+ """
+ Concatenate a list of nested containers with the same structure.
+ """
+ obj = objs[0]
+ if isinstance(obj, torch.Tensor):
+ return torch.cat(objs, *args, **kwargs)
+ elif isinstance(obj, dict):
+ return {k: cat([x[k] for x in objs], *args, **kwargs) for k in obj}
+ elif isinstance(obj, (list, tuple)):
+ return type(obj)(cat(xs, *args, **kwargs) for xs in zip(*objs))
+
+ raise TypeError("Can't perform concatenation over object type `%s`" % type(obj))
+
+def print_metrics(metrics, logger, roundto=4):
+ order = sorted(list(metrics.keys()))
+ for key in order:
+ logger.warning(f"{key}: {round(metrics[key], roundto)}")
+
+def print_metrics_to_file(metrics, results_file, roundto=4):
+ # round up all values in the metrics dict
+ metrics = {k: round(v,roundto) if type(v).__name__ != "str" else v for k,v in metrics.items()}
+ with open(results_file, "a", newline='') as csv_file:
+ fieldnames = sorted(list(metrics.keys()))
+ fieldnames.remove("dataset")
+ fieldnames = ['dataset']+fieldnames
+ writer = csv.DictWriter(csv_file, fieldnames=fieldnames, delimiter=',')
+ if csv_file.tell() == 0:
+ writer.writeheader()
+ writer.writerow(metrics)
+
+def cuda(obj, *args, **kwargs):
+ """
+ Transfer any nested container of tensors to CUDA.
+ """
+ if hasattr(obj, "cuda"):
+ return obj.cuda(*args, **kwargs)
+ elif isinstance(obj, (str, bytes)):
+ return obj
+ elif isinstance(obj, dict):
+ return type(obj)({k: cuda(v, *args, **kwargs) for k, v in obj.items()})
+ elif isinstance(obj, (list, tuple)):
+ return type(obj)(cuda(x, *args, **kwargs) for x in obj)
+
+ raise TypeError("Can't transfer object type `%s`" % type(obj))
\ No newline at end of file
diff --git a/ultra/ultraquery.py b/ultra/ultraquery.py
new file mode 100644
index 0000000..b981f90
--- /dev/null
+++ b/ultra/ultraquery.py
@@ -0,0 +1,298 @@
+import copy
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+from ultra.query_utils import Stack, spmm_max
+from ultra.tasks import build_relation_graph, edge_match
+from ultra.base_nbfnet import index_to_mask
+from torch_geometric.data import Data,Batch
+
+
+class UltraQuery(nn.Module):
+ """
+ Query executor for answering multi-hop logical queries.
+
+ Parameters:
+ model (nn.Module): GNN model (Ultra) that returns a distribution of scores over nodes
+ logic (str, optional): which fuzzy logic system to use, ``godel``, ``product`` or ``lukasiewicz``
+ dropout_ratio (float, optional): ratio for traversal dropout
+ threshold (float, optional): a score threshold for inductive models pre-trained only on 1-hop link prediction
+ more_dropout (float, optional): even more edge dropout because we like to regularize (who doesn't?)
+ """
+
+ stack_size = 2
+
+ def __init__(self, model, logic="product", dropout_ratio=0.25, threshold=0.0, more_dropout=0.0):
+ super(UltraQuery, self).__init__()
+ self.model = RelationProjection(model, threshold)
+ self.symbolic_model = SymbolicTraversal()
+ self.logic = logic
+ self.dropout_ratio = dropout_ratio
+ self.more_dropout = more_dropout
+
+ def traversal_dropout(self, graph, h_prob, r_index):
+ """Dropout edges that can be directly traversed to create an incomplete graph."""
+ sample, h_index = h_prob.nonzero().t()
+ r_index = r_index[sample]
+
+ # p1: find all tails
+ direct_ei = torch.vstack([graph.edge_index[0], graph.edge_type])
+ direct_query = torch.vstack([h_index, r_index])
+ direct_mask = edge_match(direct_ei, direct_query)[0]
+ # p2: find heads with inverses
+ # CAUTION: in some datasets, inverse edge type is rel+1, in some it is rel + num_rel/2
+ inverse_ei = torch.vstack([graph.edge_type, graph.edge_index[1]])
+ inverse_rel_plus_one = getattr(graph, 'inverse_rel_plus_one', False)
+ if inverse_rel_plus_one:
+ inverse_r_index = r_index ^ 1
+ else:
+ inverse_r_index = torch.where(r_index >= graph.num_relations // 2, r_index - graph.num_relations // 2, r_index + graph.num_relations // 2)
+ inv_query = torch.vstack([inverse_r_index, h_index])
+ inverse_mask = edge_match(inverse_ei, inv_query)[0]
+
+ edge_index = torch.cat([direct_mask, inverse_mask])
+
+ # don't remove edges that break the graph into separate connected components
+ h_index, t_index = graph.edge_index
+ degree_h = h_index.bincount()
+ degree_t = t_index.bincount()
+ h_index, t_index = graph.edge_index[:, edge_index]
+ must_keep = (degree_h[h_index] <= 1) | (degree_t[t_index] <= 1)
+ edge_index = edge_index[~must_keep]
+
+ is_sampled = torch.rand(len(edge_index), device=graph.edge_index.device) <= self.dropout_ratio
+ edge_index = edge_index[is_sampled]
+
+ if self.more_dropout > 0.0:
+ # More general edge dropout
+ more_drop_mask = torch.rand(graph.edge_index.shape[1], device=graph.edge_index.device) <= self.more_dropout
+ more_drop_edges = more_drop_mask.nonzero().squeeze(1)
+ h_index, t_index = graph.edge_index[:, more_drop_edges] # maybe add the main edge_index here as well
+ must_keep = (degree_h[h_index] <= 1) | (degree_t[t_index] <= 1)
+ more_drop_edges = more_drop_edges[~must_keep]
+ # Add to the main edge dropout
+ edge_index = torch.cat([edge_index, more_drop_edges]).unique()
+
+ mask = ~index_to_mask(edge_index, graph.num_edges)
+
+ graph = copy.copy(graph)
+ graph.edge_index = graph.edge_index[:, mask]
+ graph.edge_type = graph.edge_type[mask]
+
+ return graph
+
+ def execute(self, graph, query, symbolic_traversal):
+ """Execute queries on the graph.
+ symbolic_traversal is needed at training time for dropout
+ and can be turned off for inference
+ """
+ self.symbolic_traversal = symbolic_traversal
+ if self.training:
+ assert self.symbolic_traversal is True, "symbolic_traversal is needed at train time for dropout"
+
+ # we use stacks to execute postfix notations
+ # check out this tutorial if you are not familiar with the algorithm
+ # https://www.andrew.cmu.edu/course/15-121/lectures/Stacks%20and%20Queues/Stacks%20and%20Queues.html
+ batch_size = len(query)
+ # we execute a neural model and a symbolic model at the same time
+ # the symbolic model is used for traversal dropout at training time
+ # the elements in a stack are fuzzy sets
+ self.stack = Stack(batch_size, self.stack_size, graph.num_nodes, device=query.device)
+ self.var = Stack(batch_size, query.shape[1], graph.num_nodes, device=query.device)
+
+ if self.symbolic_traversal:
+ self.symbolic_stack = Stack(batch_size, self.stack_size, graph.num_nodes, device=query.device)
+ self.symbolic_var = Stack(batch_size, query.shape[1], graph.num_nodes, device=query.device)
+
+ # instruction pointer
+ self.IP = torch.zeros(batch_size, dtype=torch.long, device=query.device)
+ all_sample = torch.ones(batch_size, dtype=torch.bool, device=query.device)
+ op = query[all_sample, self.IP]
+
+ while not op.is_stop().all():
+ is_operand = op.is_operand()
+ is_intersection = op.is_intersection()
+ is_union = op.is_union()
+ is_negation = op.is_negation()
+ is_projection = op.is_projection()
+ if is_operand.any():
+ h_index = op[is_operand].get_operand()
+ self.apply_operand(is_operand, h_index, graph.num_nodes)
+ if is_intersection.any():
+ self.apply_intersection(is_intersection)
+ if is_union.any():
+ self.apply_union(is_union)
+ if is_negation.any():
+ self.apply_negation(is_negation)
+ # only execute projection when there are no other operations
+ # since projection is the most expensive and we want to maximize the parallelism
+ if not (is_operand | is_negation | is_intersection | is_union).any() and is_projection.any():
+ r_index = op[is_projection].get_operand()
+ self.apply_projection(is_projection, graph, r_index)
+ op = query[all_sample, self.IP]
+
+ if (self.stack.SP > 1).any():
+ raise ValueError("More operands than expected")
+
+ def forward(self, graph, query, symbolic_traversal=True):
+ self.execute(graph, query, symbolic_traversal)
+
+ # convert probability to logit for compatibility reasons
+ t_prob = self.stack.pop()
+ t_logit = ((t_prob + 1e-10) / (1 - t_prob + 1e-10)).log()
+ return t_logit
+
+
+ def apply_operand(self, mask, h_index, num_node):
+ h_prob = F.one_hot(h_index, num_node).float()
+ self.stack.push(mask, h_prob)
+ self.var.push(mask, h_prob)
+ if self.symbolic_traversal:
+ self.symbolic_stack.push(mask, h_prob)
+ self.symbolic_var.push(mask, h_prob)
+ self.IP[mask] += 1
+
+ def apply_intersection(self, mask):
+ y_prob = self.stack.pop(mask)
+ x_prob = self.stack.pop(mask)
+ z_prob = self.conjunction(x_prob, y_prob)
+ self.stack.push(mask, z_prob)
+ self.var.push(mask, z_prob)
+ if self.symbolic_traversal:
+ sym_y_prob = self.symbolic_stack.pop(mask)
+ sym_x_prob = self.symbolic_stack.pop(mask)
+ sym_z_prob = self.conjunction(sym_x_prob, sym_y_prob)
+ self.symbolic_stack.push(mask, sym_z_prob)
+ self.symbolic_var.push(mask, sym_z_prob)
+ self.IP[mask] += 1
+
+ def apply_union(self, mask):
+ y_prob = self.stack.pop(mask)
+ x_prob = self.stack.pop(mask)
+ z_prob = self.disjunction(x_prob, y_prob)
+ self.stack.push(mask, z_prob)
+ self.var.push(mask, z_prob)
+ if self.symbolic_traversal:
+ sym_y_prob = self.symbolic_stack.pop(mask)
+ sym_x_prob = self.symbolic_stack.pop(mask)
+ sym_z_prob = self.disjunction(sym_x_prob, sym_y_prob)
+ self.symbolic_stack.push(mask, sym_z_prob)
+ self.symbolic_var.push(mask, sym_z_prob)
+ self.IP[mask] += 1
+
+ def apply_negation(self, mask):
+ x_prob = self.stack.pop(mask)
+ y_prob = self.negation(x_prob)
+ self.stack.push(mask, y_prob)
+ self.var.push(mask, y_prob)
+ if self.symbolic_traversal:
+ sym_x_prob = self.symbolic_stack.pop(mask)
+ sym_y_prob = self.negation(sym_x_prob)
+ self.symbolic_stack.push(mask, sym_y_prob)
+ self.symbolic_var.push(mask, sym_y_prob)
+ self.IP[mask] += 1
+
+ def apply_projection(self, mask, graph, r_index):
+ h_prob = self.stack.pop(mask)
+ if self.training:
+ sym_h_prob = self.symbolic_stack.pop(mask)
+ # apply traversal dropout based on the output of the symbolic model
+ graph = self.traversal_dropout(graph, sym_h_prob, r_index)
+ # also changing the relation graph because of the changed main graph
+ graph = build_relation_graph(graph)
+ else:
+ if self.symbolic_traversal:
+ sym_h_prob = self.symbolic_stack.pop(mask)
+
+ # detach the variable to stabilize training
+ h_prob = h_prob.detach()
+ t_prob = self.model(graph, h_prob, r_index)
+ self.stack.push(mask, t_prob)
+ self.var.push(mask, t_prob)
+
+ if self.symbolic_traversal:
+ sym_t_prob = self.symbolic_model(graph, sym_h_prob, r_index)
+ self.symbolic_stack.push(mask, sym_t_prob)
+ self.symbolic_var.push(mask, sym_t_prob)
+
+ self.IP[mask] += 1
+
+ def conjunction(self, x, y):
+ if self.logic == "godel":
+ return torch.min(x, y)
+ elif self.logic == "product":
+ return x * y
+ elif self.logic == "lukasiewicz":
+ return (x + y - 1).clamp(min=0)
+ else:
+ raise ValueError("Unknown fuzzy logic `%s`" % self.logic)
+
+ def disjunction(self, x, y):
+ if self.logic == "godel":
+ return torch.max(x, y)
+ elif self.logic == "product":
+ return x + y - x * y
+ elif self.logic == "lukasiewicz":
+ return (x + y).clamp(max=1)
+ else:
+ raise ValueError("Unknown fuzzy logic `%s`" % self.logic)
+
+ def negation(self, x):
+ return 1 - x
+
+
+class RelationProjection(nn.Module):
+ """Wrap a GNN model for relation projection."""
+
+ def __init__(self, model, threshold=0.0):
+ super(RelationProjection, self).__init__()
+ self.model = model
+ self.threshold = threshold
+
+ def forward(self, graph, h_prob, r_index):
+
+ bs = r_index.shape[0]
+
+ # GNN model: get relation representations conditioned on the query r_index
+ rel_reprs = self.model.relation_model(graph.relation_graph, query=r_index) # (bs, num_rel, dim)
+ query = rel_reprs[torch.arange(bs, device=r_index.device), r_index] # (bs, dim)
+
+ # initialize the input with the fuzzy set and query relation
+ input = torch.einsum("bn, bd -> bnd", h_prob, query)
+
+ # GNNs trained on link prediction exhibit the multi-source propagation issue (see the paper)
+ # We can partly alleviate it by thresholding intermediate scores
+ if self.threshold > 0.0:
+ temp_prob = h_prob.clone()
+ # if self.threshold > 0.0:
+ temp_prob[temp_prob <= self.threshold] = 0.0
+ input = torch.einsum("bn, bd -> bnd", temp_prob, query)
+
+ # GNN model: run the entity-level reasoner to get a scalar distribution over nodes
+ output = self.model.entity_model(graph, input, rel_reprs, query)
+ # Turn into probs
+ t_prob = F.sigmoid(output)
+
+ return t_prob
+
+
+
+class SymbolicTraversal(nn.Module):
+ """Symbolic traversal algorithm."""
+
+ def forward(self, graph, h_prob, r_index):
+ batch_size = len(h_prob)
+
+ # For each query relation in the batch of r_index, we need to extract subgraphs induced by those edge types
+ # OG torchdrug uses perfect hashing for the matching process, here we'll use vmap over the batch dim
+ # Still not the most efficient method though, suggestions are welcome
+ mask = torch.vmap(lambda t1, t2: t1 == t2 )(graph.edge_type.repeat(batch_size,1), r_index.unsqueeze(1))
+ # Creating one big graph from all the subgraphs for one spmm_max function call
+ graph = Batch.from_data_list([
+ Data(edge_index=graph.edge_index[:, mask[i]], num_nodes=graph.num_nodes, device=graph.edge_index.device) for i in range(batch_size)
+ ])
+
+ t_prob = spmm_max(graph.edge_index.flip(0), torch.ones(graph.num_edges, device=h_prob.device), graph.num_nodes, graph.num_nodes, h_prob.view(-1, 1)).clamp(min=0)
+
+ return t_prob.view_as(h_prob)
diff --git a/ultra/variadic.py b/ultra/variadic.py
new file mode 100644
index 0000000..4635b38
--- /dev/null
+++ b/ultra/variadic.py
@@ -0,0 +1,364 @@
+import torch
+from torch_scatter import scatter_add, scatter_mean, scatter_max
+from torch_scatter.composite import scatter_log_softmax, scatter_softmax
+from torch.nn import functional as F
+
+"""
+Some variadic functions adopted from TorchDrug
+https://github.com/DeepGraphLearning/torchdrug/blob/master/torchdrug/layers/functional/functional.py
+"""
+
+def masked_mean(input, mask, dim=None, keepdim=False):
+ """
+ Masked mean of a tensor.
+
+ Parameters:
+ input (Tensor): input tensor
+ mask (BoolTensor): mask tensor
+ dim (int or tuple of int, optional): dimension to reduce
+ keepdim (bool, optional): whether retain ``dim`` or not
+ """
+ input = input.masked_scatter(~mask, torch.zeros_like(input)) # safe with nan
+ if dim is None:
+ return input.sum() / mask.sum().clamp(1)
+ return input.sum(dim, keepdim=keepdim) / mask.sum(dim, keepdim=keepdim).clamp(1)
+
+
+def mean_with_nan(input, dim=None, keepdim=False):
+ """
+ Mean of a tensor. Ignore all nan values.
+
+ Parameters:
+ input (Tensor): input tensor
+ dim (int or tuple of int, optional): dimension to reduce
+ keepdim (bool, optional): whether retain ``dim`` or not
+ """
+ mask = ~torch.isnan(input)
+ return masked_mean(input, mask, dim, keepdim)
+
+
+def multi_slice(starts, ends):
+ """
+ Compute the union of indexes in multiple slices.
+
+ Example::
+
+ >>> mask = multi_slice(torch.tensor([0, 1, 4]), torch.tensor([2, 3, 6]), 6)
+ >>> assert (mask == torch.tensor([0, 1, 2, 4, 5]).all()
+
+ Parameters:
+ starts (LongTensor): start indexes of slices
+ ends (LongTensor): end indexes of slices
+ """
+ values = torch.cat([torch.ones_like(starts), -torch.ones_like(ends)])
+ slices = torch.cat([starts, ends])
+ slices, order = slices.sort()
+ values = values[order]
+ depth = values.cumsum(0)
+ valid = ((values == 1) & (depth == 1)) | ((values == -1) & (depth == 0))
+ slices = slices[valid]
+
+ starts, ends = slices.view(-1, 2).t()
+ size = ends - starts
+ indexes = variadic_arange(size)
+ indexes = indexes + starts.repeat_interleave(size)
+ return indexes
+
+
+def multi_slice_mask(starts, ends, length):
+ """
+ Compute the union of multiple slices into a binary mask.
+
+ Example::
+
+ >>> mask = multi_slice_mask(torch.tensor([0, 1, 4]), torch.tensor([2, 3, 6]), 6)
+ >>> assert (mask == torch.tensor([1, 1, 1, 0, 1, 1])).all()
+
+ Parameters:
+ starts (LongTensor): start indexes of slices
+ ends (LongTensor): end indexes of slices
+ length (int): length of mask
+ """
+ values = torch.cat([torch.ones_like(starts), -torch.ones_like(ends)])
+ slices = torch.cat([starts, ends])
+ if slices.numel():
+ assert slices.min() >= 0 and slices.max() <= length
+ mask = scatter_add(values, slices, dim=0, dim_size=length + 1)[:-1]
+ mask = mask.cumsum(0).bool()
+ return mask
+
+
+def _extend(data, size, input, input_size):
+ """
+ Extend variadic-sized data with variadic-sized input.
+ This is a variadic variant of ``torch.cat([data, input], dim=-1)``.
+
+ Example::
+
+ >>> data = torch.tensor([0, 1, 2, 3, 4])
+ >>> size = torch.tensor([3, 2])
+ >>> input = torch.tensor([-1, -2, -3])
+ >>> input_size = torch.tensor([1, 2])
+ >>> new_data, new_size = _extend(data, size, input, input_size)
+ >>> assert (new_data == torch.tensor([0, 1, 2, -1, 3, 4, -2, -3])).all()
+ >>> assert (new_size == torch.tensor([4, 4])).all()
+
+ Parameters:
+ data (Tensor): variadic data
+ size (LongTensor): size of data
+ input (Tensor): variadic input
+ input_size (LongTensor): size of input
+
+ Returns:
+ (Tensor, LongTensor): output data, output size
+ """
+ new_size = size + input_size
+ new_cum_size = new_size.cumsum(0)
+ new_data = torch.zeros(new_cum_size[-1], *data.shape[1:], dtype=data.dtype, device=data.device)
+ starts = new_cum_size - new_size
+ ends = starts + size
+ index = multi_slice_mask(starts, ends, new_cum_size[-1])
+ new_data[index] = data
+ new_data[~index] = input
+ return new_data, new_size
+
+
+def variadic_sum(input, size):
+ """
+ Compute sum over sets with variadic sizes.
+
+ Suppose there are :math:`N` sets, and the sizes of all sets are summed to :math:`B`.
+
+ Parameters:
+ input (Tensor): input of shape :math:`(B, ...)`
+ size (LongTensor): size of sets of shape :math:`(N,)`
+ """
+ index2sample = torch.repeat_interleave(size)
+ index2sample = index2sample.view([-1] + [1] * (input.ndim - 1))
+ index2sample = index2sample.expand_as(input)
+
+ value = scatter_add(input, index2sample, dim=0)
+ return value
+
+
+def variadic_mean(input, size):
+ """
+ Compute mean over sets with variadic sizes.
+
+ Suppose there are :math:`N` sets, and the sizes of all sets are summed to :math:`B`.
+
+ Parameters:
+ input (Tensor): input of shape :math:`(B, ...)`
+ size (LongTensor): size of sets of shape :math:`(N,)`
+ """
+ index2sample = torch.repeat_interleave(size)
+ index2sample = index2sample.view([-1] + [1] * (input.ndim - 1))
+ index2sample = index2sample.expand_as(input)
+
+ value = scatter_mean(input, index2sample, dim=0)
+ return value
+
+
+def variadic_max(input, size):
+ """
+ Compute max over sets with variadic sizes.
+
+ Suppose there are :math:`N` sets, and the sizes of all sets are summed to :math:`B`.
+
+ Parameters:
+ input (Tensor): input of shape :math:`(B, ...)`
+ size (LongTensor): size of sets of shape :math:`(N,)`
+
+ Returns
+ (Tensor, LongTensor): max values and indexes
+ """
+ index2sample = torch.repeat_interleave(size)
+ index2sample = index2sample.view([-1] + [1] * (input.ndim - 1))
+ index2sample = index2sample.expand_as(input)
+
+ value, index = scatter_max(input, index2sample, dim=0)
+ index = index + (size - size.cumsum(0)).view([-1] + [1] * (index.ndim - 1))
+ return value, index
+
+
+def variadic_log_softmax(input, size):
+ """
+ Compute log softmax over categories with variadic sizes.
+
+ Suppose there are :math:`N` samples, and the numbers of categories in all samples are summed to :math:`B`.
+
+ Parameters:
+ input (Tensor): input of shape :math:`(B, ...)`
+ size (LongTensor): number of categories of shape :math:`(N,)`
+ """
+ index2sample = torch.repeat_interleave(size)
+ index2sample = index2sample.view([-1] + [1] * (input.ndim - 1))
+ index2sample = index2sample.expand_as(input)
+
+ log_likelihood = scatter_log_softmax(input, index2sample, dim=0)
+ return log_likelihood
+
+
+def variadic_softmax(input, size):
+ """
+ Compute softmax over categories with variadic sizes.
+
+ Suppose there are :math:`N` samples, and the numbers of categories in all samples are summed to :math:`B`.
+
+ Parameters:
+ input (Tensor): input of shape :math:`(B, ...)`
+ size (LongTensor): number of categories of shape :math:`(N,)`
+ """
+ index2sample = torch.repeat_interleave(size)
+ index2sample = index2sample.view([-1] + [1] * (input.ndim - 1))
+ index2sample = index2sample.expand_as(input)
+
+ log_likelihood = scatter_softmax(input, index2sample, dim=0)
+ return log_likelihood
+
+
+def variadic_cross_entropy(input, target, size, reduction="mean"):
+ """
+ Compute cross entropy loss over categories with variadic sizes.
+
+ Suppose there are :math:`N` samples, and the numbers of categories in all samples are summed to :math:`B`.
+
+ Parameters:
+ input (Tensor): prediction of shape :math:`(B, ...)`
+ target (Tensor): target of shape :math:`(N, ...)`. Each target is a relative index in a sample.
+ size (LongTensor): number of categories of shape :math:`(N,)`
+ reduction (string, optional): reduction to apply to the output.
+ Available reductions are ``none``, ``sum`` and ``mean``.
+ """
+ index2sample = torch.repeat_interleave(size)
+ index2sample = index2sample.view([-1] + [1] * (input.ndim - 1))
+ index2sample = index2sample.expand_as(input)
+
+ log_likelihood = scatter_log_softmax(input, index2sample, dim=0)
+ size = size.view([-1] + [1] * (input.ndim - 1))
+ assert (target >= 0).all() and (target < size).all()
+ target_index = target + size.cumsum(0) - size
+ loss = -log_likelihood.gather(0, target_index)
+ if reduction == "mean":
+ return loss.mean()
+ elif reduction == "sum":
+ return loss.sum()
+ elif reduction == "none":
+ return loss
+ else:
+ raise ValueError("Unknown reduction `%s`" % reduction)
+
+
+def variadic_topk(input, size, k, largest=True):
+ """
+ Compute the :math:`k` largest elements over sets with variadic sizes.
+
+ Suppose there are :math:`N` sets, and the sizes of all sets are summed to :math:`B`.
+
+ If any set has less than than :math:`k` elements, the size-th largest element will be
+ repeated to pad the output to :math:`k`.
+
+ Parameters:
+ input (Tensor): input of shape :math:`(B, ...)`
+ size (LongTensor): size of sets of shape :math:`(N,)`
+ k (int or LongTensor): the k in "top-k". Can be a fixed value for all sets,
+ or different values for different sets of shape :math:`(N,)`.
+ largest (bool, optional): return largest or smallest elements
+
+ Returns
+ (Tensor, LongTensor): top-k values and indexes
+ """
+ index2graph = torch.repeat_interleave(size)
+ index2graph = index2graph.view([-1] + [1] * (input.ndim - 1))
+
+ mask = ~torch.isinf(input)
+ max = input[mask].max().item()
+ min = input[mask].min().item()
+ abs_max = input[mask].abs().max().item()
+ # special case: max = min
+ gap = max - min + abs_max * 1e-6
+ safe_input = input.clamp(min - gap, max + gap)
+ offset = gap * 4
+ if largest:
+ offset = -offset
+ input_ext = safe_input + offset * index2graph
+ index_ext = input_ext.argsort(dim=0, descending=largest)
+ if isinstance(k, torch.Tensor) and k.shape == size.shape:
+ num_actual = torch.min(size, k)
+ else:
+ num_actual = size.clamp(max=k)
+ num_padding = k - num_actual
+ starts = size.cumsum(0) - size
+ ends = starts + num_actual
+ mask = multi_slice_mask(starts, ends, len(index_ext)).nonzero().flatten()
+
+ if (num_padding > 0).any():
+ # special case: size < k, pad with the last valid index
+ padding = ends - 1
+ padding2graph = torch.repeat_interleave(num_padding)
+ mask = _extend(mask, num_actual, padding[padding2graph], num_padding)[0]
+
+ index = index_ext[mask] # (N * k, ...)
+ value = input.gather(0, index)
+ if isinstance(k, torch.Tensor) and k.shape == size.shape:
+ value = value.view(-1, *input.shape[1:])
+ index = index.view(-1, *input.shape[1:])
+ index = index - (size.cumsum(0) - size).repeat_interleave(k).view([-1] + [1] * (index.ndim - 1))
+ else:
+ value = value.view(-1, k, *input.shape[1:])
+ index = index.view(-1, k, *input.shape[1:])
+ index = index - (size.cumsum(0) - size).view([-1] + [1] * (index.ndim - 1))
+
+ return value, index
+
+
+def variadic_sort(input, size, descending=False):
+ """
+ Sort elements in sets with variadic sizes.
+
+ Suppose there are :math:`N` sets, and the sizes of all sets are summed to :math:`B`.
+
+ Parameters:
+ input (Tensor): input of shape :math:`(B, ...)`
+ size (LongTensor): size of sets of shape :math:`(N,)`
+ descending (bool, optional): return ascending or descending order
+
+ Returns
+ (Tensor, LongTensor): sorted values and indexes
+ """
+ index2sample = torch.repeat_interleave(size)
+ index2sample = index2sample.view([-1] + [1] * (input.ndim - 1))
+
+ mask = ~torch.isinf(input)
+ max = input[mask].max().item()
+ min = input[mask].min().item()
+ abs_max = input[mask].abs().max().item()
+ # special case: max = min
+ gap = max - min + abs_max * 1e-6
+ safe_input = input.clamp(min - gap, max + gap)
+ offset = gap * 4
+ if descending:
+ offset = -offset
+ input_ext = safe_input + offset * index2sample
+ index = input_ext.argsort(dim=0, descending=descending)
+ value = input.gather(0, index)
+ index = index - (size.cumsum(0) - size)[index2sample]
+ return value, index
+
+
+def variadic_arange(size):
+ """
+ Return a 1-D tensor that contains integer intervals of variadic sizes.
+ This is a variadic variant of ``torch.arange(stop).expand(batch_size, -1)``.
+
+ Suppose there are :math:`N` intervals.
+
+ Parameters:
+ size (LongTensor): size of intervals of shape :math:`(N,)`
+ """
+ starts = size.cumsum(0) - size
+
+ range = torch.arange(size.sum(), device=size.device)
+ range = range - starts.repeat_interleave(size)
+ return range
+