diff --git a/pypots/imputation/__init__.py b/pypots/imputation/__init__.py index 3a9ba6ec..ed0fbd84 100644 --- a/pypots/imputation/__init__.py +++ b/pypots/imputation/__init__.py @@ -23,6 +23,7 @@ from .crossformer import Crossformer from .informer import Informer from .autoformer import Autoformer +from .reformer import Reformer from .dlinear import DLinear from .patchtst import PatchTST from .usgan import USGAN @@ -54,6 +55,7 @@ "DLinear", "Informer", "Autoformer", + "Reformer", "NonstationaryTransformer", "Pyraformer", "BRITS", diff --git a/pypots/imputation/reformer/__init__.py b/pypots/imputation/reformer/__init__.py new file mode 100644 index 00000000..ddd255bf --- /dev/null +++ b/pypots/imputation/reformer/__init__.py @@ -0,0 +1,25 @@ +""" +The package of the partially-observed time-series imputation model Reformer. + +Refer to the paper +`Kitaev, Nikita, Ɓukasz Kaiser, and Anselm Levskaya. +Reformer: The Efficient Transformer. +International Conference on Learning Representations, 2020. +`_ + +Notes +----- +This implementation is inspired by the official one https://github.com/google/trax/tree/master/trax/models/reformer and +https://github.com/lucidrains/reformer-pytorch + +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + + +from .model import Reformer + +__all__ = [ + "Reformer", +] diff --git a/pypots/imputation/reformer/core.py b/pypots/imputation/reformer/core.py new file mode 100644 index 00000000..c1c70fe4 --- /dev/null +++ b/pypots/imputation/reformer/core.py @@ -0,0 +1,88 @@ +""" +The core wrapper assembles the submodules of Reformer imputation model +and takes over the forward progress of the algorithm. +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + +import torch.nn as nn + +from ...nn.modules.reformer import ReformerEncoder +from ...nn.modules.saits import SaitsLoss, SaitsEmbedding + + +class _Reformer(nn.Module): + def __init__( + self, + n_steps, + n_features, + n_layers, + d_model, + n_heads, + bucket_size, + n_hashes, + causal, + d_ffn, + dropout, + ORT_weight: float = 1, + MIT_weight: float = 1, + ): + super().__init__() + + self.n_steps = n_steps + + self.saits_embedding = SaitsEmbedding( + n_features * 2, + d_model, + with_pos=False, + dropout=dropout, + ) + self.encoder = ReformerEncoder( + n_steps, + n_layers, + d_model, + n_heads, + bucket_size, + n_hashes, + causal, + d_ffn, + dropout, + ) + + # for the imputation task, the output dim is the same as input dim + self.output_projection = nn.Linear(d_model, n_features) + self.saits_loss_func = SaitsLoss(ORT_weight, MIT_weight) + + def forward(self, inputs: dict, training: bool = True) -> dict: + X, missing_mask = inputs["X"], inputs["missing_mask"] + + # WDU: the original Reformer paper isn't proposed for imputation task. Hence the model doesn't take + # the missing mask into account, which means, in the process, the model doesn't know which part of + # the input data is missing, and this may hurt the model's imputation performance. Therefore, I apply the + # SAITS embedding method to project the concatenation of features and masks into a hidden space, as well as + # the output layers to project back from the hidden space to the original space. + enc_out = self.saits_embedding(X, missing_mask) + + # Reformer encoder processing + enc_out = self.encoder(enc_out) + # project back the original data space + reconstruction = self.output_projection(enc_out) + + imputed_data = missing_mask * X + (1 - missing_mask) * reconstruction + results = { + "imputed_data": imputed_data, + } + + # if in training mode, return results with losses + if training: + X_ori, indicating_mask = inputs["X_ori"], inputs["indicating_mask"] + loss, ORT_loss, MIT_loss = self.saits_loss_func( + reconstruction, X_ori, missing_mask, indicating_mask + ) + results["ORT_loss"] = ORT_loss + results["MIT_loss"] = MIT_loss + # `loss` is always the item for backward propagating to update the model + results["loss"] = loss + + return results diff --git a/pypots/imputation/reformer/data.py b/pypots/imputation/reformer/data.py new file mode 100644 index 00000000..63f29969 --- /dev/null +++ b/pypots/imputation/reformer/data.py @@ -0,0 +1,24 @@ +""" +Dataset class for Reformer. +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + +from typing import Union + +from ..saits.data import DatasetForSAITS + + +class DatasetForReformer(DatasetForSAITS): + """Actually Reformer uses the same data strategy as SAITS, needs MIT for training.""" + + def __init__( + self, + data: Union[dict, str], + return_X_ori: bool, + return_y: bool, + file_type: str = "hdf5", + rate: float = 0.2, + ): + super().__init__(data, return_X_ori, return_y, file_type, rate) diff --git a/pypots/imputation/reformer/model.py b/pypots/imputation/reformer/model.py new file mode 100644 index 00000000..47c21664 --- /dev/null +++ b/pypots/imputation/reformer/model.py @@ -0,0 +1,331 @@ +""" +The implementation of Reformer for the partially-observed time-series imputation task. + +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + +from typing import Union, Optional + +import numpy as np +import torch +from torch.utils.data import DataLoader + +from .core import _Reformer +from .data import DatasetForReformer +from ..base import BaseNNImputer +from ...data.checking import key_in_data_set +from ...data.dataset import BaseDataset +from ...optim.adam import Adam +from ...optim.base import Optimizer + + +class Reformer(BaseNNImputer): + """The PyTorch implementation of the Reformer model. + Reformer is originally proposed by Kitaev et al. in :cite:`kitaev2020reformer`. + + Parameters + ---------- + n_steps : + The number of time steps in the time-series data sample. + + n_features : + The number of features in the time-series data sample. + + n_layers : + The number of layers in the Reformer model. + + d_model : + The dimension of the model. + + n_heads : + The number of heads in each layer of Reformer. + + bucket_size : + Average size of qk per bucket, 64 was recommended in paper. + + n_hashes : + 4 is permissible per author, 8 is the best but slower. + + causal : + Auto-regressive or not. + + d_ffn : + The dimension of the feed-forward network. + The window size of moving average. + + dropout : + The dropout rate for the model. + + ORT_weight : + The weight for the ORT loss, the same as SAITS. + + MIT_weight : + The weight for the MIT loss, the same as SAITS. + + batch_size : + The batch size for training and evaluating the model. + + epochs : + The number of epochs for training the model. + + patience : + The patience for the early-stopping mechanism. Given a positive integer, the training process will be + stopped when the model does not perform better after that number of epochs. + Leaving it default as None will disable the early-stopping. + + optimizer : + The optimizer for model training. + If not given, will use a default Adam optimizer. + + num_workers : + The number of subprocesses to use for data loading. + `0` means data loading will be in the main process, i.e. there won't be subprocesses. + + device : + The device for the model to run on. It can be a string, a :class:`torch.device` object, or a list of them. + If not given, will try to use CUDA devices first (will use the default CUDA device if there are multiple), + then CPUs, considering CUDA and CPU are so far the main devices for people to train ML models. + If given a list of devices, e.g. ['cuda:0', 'cuda:1'], or [torch.device('cuda:0'), torch.device('cuda:1')] , the + model will be parallely trained on the multiple devices (so far only support parallel training on CUDA devices). + Other devices like Google TPU and Apple Silicon accelerator MPS may be added in the future. + + saving_path : + The path for automatically saving model checkpoints and tensorboard files (i.e. loss values recorded during + training into a tensorboard file). Will not save if not given. + + model_saving_strategy : + The strategy to save model checkpoints. It has to be one of [None, "best", "better", "all"]. + No model will be saved when it is set as None. + The "best" strategy will only automatically save the best model after the training finished. + The "better" strategy will automatically save the model during training whenever the model performs + better than in previous epochs. + The "all" strategy will save every model after each epoch training. + + verbose : + Whether to print out the training logs during the training process. + """ + + def __init__( + self, + n_steps: int, + n_features: int, + n_layers: int, + d_model: int, + n_heads: int, + bucket_size: int, + n_hashes: int, + causal: bool, + d_ffn: int, + dropout: float = 0, + ORT_weight: float = 1, + MIT_weight: float = 1, + batch_size: int = 32, + epochs: int = 100, + patience: int = None, + optimizer: Optional[Optimizer] = Adam(), + num_workers: int = 0, + device: Optional[Union[str, torch.device, list]] = None, + saving_path: str = None, + model_saving_strategy: Optional[str] = "best", + verbose: bool = True, + ): + super().__init__( + batch_size, + epochs, + patience, + num_workers, + device, + saving_path, + model_saving_strategy, + verbose, + ) + + self.n_steps = n_steps + self.n_features = n_features + # model hype-parameters + self.n_heads = n_heads + self.n_layers = n_layers + self.d_model = d_model + self.bucket_size = bucket_size + self.n_hashes = n_hashes + self.causal = causal + self.d_ffn = d_ffn + self.dropout = dropout + self.ORT_weight = ORT_weight + self.MIT_weight = MIT_weight + + # set up the model + self.model = _Reformer( + self.n_steps, + self.n_features, + self.n_layers, + self.d_model, + self.n_heads, + self.bucket_size, + self.n_hashes, + self.causal, + self.d_ffn, + self.dropout, + self.ORT_weight, + self.MIT_weight, + ) + self._send_model_to_given_device() + self._print_model_size() + + # set up the optimizer + self.optimizer = optimizer + self.optimizer.init_optimizer(self.model.parameters()) + + def _assemble_input_for_training(self, data: list) -> dict: + ( + indices, + X, + missing_mask, + X_ori, + indicating_mask, + ) = self._send_data_to_given_device(data) + + inputs = { + "X": X, + "missing_mask": missing_mask, + "X_ori": X_ori, + "indicating_mask": indicating_mask, + } + + return inputs + + def _assemble_input_for_validating(self, data: list) -> dict: + return self._assemble_input_for_training(data) + + def _assemble_input_for_testing(self, data: list) -> dict: + indices, X, missing_mask = self._send_data_to_given_device(data) + + inputs = { + "X": X, + "missing_mask": missing_mask, + } + + return inputs + + def fit( + self, + train_set: Union[dict, str], + val_set: Optional[Union[dict, str]] = None, + file_type: str = "hdf5", + ) -> None: + # Step 1: wrap the input data with classes Dataset and DataLoader + training_set = DatasetForReformer( + train_set, return_X_ori=False, return_y=False, file_type=file_type + ) + training_loader = DataLoader( + training_set, + batch_size=self.batch_size, + shuffle=True, + num_workers=self.num_workers, + ) + val_loader = None + if val_set is not None: + if not key_in_data_set("X_ori", val_set): + raise ValueError("val_set must contain 'X_ori' for model validation.") + val_set = DatasetForReformer( + val_set, return_X_ori=True, return_y=False, file_type=file_type + ) + val_loader = DataLoader( + val_set, + batch_size=self.batch_size, + shuffle=False, + num_workers=self.num_workers, + ) + + # Step 2: train the model and freeze it + self._train_model(training_loader, val_loader) + self.model.load_state_dict(self.best_model_dict) + self.model.eval() # set the model as eval status to freeze it. + + # Step 3: save the model if necessary + self._auto_save_model_if_necessary(confirm_saving=True) + + def predict( + self, + test_set: Union[dict, str], + file_type: str = "hdf5", + ) -> dict: + """Make predictions for the input data with the trained model. + + Parameters + ---------- + test_set : dict or str + The dataset for model validating, should be a dictionary including keys as 'X', + or a path string locating a data file supported by PyPOTS (e.g. h5 file). + If it is a dict, X should be array-like of shape [n_samples, sequence length (n_steps), n_features], + which is time-series data for validating, can contain missing values, and y should be array-like of shape + [n_samples], which is classification labels of X. + If it is a path string, the path should point to a data file, e.g. a h5 file, which contains + key-value pairs like a dict, and it has to include keys as 'X' and 'y'. + + file_type : + The type of the given file if test_set is a path string. + + Returns + ------- + file_type : + The dictionary containing the clustering results and latent variables if necessary. + + """ + # Step 1: wrap the input data with classes Dataset and DataLoader + self.model.eval() # set the model as eval status to freeze it. + test_set = BaseDataset( + test_set, + return_X_ori=False, + return_X_pred=False, + return_y=False, + file_type=file_type, + ) + test_loader = DataLoader( + test_set, + batch_size=self.batch_size, + shuffle=False, + num_workers=self.num_workers, + ) + imputation_collector = [] + + # Step 2: process the data with the model + with torch.no_grad(): + for idx, data in enumerate(test_loader): + inputs = self._assemble_input_for_testing(data) + results = self.model.forward(inputs, training=False) + imputation_collector.append(results["imputed_data"]) + + # Step 3: output collection and return + imputation = torch.cat(imputation_collector).cpu().detach().numpy() + result_dict = { + "imputation": imputation, + } + return result_dict + + def impute( + self, + test_set: Union[dict, str], + file_type: str = "hdf5", + ) -> np.ndarray: + """Impute missing values in the given data with the trained model. + + Parameters + ---------- + test_set : + The data samples for testing, should be array-like of shape [n_samples, sequence length (n_steps), + n_features], or a path string locating a data file, e.g. h5 file. + + file_type : + The type of the given file if X is a path string. + + Returns + ------- + array-like, shape [n_samples, sequence length (n_steps), n_features], + Imputed data. + """ + + result_dict = self.predict(test_set, file_type=file_type) + return result_dict["imputation"] diff --git a/pypots/nn/modules/reformer/__init__.py b/pypots/nn/modules/reformer/__init__.py new file mode 100644 index 00000000..84963555 --- /dev/null +++ b/pypots/nn/modules/reformer/__init__.py @@ -0,0 +1,25 @@ +""" +The package including the modules of Reformer. + +Refer to the paper +`Nikita Kitaev, Lukasz Kaiser, and Anselm Levskaya. +"Reformer: The Efficient Transformer". +In International Conference on Learning Representations, 2020. +`_ + +Notes +----- +This implementation is inspired by the official one https://github.com/google/trax/tree/master/trax/models/reformer and +https://github.com/lucidrains/reformer-pytorch + +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + + +from .autoencoder import ReformerEncoder + +__all__ = [ + "ReformerEncoder", +] diff --git a/pypots/nn/modules/reformer/autoencoder.py b/pypots/nn/modules/reformer/autoencoder.py new file mode 100644 index 00000000..956b8b13 --- /dev/null +++ b/pypots/nn/modules/reformer/autoencoder.py @@ -0,0 +1,54 @@ +""" + +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + +import torch +import torch.nn as nn + +from .layers import ReformerLayer + + +class ReformerEncoder(nn.Module): + def __init__( + self, + n_steps, + n_layers, + d_model, + n_heads, + bucket_size, + n_hashes, + causal, + d_ffn, + dropout, + ): + super().__init__() + + assert ( + n_steps % (bucket_size * 2) == 0 + ), f"Sequence length ({n_steps}) needs to be divisible by target bucket size x 2 - {bucket_size * 2}" + + self.enc_layer_stack = nn.ModuleList( + [ + ReformerLayer( + d_model, + n_heads, + bucket_size, + n_hashes, + causal, + d_ffn, + dropout, + ) + for _ in range(n_layers) + ] + ) + + def forward(self, x: torch.Tensor): + enc_output = x + + for layer in self.enc_layer_stack: + enc_output = layer(enc_output) + + return enc_output diff --git a/pypots/nn/modules/reformer/layers.py b/pypots/nn/modules/reformer/layers.py new file mode 100644 index 00000000..f7b0fdf4 --- /dev/null +++ b/pypots/nn/modules/reformer/layers.py @@ -0,0 +1,52 @@ +""" + +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + +import torch +import torch.nn as nn + +from .lsh_attention import LSHSelfAttention +from ..transformer import PositionWiseFeedForward + + +class ReformerLayer(nn.Module): + def __init__( + self, + d_model, + n_heads, + bucket_size, + n_hashes, + causal, + d_ffn, + dropout, + ): + super().__init__() + self.attn = LSHSelfAttention( + dim=d_model, + heads=n_heads, + bucket_size=bucket_size, + n_hashes=n_hashes, + causal=causal, + ) + self.dropout = nn.Dropout(dropout) + self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) + self.pos_ffn = PositionWiseFeedForward(d_model, d_ffn, dropout) + + def forward( + self, + enc_input: torch.Tensor, + ): + enc_output = self.attn(enc_input) + + # apply dropout and residual connection + enc_output = self.dropout(enc_output) + enc_output += enc_input + + # apply layer-norm + enc_output = self.layer_norm(enc_output) + + enc_output = self.pos_ffn(enc_output) + return enc_output diff --git a/pypots/nn/modules/reformer/local_attention.py b/pypots/nn/modules/reformer/local_attention.py new file mode 100644 index 00000000..86388b7c --- /dev/null +++ b/pypots/nn/modules/reformer/local_attention.py @@ -0,0 +1,339 @@ +""" +Local attention from https://github.com/lucidrains/local-attention +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + + +import math + +import torch +import torch.nn.functional as F +from einops import rearrange +from einops import repeat, pack, unpack +from torch import nn, einsum +from torch.cuda.amp import autocast + +TOKEN_SELF_ATTN_VALUE = -5e4 + + +def exists(val): + return val is not None + + +def rotate_half(x): + x = rearrange(x, "b ... (r d) -> b ... r d", r=2) + x1, x2 = x.unbind(dim=-2) + return torch.cat((-x2, x1), dim=-1) + + +@autocast(enabled=False) +def apply_rotary_pos_emb(q, k, freqs, scale=1): + q_len = q.shape[-2] + q_freqs = freqs[..., -q_len:, :] + + inv_scale = scale**-1 + + if scale.ndim == 2: + scale = scale[-q_len:, :] + + q = (q * q_freqs.cos() * scale) + (rotate_half(q) * q_freqs.sin() * scale) + k = (k * freqs.cos() * inv_scale) + (rotate_half(k) * freqs.sin() * inv_scale) + return q, k + + +def exists(val): + return val is not None + + +def default(value, d): + return d if not exists(value) else value + + +def to(t): + return {"device": t.device, "dtype": t.dtype} + + +def max_neg_value(tensor): + return -torch.finfo(tensor.dtype).max + + +def l2norm(tensor): + dtype = tensor.dtype + normed = F.normalize(tensor, dim=-1) + return normed.type(dtype) + + +def pad_to_multiple(tensor, multiple, dim=-1, value=0): + seqlen = tensor.shape[dim] + m = seqlen / multiple + if m.is_integer(): + return False, tensor + remainder = math.ceil(m) * multiple - seqlen + pad_offset = (0,) * (-1 - dim) * 2 + return True, F.pad(tensor, (*pad_offset, 0, remainder), value=value) + + +def look_around(x, backward=1, forward=0, pad_value=-1, dim=2): + t = x.shape[1] + dims = (len(x.shape) - dim) * (0, 0) + padded_x = F.pad(x, (*dims, backward, forward), value=pad_value) + tensors = [ + padded_x[:, ind : (ind + t), ...] for ind in range(forward + backward + 1) + ] + return torch.cat(tensors, dim=dim) + + +class SinusoidalEmbeddings(nn.Module): + def __init__(self, dim, scale_base=None, use_xpos=False): + super().__init__() + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer("inv_freq", inv_freq) + + # xpos related + + self.use_xpos = use_xpos + self.scale_base = scale_base + + assert not ( + use_xpos and not exists(scale_base) + ), "scale base must be defined if using xpos" + + scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim) + self.register_buffer("scale", scale, persistent=False) + + @autocast(enabled=False) + def forward(self, x): + seq_len, device = x.shape[-2], x.device + + t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq) + freqs = torch.einsum("i , j -> i j", t, self.inv_freq) + freqs = torch.cat((freqs, freqs), dim=-1) + + if not self.use_xpos: + return freqs, torch.ones(1, device=device) + + power = (t - (seq_len // 2)) / self.scale_base + scale = self.scale ** rearrange(power, "n -> n 1") + scale = torch.cat((scale, scale), dim=-1) + + return freqs, scale + + +class LocalAttention(nn.Module): + def __init__( + self, + window_size, + causal=False, + look_backward=1, + look_forward=None, + dropout=0.0, + shared_qk=False, + rel_pos_emb_config=None, + dim=None, + autopad=False, + exact_windowsize=False, + scale=None, + use_rotary_pos_emb=True, + use_xpos=False, + xpos_scale_base=None, + ): + super().__init__() + look_forward = default(look_forward, 0 if causal else 1) + assert not (causal and look_forward > 0), "you cannot look forward if causal" + + self.scale = scale + + self.window_size = window_size + self.autopad = autopad + self.exact_windowsize = exact_windowsize + + self.causal = causal + + self.look_backward = look_backward + self.look_forward = look_forward + + self.dropout = nn.Dropout(dropout) + + self.shared_qk = shared_qk + + # relative positions + + self.rel_pos = None + self.use_xpos = use_xpos + + if use_rotary_pos_emb and ( + exists(rel_pos_emb_config) or exists(dim) + ): # backwards compatible with old `rel_pos_emb_config` deprecated argument + if exists(rel_pos_emb_config): + dim = rel_pos_emb_config[0] + + self.rel_pos = SinusoidalEmbeddings( + dim, + use_xpos=use_xpos, + scale_base=default(xpos_scale_base, window_size // 2), + ) + + def forward( + self, q, k, v, mask=None, input_mask=None, attn_bias=None, window_size=None + ): + + mask = default(mask, input_mask) + + assert not ( + exists(window_size) and not self.use_xpos + ), "cannot perform window size extrapolation if xpos is not turned on" + + ( + shape, + autopad, + pad_value, + window_size, + causal, + look_backward, + look_forward, + shared_qk, + ) = ( + q.shape, + self.autopad, + -1, + default(window_size, self.window_size), + self.causal, + self.look_backward, + self.look_forward, + self.shared_qk, + ) + + # https://github.com/arogozhnikov/einops/blob/master/docs/4-pack-and-unpack.ipynb + (q, packed_shape), (k, _), (v, _) = map(lambda t: pack([t], "* n d"), (q, k, v)) + + # auto padding + + if autopad: + orig_seq_len = q.shape[1] + (needed_pad, q), (_, k), (_, v) = map( + lambda t: pad_to_multiple(t, self.window_size, dim=-2), (q, k, v) + ) + + b, n, dim_head, device, dtype = *q.shape, q.device, q.dtype + + scale = default(self.scale, dim_head**-0.5) + + assert ( + n % window_size + ) == 0, f"sequence length {n} must be divisible by window size {window_size} for local attention" + + windows = n // window_size + + if shared_qk: + k = l2norm(k) + + seq = torch.arange(n, device=device) + b_t = rearrange(seq, "(w n) -> 1 w n", w=windows, n=window_size) + + # bucketing + + bq, bk, bv = map( + lambda t: rearrange(t, "b (w n) d -> b w n d", w=windows), (q, k, v) + ) + + bq = bq * scale + + look_around_kwargs = dict( + backward=look_backward, forward=look_forward, pad_value=pad_value + ) + + bk = look_around(bk, **look_around_kwargs) + bv = look_around(bv, **look_around_kwargs) + + # rotary embeddings + + if exists(self.rel_pos): + pos_emb, xpos_scale = self.rel_pos(bk) + bq, bk = apply_rotary_pos_emb(bq, bk, pos_emb, scale=xpos_scale) + + # calculate positions for masking + + bq_t = b_t + bq_k = look_around(b_t, **look_around_kwargs) + + bq_t = rearrange(bq_t, "... i -> ... i 1") + bq_k = rearrange(bq_k, "... j -> ... 1 j") + + pad_mask = bq_k == pad_value + + sim = einsum("b h i e, b h j e -> b h i j", bq, bk) + + if exists(attn_bias): + heads = attn_bias.shape[0] + assert (b % heads) == 0 + + attn_bias = repeat(attn_bias, "h i j -> (b h) 1 i j", b=b // heads) + sim = sim + attn_bias + + mask_value = max_neg_value(sim) + + if shared_qk: + self_mask = bq_t == bq_k + sim = sim.masked_fill(self_mask, TOKEN_SELF_ATTN_VALUE) + del self_mask + + if causal: + causal_mask = bq_t < bq_k + + if self.exact_windowsize: + max_causal_window_size = self.window_size * self.look_backward + causal_mask = causal_mask | (bq_t > (bq_k + max_causal_window_size)) + + sim = sim.masked_fill(causal_mask, mask_value) + del causal_mask + + # masking out for exact window size for non-causal + # as well as masking out for padding value + + if not causal and self.exact_windowsize: + max_backward_window_size = self.window_size * self.look_backward + max_forward_window_size = self.window_size * self.look_forward + window_mask = ( + ((bq_k - max_forward_window_size) > bq_t) + | (bq_t > (bq_k + max_backward_window_size)) + | pad_mask + ) + sim = sim.masked_fill(window_mask, mask_value) + else: + sim = sim.masked_fill(pad_mask, mask_value) + + # take care of key padding mask passed in + + if exists(mask): + batch = mask.shape[0] + assert (b % batch) == 0 + + h = b // mask.shape[0] + + if autopad: + _, mask = pad_to_multiple(mask, window_size, dim=-1, value=False) + + mask = rearrange(mask, "... (w n) -> (...) w n", w=windows, n=window_size) + mask = look_around(mask, **{**look_around_kwargs, "pad_value": False}) + mask = rearrange(mask, "... j -> ... 1 j") + mask = repeat(mask, "b ... -> (b h) ...", h=h) + sim = sim.masked_fill(~mask, mask_value) + del mask + + # attention + + attn = sim.softmax(dim=-1) + attn = self.dropout(attn) + + # aggregation + + out = einsum("b h i j, b h j e -> b h i e", attn, bv) + out = rearrange(out, "b w n d -> b (w n) d") + + if autopad: + out = out[:, :orig_seq_len, :] + + out, *_ = unpack(out, packed_shape, "* n d") + return out diff --git a/pypots/nn/modules/reformer/lsh_attention.py b/pypots/nn/modules/reformer/lsh_attention.py new file mode 100644 index 00000000..3f6f2980 --- /dev/null +++ b/pypots/nn/modules/reformer/lsh_attention.py @@ -0,0 +1,686 @@ +""" +Locality-Sensitive Hashing (LSH) Attention from https://github.com/lucidrains/reformer-pytorch +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + +from functools import partial, wraps, reduce +from operator import mul + +import torch +import torch.fft +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange, repeat + +from .local_attention import LocalAttention, TOKEN_SELF_ATTN_VALUE + + +def rotate_every_two(x): + x = rearrange(x, "... (d j) -> ... d j", j=2) + x1, x2 = x.unbind(dim=-1) + x = torch.stack((-x2, x1), dim=-1) + return rearrange(x, "... d j -> ... (d j)") + + +def apply_rotary_pos_emb(qk, sinu_pos): + sinu_pos = sinu_pos.type(qk.dtype) + sinu_pos = rearrange(sinu_pos, "() n (j d) -> n j d", j=2) + sin, cos = sinu_pos.unbind(dim=-2) + sin, cos = map(lambda t: repeat(t, "n d -> n (d j)", j=2), (sin, cos)) + seq_len = sin.shape[0] + qk, qk_pass = qk[:, :seq_len], qk[:, seq_len:] + qk = (qk * cos) + (rotate_every_two(qk) * sin) + return torch.cat((qk, qk_pass), dim=1) + + +def exists(val): + return val is not None + + +def sort_key_val(t1, t2, dim=-1): + values, indices = t1.sort(dim=dim) + t2 = t2.expand_as(t1) + return values, t2.gather(dim, indices) + + +def batched_index_select(values, indices): + last_dim = values.shape[-1] + return values.gather(1, indices[:, :, None].expand(-1, -1, last_dim)) + + +def process_inputs_chunk(fn, chunks=1, dim=0): + def inner_fn(*args, **kwargs): + keys, values, len_args = kwargs.keys(), kwargs.values(), len(args) + chunked_args = list( + zip(*map(lambda x: x.chunk(chunks, dim=dim), list(args) + list(values))) + ) + all_args = map( + lambda x: (x[:len_args], dict(zip(keys, x[len_args:]))), chunked_args + ) + outputs = [fn(*c_args, **c_kwargs) for c_args, c_kwargs in all_args] + return tuple(map(lambda x: torch.cat(x, dim=dim), zip(*outputs))) + + return inner_fn + + +def chunked_sum(tensor, chunks=1): + *orig_size, last_dim = tensor.shape + tensor = tensor.reshape(-1, last_dim) + summed_tensors = [c.sum(dim=-1) for c in tensor.chunk(chunks, dim=0)] + return torch.cat(summed_tensors, dim=0).reshape(orig_size) + + +def default(val, default_val): + return default_val if val is None else val + + +def cast_tuple(x): + return x if isinstance(x, tuple) else (x,) + + +def max_neg_value(tensor): + return -torch.finfo(tensor.dtype).max + + +def cache_fn(f): + cache = None + + @wraps(f) + def cached_fn(*args, **kwargs): + nonlocal cache + if cache is not None: + return cache + cache = f(*args, **kwargs) + return cache + + return cached_fn + + +def cache_method_decorator(cache_attr, cache_namespace, reexecute=False): + def inner_fn(fn): + @wraps(fn) + def wrapper( + self, *args, key_namespace=None, fetch=False, set_cache=True, **kwargs + ): + namespace_str = str(default(key_namespace, "")) + _cache = getattr(self, cache_attr) + _keyname = f"{cache_namespace}:{namespace_str}" + + if fetch: + val = _cache[_keyname] + if reexecute: + fn(self, *args, **kwargs) + else: + val = fn(self, *args, **kwargs) + if set_cache: + setattr(self, cache_attr, {**_cache, **{_keyname: val}}) + return val + + return wrapper + + return inner_fn + + +def expand_dim(dim, k, t): + t = t.unsqueeze(dim) + expand_shape = [-1] * len(t.shape) + expand_shape[dim] = k + return t.expand(*expand_shape) + + +def merge_dims(ind_from, ind_to, tensor): + shape = list(tensor.shape) + arr_slice = slice(ind_from, ind_to + 1) + shape[arr_slice] = [reduce(mul, shape[arr_slice])] + return tensor.reshape(*shape) + + +def split_at_index(dim, index, t): + pre_slices = (slice(None),) * dim + l = (*pre_slices, slice(None, index)) + r = (*pre_slices, slice(index, None)) + return t[l], t[r] + + +class FullQKAttention(nn.Module): + def __init__(self, causal=False, dropout=0.0): + super().__init__() + self.causal = causal + self.dropout = nn.Dropout(dropout) + + def forward( + self, qk, v, query_len=None, input_mask=None, input_attn_mask=None, **kwargs + ): + b, seq_len, dim = qk.shape + query_len = default(query_len, seq_len) + t = query_len + + q = qk[:, 0:query_len] + qk = F.normalize(qk, 2, dim=-1).type_as(q) + + dot = torch.einsum("bie,bje->bij", q, qk) * (dim**-0.5) + + # qk attention requires tokens not attend to self + i = torch.arange(t) + dot[:, i, i] = TOKEN_SELF_ATTN_VALUE + masked_value = max_neg_value(dot) + + # Input mask for padding in variable lengthed sequences + if input_mask is not None: + mask = input_mask[:, 0:query_len, None] * input_mask[:, None, :] + mask = F.pad(mask, (0, seq_len - mask.shape[-1]), value=True) + dot.masked_fill_(~mask, masked_value) + + # Mask for post qk attention logits of the input sequence + if input_attn_mask is not None: + input_attn_mask = F.pad( + input_attn_mask, (0, seq_len - input_attn_mask.shape[-1]), value=True + ) + dot.masked_fill_(~input_attn_mask, masked_value) + + if self.causal: + i, j = torch.triu_indices(t, t, 1) + dot[:, i, j] = masked_value + + dot = dot.softmax(dim=-1) + dot = self.dropout(dot) + + out = torch.einsum("bij,bje->bie", dot, v) + + return out, dot, torch.empty(0) + + +class LSHAttention(nn.Module): + def __init__( + self, + dropout=0.0, + bucket_size=64, + n_hashes=8, + causal=False, + allow_duplicate_attention=True, + attend_across_buckets=True, + rehash_each_round=True, + drop_for_hash_rate=0.0, + random_rotations_per_head=False, + return_attn=False, + ): + super().__init__() + if dropout >= 1.0: + raise ValueError("Dropout rates must be lower than 1.") + + self.dropout = nn.Dropout(dropout) + self.dropout_for_hash = nn.Dropout(drop_for_hash_rate) + + assert rehash_each_round or allow_duplicate_attention, ( + "The setting {allow_duplicate_attention=False, rehash_each_round=False}" + " is not implemented." + ) + + self.causal = causal + self.bucket_size = bucket_size + + self.n_hashes = n_hashes + + self._allow_duplicate_attention = allow_duplicate_attention + self._attend_across_buckets = attend_across_buckets + self._rehash_each_round = rehash_each_round + self._random_rotations_per_head = random_rotations_per_head + + # will expend extra computation to return attention matrix + self._return_attn = return_attn + + # cache buckets for reversible network, reported by authors to make Reformer work at depth + self._cache = {} + + @cache_method_decorator("_cache", "buckets", reexecute=True) + def hash_vectors(self, n_buckets, vecs): + batch_size = vecs.shape[0] + device = vecs.device + + # See https://arxiv.org/pdf/1509.02897.pdf + # We sample a different random rotation for each round of hashing to + # decrease the probability of hash misses. + assert n_buckets % 2 == 0 + + rot_size = n_buckets + + rotations_shape = ( + batch_size if self._random_rotations_per_head else 1, + vecs.shape[-1], + self.n_hashes if self._rehash_each_round else 1, + rot_size // 2, + ) + + random_rotations = torch.randn( + rotations_shape, dtype=vecs.dtype, device=device + ).expand(batch_size, -1, -1, -1) + + dropped_vecs = self.dropout_for_hash(vecs) + rotated_vecs = torch.einsum("btf,bfhi->bhti", dropped_vecs, random_rotations) + + if self._rehash_each_round: + # rotated_vectors size [batch,n_hash,seq_len,buckets] + rotated_vecs = torch.cat([rotated_vecs, -rotated_vecs], dim=-1) + buckets = torch.argmax(rotated_vecs, dim=-1) + else: + rotated_vecs = torch.cat([rotated_vecs, -rotated_vecs], dim=-1) + # In this configuration, we map each item to the top self.n_hashes buckets + rotated_vecs = torch.squeeze(rotated_vecs, 1) + bucket_range = torch.arange(rotated_vecs.shape[-1], device=device) + bucket_range = torch.reshape(bucket_range, (1, -1)) + bucket_range = bucket_range.expand_as(rotated_vecs) + + _, buckets = sort_key_val(rotated_vecs, bucket_range, dim=-1) + # buckets size [batch size, seq_len, buckets] + buckets = buckets[..., -self.n_hashes :].transpose(1, 2) + + # buckets is now (self.n_hashes, seq_len). Next we add offsets so that + # bucket numbers from different hashing rounds don't overlap. + offsets = torch.arange(self.n_hashes, device=device) + offsets = torch.reshape(offsets * n_buckets, (1, -1, 1)) + buckets = torch.reshape( + buckets + offsets, + ( + batch_size, + -1, + ), + ) + return buckets + + def forward( + self, + qk, + v, + query_len=None, + input_mask=None, + input_attn_mask=None, + pos_emb=None, + **kwargs, + ): + batch_size, seqlen, dim, device = *qk.shape, qk.device + + query_len = default(query_len, seqlen) + is_reverse = kwargs.pop("_reverse", False) + depth = kwargs.pop("_depth", None) + + assert ( + seqlen % (self.bucket_size * 2) == 0 + ), f"Sequence length ({seqlen}) needs to be divisible by target bucket size x 2 - {self.bucket_size * 2}" + + n_buckets = seqlen // self.bucket_size + buckets = self.hash_vectors( + n_buckets, + qk, + key_namespace=depth, + fetch=is_reverse, + set_cache=self.training, + ) + + # We use the same vector as both a query and a key. + assert int(buckets.shape[1]) == self.n_hashes * seqlen + + total_hashes = self.n_hashes + + ticker = ( + torch.arange(total_hashes * seqlen, device=device) + .unsqueeze(0) + .expand_as(buckets) + ) + buckets_and_t = seqlen * buckets + (ticker % seqlen) + buckets_and_t = buckets_and_t.detach() + + # Hash-based sort ("s" at the start of variable names means "sorted") + sbuckets_and_t, sticker = sort_key_val(buckets_and_t, ticker, dim=-1) + _, undo_sort = sticker.sort(dim=-1) + del ticker + + sbuckets_and_t = sbuckets_and_t.detach() + sticker = sticker.detach() + undo_sort = undo_sort.detach() + + if exists(pos_emb): + qk = apply_rotary_pos_emb(qk, pos_emb) + + st = sticker % seqlen + sqk = batched_index_select(qk, st) + sv = batched_index_select(v, st) + + # Split off a "bin" axis so that attention only occurs within chunks. + chunk_size = total_hashes * n_buckets + bq_t = bkv_t = torch.reshape(st, (batch_size, chunk_size, -1)) + bqk = torch.reshape(sqk, (batch_size, chunk_size, -1, dim)) + bv = torch.reshape(sv, (batch_size, chunk_size, -1, dim)) + + # Hashing operates on unit-length vectors. Unnormalized query vectors are + # fine because they effectively provide a learnable temperature for the + # attention softmax, but normalizing keys is needed so that similarity for + # the purposes of attention correctly corresponds to hash locality. + bq = bqk + bk = F.normalize(bqk, p=2, dim=-1).type_as(bq) + + # Allow each chunk to attend within itself, and also one chunk back. Chunk + # boundaries might occur in the middle of a sequence of items from the + # same bucket, so this increases the chances of attending to relevant items. + def look_one_back(x): + x_extra = torch.cat([x[:, -1:, ...], x[:, :-1, ...]], dim=1) + return torch.cat([x, x_extra], dim=2) + + bk = look_one_back(bk) + bv = look_one_back(bv) + bkv_t = look_one_back(bkv_t) + + # Dot-product attention. + dots = torch.einsum("bhie,bhje->bhij", bq, bk) * (dim**-0.5) + masked_value = max_neg_value(dots) + + # Mask for post qk attention logits of the input sequence + if input_attn_mask is not None: + input_attn_mask = F.pad( + input_attn_mask, + ( + 0, + seqlen - input_attn_mask.shape[-1], + 0, + seqlen - input_attn_mask.shape[-2], + ), + value=True, + ) + dot_attn_indices = (bq_t * seqlen)[:, :, :, None] + bkv_t[:, :, None, :] + input_attn_mask = input_attn_mask.reshape(batch_size, -1) + dot_attn_indices = dot_attn_indices.reshape(batch_size, -1) + mask = input_attn_mask.gather(1, dot_attn_indices).reshape_as(dots) + dots.masked_fill_(~mask, masked_value) + del mask + + # Input mask for padding in variable lengthed sequences + if input_mask is not None: + input_mask = F.pad( + input_mask, (0, seqlen - input_mask.shape[1]), value=True + ) + mq = input_mask.gather(1, st).reshape((batch_size, chunk_size, -1)) + mkv = look_one_back(mq) + mask = mq[:, :, :, None] * mkv[:, :, None, :] + dots.masked_fill_(~mask, masked_value) + del mask + + # Causal masking + if self.causal: + mask = bq_t[:, :, :, None] < bkv_t[:, :, None, :] + if seqlen > query_len: + mask = mask & (bkv_t[:, :, None, :] < query_len) + dots.masked_fill_(mask, masked_value) + del mask + + # Mask out attention to self except when no other targets are available. + self_mask = bq_t[:, :, :, None] == bkv_t[:, :, None, :] + dots.masked_fill_(self_mask, TOKEN_SELF_ATTN_VALUE) + del self_mask + + # Mask out attention to other hash buckets. + if not self._attend_across_buckets: + bq_buckets = bkv_buckets = torch.reshape( + sbuckets_and_t // seqlen, (batch_size, chunk_size, -1) + ) + bkv_buckets = look_one_back(bkv_buckets) + bucket_mask = bq_buckets[:, :, :, None] != bkv_buckets[:, :, None, :] + dots.masked_fill_(bucket_mask, masked_value) + del bucket_mask + + # Don't double-count query-key pairs across multiple rounds of hashing. + # There are two possible strategies here. (1) The default is to count how + # many times a query-key pair is repeated, and to lower its log-prob + # correspondingly at each repetition. (2) When hard_k is set, the code + # instead masks all but the first occurence of each query-key pair. + if not self._allow_duplicate_attention: + locs1 = undo_sort // bq_t.shape[-1] + locs2 = (locs1 + 1) % chunk_size + if not self._attend_across_buckets: + locs1 = buckets * chunk_size + locs1 + locs2 = buckets * chunk_size + locs2 + locs = torch.cat( + [ + torch.reshape(locs1, (batch_size, total_hashes, seqlen)), + torch.reshape(locs2, (batch_size, total_hashes, seqlen)), + ], + 1, + ).permute((0, 2, 1)) + + slocs = batched_index_select(locs, st) + b_locs = torch.reshape( + slocs, (batch_size, chunk_size, -1, 2 * total_hashes) + ) + + b_locs1 = b_locs[:, :, :, None, :total_hashes] + + bq_locs = b_locs1.expand(b_locs.shape[:3] + (2, total_hashes)) + bq_locs = torch.reshape(bq_locs, b_locs.shape) + bkv_locs = look_one_back(b_locs) + + dup_counts = bq_locs[:, :, :, None, :] == bkv_locs[:, :, None, :, :] + # for memory considerations, chunk summation of last dimension for counting duplicates + dup_counts = chunked_sum(dup_counts, chunks=(total_hashes * batch_size)) + dup_counts = dup_counts.detach() + assert dup_counts.shape == dots.shape + dots = dots - torch.log(dup_counts + 1e-9) + del dup_counts + + # Softmax. + dots_logsumexp = torch.logsumexp(dots, dim=-1, keepdim=True) + dots = torch.exp(dots - dots_logsumexp).type_as(dots) + dropped_dots = self.dropout(dots) + + bo = torch.einsum("buij,buje->buie", dropped_dots, bv) + so = torch.reshape(bo, (batch_size, -1, dim)) + slogits = torch.reshape( + dots_logsumexp, + ( + batch_size, + -1, + ), + ) + + # unsort logits + o = batched_index_select(so, undo_sort) + logits = slogits.gather(1, undo_sort) + + o = torch.reshape(o, (batch_size, total_hashes, seqlen, dim)) + logits = torch.reshape(logits, (batch_size, total_hashes, seqlen, 1)) + + if query_len != seqlen: + query_slice = (slice(None), slice(None), slice(0, query_len)) + o, logits = o[query_slice], logits[query_slice] + + probs = torch.exp(logits - torch.logsumexp(logits, dim=1, keepdim=True)) + out = torch.sum(o * probs, dim=1) + + attn = torch.empty(0, device=device) + + # return unsorted attention weights + if self._return_attn: + attn_unsort = (bq_t * seqlen)[:, :, :, None] + bkv_t[:, :, None, :] + attn_unsort = attn_unsort.view(batch_size * total_hashes, -1).long() + unsorted_dots = torch.zeros( + batch_size * total_hashes, seqlen * seqlen, device=device + ) + unsorted_dots.scatter_add_(1, attn_unsort, dots.view_as(attn_unsort)) + del attn_unsort + unsorted_dots = unsorted_dots.reshape( + batch_size, total_hashes, seqlen, seqlen + ) + attn = torch.sum(unsorted_dots[:, :, 0:query_len, :] * probs, dim=1) + + # return output, attention matrix, and bucket distribution + return out, attn, buckets + + +class LSHSelfAttention(nn.Module): + def __init__( + self, + dim, + heads=8, + bucket_size=64, + n_hashes=8, + causal=False, + dim_head=None, + attn_chunks=1, + random_rotations_per_head=False, + attend_across_buckets=True, + allow_duplicate_attention=True, + num_mem_kv=0, + one_value_head=False, + use_full_attn=False, + full_attn_thres=None, + return_attn=False, + post_attn_dropout=0.0, + dropout=0.0, + n_local_attn_heads=0, + **kwargs, + ): + super().__init__() + assert ( + dim_head or (dim % heads) == 0 + ), "dimensions must be divisible by number of heads" + assert ( + n_local_attn_heads < heads + ), "local attention heads must be less than number of heads" + + dim_head = default(dim_head, dim // heads) + dim_heads = dim_head * heads + + self.dim = dim + self.heads = heads + self.dim_head = dim_head + self.attn_chunks = default(attn_chunks, 1) + + self.v_head_repeats = heads if one_value_head else 1 + v_dim = dim_heads // self.v_head_repeats + + self.toqk = nn.Linear(dim, dim_heads, bias=False) + self.tov = nn.Linear(dim, v_dim, bias=False) + self.to_out = nn.Linear(dim_heads, dim) + + self.bucket_size = bucket_size + self.lsh_attn = LSHAttention( + bucket_size=bucket_size, + n_hashes=n_hashes, + causal=causal, + random_rotations_per_head=random_rotations_per_head, + attend_across_buckets=attend_across_buckets, + allow_duplicate_attention=allow_duplicate_attention, + return_attn=return_attn, + dropout=dropout, + **kwargs, + ) + self.full_attn = FullQKAttention(causal=causal, dropout=dropout) + self.post_attn_dropout = nn.Dropout(post_attn_dropout) + + self.use_full_attn = use_full_attn + self.full_attn_thres = default(full_attn_thres, bucket_size) + + self.num_mem_kv = num_mem_kv + self.mem_kv = ( + nn.Parameter(torch.randn(1, num_mem_kv, dim, requires_grad=True)) + if num_mem_kv > 0 + else None + ) + + self.n_local_attn_heads = n_local_attn_heads + self.local_attn = LocalAttention( + window_size=bucket_size * 2, + causal=causal, + dropout=dropout, + shared_qk=True, + look_forward=(1 if not causal else 0), + ) + + self.callback = None + + def forward( + self, + x, + keys=None, + input_mask=None, + input_attn_mask=None, + context_mask=None, + pos_emb=None, + **kwargs, + ): + device, dtype = x.device, x.dtype + b, t, e, h, dh, m, l_h = ( + *x.shape, + self.heads, + self.dim_head, + self.num_mem_kv, + self.n_local_attn_heads, + ) + + mem_kv = default(self.mem_kv, torch.empty(b, 0, e, dtype=dtype, device=device)) + mem = mem_kv.expand(b, m, -1) + + keys = default(keys, torch.empty(b, 0, e, dtype=dtype, device=device)) + c = keys.shape[1] + + kv_len = t + m + c + use_full_attn = self.use_full_attn or kv_len <= self.full_attn_thres + + x = torch.cat((x, mem, keys), dim=1) + qk = self.toqk(x) + v = self.tov(x) + v = v.repeat(1, 1, self.v_head_repeats) + + def merge_heads(v): + return v.view(b, kv_len, h, -1).transpose(1, 2) + + def split_heads(v): + return v.view(b, h, t, -1).transpose(1, 2).contiguous() + + merge_batch_and_heads = partial(merge_dims, 0, 1) + + qk, v = map(merge_heads, (qk, v)) + + has_local = l_h > 0 + lsh_h = h - l_h + + split_index_fn = partial(split_at_index, 1, l_h) + (lqk, qk), (lv, v) = map(split_index_fn, (qk, v)) + lqk, qk, lv, v = map(merge_batch_and_heads, (lqk, qk, lv, v)) + + masks = {} + if input_mask is not None or context_mask is not None: + default_mask = torch.tensor([True], device=device) + i_mask = default(input_mask, default_mask.expand(b, t)) + m_mask = default_mask.expand(b, m) + c_mask = default(context_mask, default_mask.expand(b, c)) + mask = torch.cat((i_mask, m_mask, c_mask), dim=1) + mask = merge_batch_and_heads(expand_dim(1, lsh_h, mask)) + masks["input_mask"] = mask + + if input_attn_mask is not None: + input_attn_mask = merge_batch_and_heads( + expand_dim(1, lsh_h, input_attn_mask) + ) + masks["input_attn_mask"] = input_attn_mask + + attn_fn = self.lsh_attn if not use_full_attn else self.full_attn + partial_attn_fn = partial(attn_fn, query_len=t, pos_emb=pos_emb, **kwargs) + attn_fn_in_chunks = process_inputs_chunk( + partial_attn_fn, chunks=self.attn_chunks + ) + + out, attn, buckets = attn_fn_in_chunks(qk, v, **masks) + + if self.callback is not None: + self.callback(attn.reshape(b, lsh_h, t, -1), buckets.reshape(b, lsh_h, -1)) + + if has_local: + lqk, lv = lqk[:, :t], lv[:, :t] + local_out = self.local_attn(lqk, lqk, lv, input_mask=input_mask) + local_out = local_out.reshape(b, l_h, t, -1) + out = out.reshape(b, lsh_h, t, -1) + out = torch.cat((local_out, out), dim=1) + + out = split_heads(out).view(b, t, -1) + out = self.to_out(out) + return self.post_attn_dropout(out) diff --git a/tests/imputation/reformer.py b/tests/imputation/reformer.py new file mode 100644 index 00000000..15b3d749 --- /dev/null +++ b/tests/imputation/reformer.py @@ -0,0 +1,130 @@ +""" +Test cases for Reformer imputation model. +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + + +import os.path +import unittest + +import numpy as np +import pytest + +from pypots.imputation import Reformer +from pypots.optim import Adam +from pypots.utils.logging import logger +from pypots.utils.metrics import calc_mse +from tests.global_test_config import ( + DATA, + EPOCHS, + DEVICE, + TRAIN_SET, + VAL_SET, + TEST_SET, + GENERAL_H5_TRAIN_SET_PATH, + GENERAL_H5_VAL_SET_PATH, + GENERAL_H5_TEST_SET_PATH, + RESULT_SAVING_DIR_FOR_IMPUTATION, + check_tb_and_model_checkpoints_existence, +) + + +class TestReformer(unittest.TestCase): + logger.info("Running tests for an imputation model Reformer...") + + # set the log and model saving path + saving_path = os.path.join(RESULT_SAVING_DIR_FOR_IMPUTATION, "Reformer") + model_save_name = "saved_reformer_model.pypots" + + # initialize an Adam optimizer + optimizer = Adam(lr=0.001, weight_decay=1e-5) + + # initialize a Reformer model + reformer = Reformer( + DATA["n_steps"], + DATA["n_features"], + n_layers=2, + d_model=32, + n_heads=2, + bucket_size=4, + n_hashes=4, + causal=True, + d_ffn=32, + dropout=0, + epochs=EPOCHS, + saving_path=saving_path, + optimizer=optimizer, + device=DEVICE, + ) + + @pytest.mark.xdist_group(name="imputation-reformer") + def test_0_fit(self): + self.reformer.fit(TRAIN_SET, VAL_SET) + + @pytest.mark.xdist_group(name="imputation-reformer") + def test_1_impute(self): + imputation_results = self.reformer.predict(TEST_SET) + assert not np.isnan( + imputation_results["imputation"] + ).any(), "Output still has missing values after running impute()." + + test_MSE = calc_mse( + imputation_results["imputation"], + DATA["test_X_ori"], + DATA["test_X_indicating_mask"], + ) + logger.info(f"Reformer test_MSE: {test_MSE}") + + @pytest.mark.xdist_group(name="imputation-reformer") + def test_2_parameters(self): + assert hasattr(self.reformer, "model") and self.reformer.model is not None + + assert ( + hasattr(self.reformer, "optimizer") and self.reformer.optimizer is not None + ) + + assert hasattr(self.reformer, "best_loss") + self.assertNotEqual(self.reformer.best_loss, float("inf")) + + assert ( + hasattr(self.reformer, "best_model_dict") + and self.reformer.best_model_dict is not None + ) + + @pytest.mark.xdist_group(name="imputation-reformer") + def test_3_saving_path(self): + # whether the root saving dir exists, which should be created by save_log_into_tb_file + assert os.path.exists( + self.saving_path + ), f"file {self.saving_path} does not exist" + + # check if the tensorboard file and model checkpoints exist + check_tb_and_model_checkpoints_existence(self.reformer) + + # save the trained model into file, and check if the path exists + saved_model_path = os.path.join(self.saving_path, self.model_save_name) + self.reformer.save(saved_model_path) + + # test loading the saved model, not necessary, but need to test + self.reformer.load(saved_model_path) + + @pytest.mark.xdist_group(name="imputation-reformer") + def test_4_lazy_loading(self): + self.reformer.fit(GENERAL_H5_TRAIN_SET_PATH, GENERAL_H5_VAL_SET_PATH) + imputation_results = self.reformer.predict(GENERAL_H5_TEST_SET_PATH) + assert not np.isnan( + imputation_results["imputation"] + ).any(), "Output still has missing values after running impute()." + + test_MSE = calc_mse( + imputation_results["imputation"], + DATA["test_X_ori"], + DATA["test_X_indicating_mask"], + ) + logger.info(f"Lazy-loading Reformer test_MSE: {test_MSE}") + + +if __name__ == "__main__": + unittest.main()