diff --git a/.gitignore b/.gitignore index 45eacb95..524eff5d 100644 --- a/.gitignore +++ b/.gitignore @@ -10,3 +10,4 @@ MANIFEST /wandb /models /.ipynb_checkpoints +*.*~ diff --git a/environment.yml b/environment.yml index 62767e2e..4a304e72 100644 --- a/environment.yml +++ b/environment.yml @@ -1,16 +1,183 @@ name: fpt +channels: + - anaconda + - pytorch + - conda-forge + - defaults dependencies: - - python=3.7 - - pip: - - boto3==1.17.102 - - einops==0.3.0 - - matplotlib==3.2.1 - - numpy==1.18.3 - - tape-proteins==0.4 - - tensorflow==2.3.0 - - tensorflow-datasets==4.0.1 - - torch==1.7.1 - - torchvision==0.8.2 - - transformers==4.1.1 - - tqdm==4.46.0 - - wandb==0.9.1 + - _libgcc_mutex=0.1=main + - _openmp_mutex=4.5=1_gnu + - _pytorch_select=0.1=cpu_0 + - _tflow_select=2.3.0=mkl + - absl-py=0.13.0=py37h06a4308_0 + - aiohttp=3.8.1=py37h7f8727e_0 + - aiosignal=1.2.0=pyhd3eb1b0_0 + - astor=0.8.1=py37h06a4308_0 + - astunparse=1.6.3=py_0 + - async-timeout=4.0.1=pyhd3eb1b0_0 + - asynctest=0.13.0=py_0 + - attrs=21.2.0=pyhd3eb1b0_0 + - blas=1.0=mkl + - blinker=1.4=py37h06a4308_0 + - boto3=1.20.16=pyhd8ed1ab_0 + - botocore=1.23.16=pyhd8ed1ab_0 + - bottleneck=1.3.2=py37heb32a55_1 + - brotli=1.0.9=he6710b0_2 + - brotlipy=0.7.0=py37h27cfd23_1003 + - c-ares=1.17.1=h27cfd23_0 + - ca-certificates=2021.10.26=h06a4308_2 + - cachetools=4.2.2=pyhd3eb1b0_0 + - certifi=2021.10.8=py37h06a4308_0 + - cffi=1.14.6=py37h400218f_0 + - charset-normalizer=2.0.4=pyhd3eb1b0_0 + - click=8.0.3=pyhd3eb1b0_0 + - cryptography=3.4.8=py37hd23ed53_0 + - cudatoolkit=10.2.89=hfd86e86_1 + - cycler=0.11.0=pyhd3eb1b0_0 + - dataclasses=0.8=pyh6d0b6a4_7 + - dbus=1.13.18=hb2f20db_0 + - dill=0.3.4=pyhd3eb1b0_0 + - einops=0.3.2=pyhd8ed1ab_0 + - expat=2.4.1=h2531618_2 + - fontconfig=2.13.1=h6c09931_0 + - fonttools=4.25.0=pyhd3eb1b0_0 + - freetype=2.11.0=h70c0345_0 + - frozenlist=1.2.0=py37h7f8727e_0 + - future=0.18.2=py37_1 + - gast=0.4.0=pyhd3eb1b0_0 + - giflib=5.2.1=h7b6447c_0 + - glib=2.69.1=h5202010_0 + - google-auth=1.33.0=pyhd3eb1b0_0 + - google-auth-oauthlib=0.4.4=pyhd3eb1b0_0 + - google-pasta=0.2.0=pyhd3eb1b0_0 + - googleapis-common-protos=1.53.0=py37h06a4308_0 + - grpcio=1.42.0=py37hce63b2e_0 + - gst-plugins-base=1.14.0=h8213a91_2 + - gstreamer=1.14.0=h28cd5cc_2 + - h5py=2.10.0=py37hd6299e0_1 + - hdf5=1.10.6=hb1b8bf9_0 + - icu=58.2=he6710b0_3 + - idna=3.3=pyhd3eb1b0_0 + - importlib-metadata=4.8.1=py37h06a4308_0 + - intel-openmp=2021.4.0=h06a4308_3561 + - jmespath=0.10.0=pyhd3eb1b0_0 + - joblib=1.1.0=pyhd3eb1b0_0 + - jpeg=9d=h7f8727e_0 + - keras-preprocessing=1.1.2=pyhd3eb1b0_0 + - kiwisolver=1.3.1=py37h2531618_0 + - lcms2=2.12=h3be6417_0 + - ld_impl_linux-64=2.35.1=h7274673_9 + - libffi=3.3=he6710b0_2 + - libgcc-ng=9.3.0=h5101ec6_17 + - libgfortran-ng=7.5.0=ha8ba4b0_17 + - libgfortran4=7.5.0=ha8ba4b0_17 + - libgomp=9.3.0=h5101ec6_17 + - libpng=1.6.37=hbc83047_0 + - libprotobuf=3.17.2=h4ff587b_1 + - libstdcxx-ng=9.3.0=hd4cf53a_17 + - libtiff=4.2.0=h85742a9_0 + - libuuid=1.0.3=h7f8727e_2 + - libuv=1.40.0=h7b6447c_0 + - libwebp=1.2.0=h89dd481_0 + - libwebp-base=1.2.0=h27cfd23_0 + - libxcb=1.14=h7b6447c_0 + - libxml2=2.9.12=h03d6c58_0 + - lz4-c=1.9.3=h295c915_1 + - markdown=3.3.4=py37h06a4308_0 + - matplotlib=3.4.3=py37h06a4308_0 + - matplotlib-base=3.4.3=py37hbbc1b5f_0 + - mkl=2021.4.0=h06a4308_640 + - mkl-service=2.4.0=py37h7f8727e_0 + - mkl_fft=1.3.1=py37hd3c417c_0 + - mkl_random=1.2.2=py37h51133e4_0 + - multidict=5.1.0=py37h27cfd23_2 + - munkres=1.1.4=py_0 + - ncurses=6.3=h7f8727e_2 + - ninja=1.10.2=py37hd09550d_3 + - numexpr=2.7.3=py37h22e1b3c_1 + - numpy=1.21.2=py37h20f2e39_0 + - numpy-base=1.21.2=py37h79a1101_0 + - oauthlib=3.1.1=pyhd3eb1b0_0 + - olefile=0.46=py37_0 + - openssl=1.1.1l=h7f8727e_0 + - opt_einsum=3.3.0=pyhd3eb1b0_1 + - pandas=1.3.4=py37h8c16a72_0 + - pcre=8.45=h295c915_0 + - pillow=8.4.0=py37h5aabda8_0 + - pip=21.2.2=py37h06a4308_0 + - promise=2.3=py37h06a4308_0 + - psutil=5.8.0=py37h27cfd23_1 + - pyasn1=0.4.8=pyhd3eb1b0_0 + - pyasn1-modules=0.2.8=py_0 + - pycparser=2.21=pyhd3eb1b0_0 + - pyjwt=2.1.0=py37h06a4308_0 + - pyopenssl=21.0.0=pyhd3eb1b0_1 + - pyparsing=3.0.4=pyhd3eb1b0_0 + - pyqt=5.9.2=py37h05f1152_2 + - pysocks=1.7.1=py37_1 + - python=3.7.11=h12debd9_0 + - python-dateutil=2.8.2=pyhd3eb1b0_0 + - python-flatbuffers=2.0=pyhd3eb1b0_0 + - pytorch=1.7.1=py3.7_cuda10.2.89_cudnn7.6.5_0 + - pytz=2021.3=pyhd3eb1b0_0 + - qt=5.9.7=h5867ecd_1 + - readline=8.1=h27cfd23_0 + - requests=2.26.0=pyhd3eb1b0_0 + - requests-oauthlib=1.3.0=py_0 + - rsa=4.7.2=pyhd3eb1b0_1 + - s3transfer=0.5.0=pyhd3eb1b0_0 + - scikit-learn=0.23.2=py37h0573a6f_0 + - setuptools=58.0.4=py37h06a4308_0 + - sip=4.19.8=py37hf484d3e_0 + - six=1.16.0=pyhd3eb1b0_0 + - sqlite=3.36.0=hc218d9a_0 + - tensorboard=2.4.0=pyhc547734_0 + - tensorboard-plugin-wit=1.6.0=py_0 + - tensorflow=2.4.1=mkl_py37h2d14ff2_0 + - tensorflow-base=2.4.1=mkl_py37h43e0292_0 + - tensorflow-datasets=1.2.0=py37_0 + - tensorflow-estimator=2.6.0=pyh7b7c402_0 + - tensorflow-metadata=0.14.0=pyhe6710b0_1 + - termcolor=1.1.0=py37h06a4308_1 + - threadpoolctl=2.2.0=pyh0d69192_0 + - tk=8.6.11=h1ccaba5_0 + - torchvision=0.8.2=cpu_py37ha229d99_0 + - tornado=6.1=py37h27cfd23_0 + - tqdm=4.62.3=pyhd3eb1b0_1 + - typing-extensions=3.10.0.2=hd3eb1b0_0 + - typing_extensions=3.10.0.2=pyh06a4308_0 + - urllib3=1.26.7=pyhd3eb1b0_0 + - werkzeug=2.0.2=pyhd3eb1b0_0 + - wheel=0.37.0=pyhd3eb1b0_1 + - wrapt=1.13.3=py37h7f8727e_2 + - xz=5.2.5=h7b6447c_0 + - yarl=1.6.3=py37h27cfd23_0 + - zipp=3.6.0=pyhd3eb1b0_0 + - zlib=1.2.11=h7b6447c_3 + - zstd=1.4.9=haebb681_0 + - pip: + - biopython==1.79 + - configparser==5.1.0 + - docker-pycreds==0.4.0 + - filelock==3.4.0 + - gitdb==4.0.9 + - gitpython==3.1.24 + - lmdb==1.2.1 + - packaging==21.3 + - pathtools==0.1.2 + - protobuf==3.19.1 + - pyyaml==6.0 + - regex==2021.11.10 + - sacremoses==0.0.46 + - scipy==1.7.3 + - sentry-sdk==1.5.0 + - shortuuid==1.0.8 + - smmap==5.0.0 + - subprocess32==3.5.4 + - tape-proteins==0.5 + - tensorboardx==2.4.1 + - timm==0.4.12 + - tokenizers==0.9.4 + - transformers==4.1.1 + - wandb==0.12.7 + - yaspin==2.1.0 diff --git a/scripts/run.py b/scripts/run.py index c5c98f8a..3a7edf00 100644 --- a/scripts/run.py +++ b/scripts/run.py @@ -1,17 +1,24 @@ from universal_computation.experiment import run_experiment +from argparse import ArgumentParser +import sys if __name__ == '__main__': - - experiment_name = 'fpt' + parser = ArgumentParser(description='Pick the task to be run.') + parser.add_argument('name', help='the name of the experiment') + parser.add_argument('task', help='the name of the task to be run') + parser.add_argument('--model', default='gpt2', help='the model to use') + args = parser.parse_args() + + experiment_name = args.name experiment_params = dict( - task='bit-memory', + task=args.task, n=1000, # ignored if not a bit task num_patterns=5, # ignored if not a bit task - patch_size=50, + patch_size=16, - model_name='gpt2', + model_name=args.model, pretrained=True, # if vit this is forced to true, if lstm this is forced to false freeze_trans=True, # if False, we don't check arguments other than in and out @@ -31,4 +38,5 @@ orth_gain=1.41, # orthogonal initialization of input layer ) + sys.argv = [''] # clear args since run_experiment also has an argparser run_experiment(experiment_name, experiment_params) diff --git a/universal_computation/datasets/eurosat.py b/universal_computation/datasets/eurosat.py new file mode 100644 index 00000000..424ab845 --- /dev/null +++ b/universal_computation/datasets/eurosat.py @@ -0,0 +1,104 @@ +import os +import pathlib +from pathlib import Path +import pandas as pd +from einops import rearrange +from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +import torch +from torch.utils.data import DataLoader +from PIL import Image +import torchvision.transforms as transforms + +from universal_computation.datasets.dataset import Dataset + +class EuroSatDatasetHelper(torch.utils.data.Dataset): + def __init__(self, img_dir, ann_file, transform=None, target_transform=None): + df = pd.read_csv(ann_file) + self.img_labels = df[['label','img_name', 'int_label']].reset_index(drop=True) + self.img_dir = img_dir + self.transform = transform + self.target_transform = target_transform + + def __len__(self): + return len(self.img_labels) + + def __getitem__(self, idx): + label = self.img_labels.iloc[idx, 0] + name = self.img_labels.iloc[idx, 1] + int_label = self.img_labels.iloc[idx, 2] + temp = os.path.join(self.img_dir, label) + img_path = os.path.join(temp,name) + img = Image.open(img_path) + if self.transform: + img = self.transform(img) + if self.target_transform: + int_label = self.target_transform(int_label) + return img, int_label + + +class EuroSatDataset(Dataset): + def __init__(self, batch_size, patch_size=None, data_aug=True, *args, **kwargs): + super(EuroSatDataset, self).__init__(*args, **kwargs) + + self.batch_size = batch_size + self.patch_size = patch_size + + if data_aug: + transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Resize((224,224), interpolation=3), + transforms.RandomApply([transforms.GaussianBlur(3)]), + transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD), + ]) + else: + transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Resize((224,224), interpolation=3), + transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD), + ]) + + val_transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Resize((224,224), interpolation=3), + transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD), + ]) + train_test_dir = 'data/2750' + self.d_train = DataLoader( + EuroSatDatasetHelper(train_test_dir, os.path.join(train_test_dir, 'train.csv'), transform=transform), + batch_size=batch_size, drop_last=True, shuffle=True, + ) + self.d_test = DataLoader( + EuroSatDatasetHelper(train_test_dir, os.path.join(train_test_dir, 'test.csv'), transform=val_transform), + batch_size=batch_size, drop_last=True, shuffle=True, + ) + + self.train_enum = enumerate(self.d_train) + self.test_enum = enumerate(self.d_test) + + self.train_size = len(self.d_train) + self.test_size = len(self.d_test) + + def reset_test(self): + self.test_enum = enumerate(self.d_test) + + def get_batch(self, batch_size=None, train=True): + if train: + _, (x, y) = next(self.train_enum, (None, (None, None))) + if x is None: + self.train_enum = enumerate(self.d_train) + _, (x, y) = next(self.train_enum) + else: + _, (x, y) = next(self.test_enum, (None, (None, None))) + if x is None: + self.test_enum = enumerate(self.d_test) + _, (x, y) = next(self.train_enum) + + if self.patch_size is not None: + x = rearrange(x, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=self.patch_size, p2=self.patch_size) + + x = x.to(device=self.device) + y = y.to(device=self.device) + + self._ind += 1 + + return x, y diff --git a/universal_computation/datasets/helpers/annotations.py b/universal_computation/datasets/helpers/annotations.py new file mode 100644 index 00000000..b1889e9f --- /dev/null +++ b/universal_computation/datasets/helpers/annotations.py @@ -0,0 +1,23 @@ +import pandas as pd +import os +import pathlib +from pathlib import Path + +data = 'data/2750' + +df = pd.DataFrame(columns=['label', 'int_label', 'img_name']) +labels_dict = {} +counter = 0 +for subdir in os.listdir(data): + labels_dict[subdir] = counter + filepath = os.path.join(data,subdir) + if os.path.isdir(filepath): + for file in os.listdir(filepath): + dict = {'label': subdir, 'int_label': labels_dict[subdir], 'img_name': file} + df = df.append(dict, ignore_index = True) + counter += 1 +train = df.sample(frac=0.75,random_state=200) #random state is a seed value +test = df.drop(train.index) + +train.to_csv(f'{data}/train.csv') +test.to_csv(f'{data}/test.csv') diff --git a/universal_computation/datasets/helpers/datasetops.py b/universal_computation/datasets/helpers/datasetops.py new file mode 100644 index 00000000..3e5dcca1 --- /dev/null +++ b/universal_computation/datasets/helpers/datasetops.py @@ -0,0 +1,36 @@ +import pandas as pd + +def read_annotations(ann_file): + site_name = pd.read_csv(ann_file, nrows=1, header=None)[0].tolist()[0].split('# Site: ')[1] + labels_dict = {} + with open(ann_file, 'r') as f: + start_reading = False + for line in f: + if start_reading: + if line[0] != '#': + break + else: + int_label, str_label = line[1:].split('. ') + int_label = int(int_label) + str_label = str_label.strip() + labels_dict[str_label] = int_label + if line == '# Categories:\n': + start_reading = True + + df = pd.read_csv(ann_file, comment='#') + df.set_index('timestamp', inplace=True) + df.sort_index(inplace=True) + + df['label'] = df['label'].astype('category') + df['int_label'] = [labels_dict[x] for x in df['label']] + + img_name_col = [] + for ts in df.index: + year = ts[:4] + month = ts[5:7] + day = ts[8:10] + hms = ts.split(' ')[1].replace(':', '') + img_name_col.append(f'{site_name}_{year}_{month}_{day}_{hms}.jpg') + df['img_name'] = img_name_col + + return df diff --git a/universal_computation/datasets/phenocam.py b/universal_computation/datasets/phenocam.py new file mode 100644 index 00000000..ccd4707c --- /dev/null +++ b/universal_computation/datasets/phenocam.py @@ -0,0 +1,103 @@ +import os + +from einops import rearrange +from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +import torch +from torch.utils.data import DataLoader +from PIL import Image +import torchvision.transforms as transforms + +from universal_computation.datasets.dataset import Dataset +from universal_computation.datasets.helpers.datasetops import read_annotations + +class PhenoCamDatasetHelper(torch.utils.data.Dataset): + def __init__(self, img_dir, ann_file, transform=None, target_transform=None): + df = read_annotations(ann_file) + self.img_labels = df[['img_name', 'int_label']].reset_index(drop=True) + self.img_dir = img_dir + self.transform = transform + self.target_transform = target_transform + + def __len__(self): + return len(self.img_labels) + + def __getitem__(self, idx): + img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0]) + img = Image.open(img_path) + label = self.img_labels.iloc[idx, 1] + if self.transform: + img = self.transform(img) + if self.target_transform: + label = self.target_transform(label) + return img, label + + +class PhenoCamDataset(Dataset): + def __init__(self, batch_size, patch_size=None, data_aug=True, *args, **kwargs): + site = kwargs.pop('site', 'canadaojp') + super(PhenoCamDataset, self).__init__(*args, **kwargs) + + self.batch_size = batch_size + self.patch_size = patch_size + + if data_aug: + transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Resize((224,224), interpolation=3), + transforms.RandomApply([transforms.GaussianBlur(3)]), + transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD), + ]) + else: + transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Resize((224,224), interpolation=3), + transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD), + ]) + + val_transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Resize((224,224), interpolation=3), + transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD), + ]) + + train_dir = f'data/phenocam/{site}_train' + test_dir = f'data/phenocam/{site}_test' + self.d_train = DataLoader( + PhenoCamDatasetHelper(train_dir, os.path.join(train_dir, 'annotations.csv'), transform=transform), + batch_size=batch_size, drop_last=True, shuffle=True, + ) + self.d_test = DataLoader( + PhenoCamDatasetHelper(test_dir, os.path.join(test_dir, 'annotations.csv'), transform=val_transform), + batch_size=batch_size, drop_last=True, shuffle=True, + ) + + self.train_enum = enumerate(self.d_train) + self.test_enum = enumerate(self.d_test) + + self.train_size = len(self.d_train) + self.test_size = len(self.d_test) + + def reset_test(self): + self.test_enum = enumerate(self.d_test) + + def get_batch(self, batch_size=None, train=True): + if train: + _, (x, y) = next(self.train_enum, (None, (None, None))) + if x is None: + self.train_enum = enumerate(self.d_train) + _, (x, y) = next(self.train_enum) + else: + _, (x, y) = next(self.test_enum, (None, (None, None))) + if x is None: + self.test_enum = enumerate(self.d_test) + _, (x, y) = next(self.train_enum) + + if self.patch_size is not None: + x = rearrange(x, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=self.patch_size, p2=self.patch_size) + + x = x.to(device=self.device) + y = y.to(device=self.device) + + self._ind += 1 + + return x, y diff --git a/universal_computation/experiment.py b/universal_computation/experiment.py index cadd3881..875273a9 100644 --- a/universal_computation/experiment.py +++ b/universal_computation/experiment.py @@ -87,6 +87,19 @@ def experiment( input_dim, output_dim = 30, 1200 use_embeddings = True experiment_type = 'classification' + elif task == 'eurosat': + from universal_computation.datasets.eurosat import EuroSatDataset + dataset = EuroSatDataset(batch_size=batch_size, patch_size=patch_size, device=device) + input_dim, output_dim = 3 * patch_size**2, 10 + use_embeddings = False + experiment_type = 'classification' + + elif task.split('-')[0] == 'phenocam': + from universal_computation.datasets.phenocam import PhenoCamDataset + dataset = PhenoCamDataset(batch_size=batch_size, patch_size=patch_size, device=device, site=task.split('-')[1]) + input_dim, output_dim = 3 * patch_size**2, 3 + use_embeddings = False + experiment_type = 'classification' else: raise NotImplementedError('dataset not implemented') diff --git a/universal_computation/models/__init__.py b/universal_computation/models/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/universal_computation/models/lstm.py b/universal_computation/models/lstm.py deleted file mode 100644 index a5f9712e..00000000 --- a/universal_computation/models/lstm.py +++ /dev/null @@ -1,398 +0,0 @@ -""" -MIT License - -Copyright (c) 2018 Alex - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. - -Code modified from Github repo: https://github.com/exe1023/LSTM_LN -""" - -import math -import torch -import torch.nn as nn -import torch.nn.functional as F -from torch import Tensor -from torch.nn import Parameter -from torch.autograd import Variable - - -use_cuda = torch.cuda.is_available() - - -class LNLSTM(nn.Module): - def __init__(self, - input_size, - hidden_size, - num_layers=1, - dropout=0., - bidirectional=1, - batch_first=False, - residual=False, - cln=False): - super(LNLSTM, self).__init__() - self.input_size = input_size - self.hidden_size = hidden_size - self.num_layers = num_layers - self.direction = bidirectional + 1 - self.batch_first = batch_first - self.residual = residual - - layers = [] - for i in range(num_layers): - for j in range(self.direction): - layer = LayerNormLSTM(input_size*self.direction, - hidden_size, - dropout=dropout, - cln=cln) - layers.append(layer) - input_size = hidden_size - self.layers = layers - self.params = nn.ModuleList(layers) - - def reset_parameters(self): - for l in self.layers: - l.reset_parameters() - - def init_hidden(self, batch_size): - # Uses Xavier init here. - hiddens = [] - for l in self.layers: - std = math.sqrt(2.0 / (l.input_size + l.hidden_size)) - h = Variable(Tensor(1, batch_size, l.hidden_size).normal_(0, std)) - c = Variable(Tensor(1, batch_size, l.hidden_size).normal_(0, std)) - if use_cuda: - hiddens.append((h.cuda(), c.cuda())) - else: - hiddens.append((h, c)) - return hiddens - - def layer_forward(self, l, xs, h, image_emb, reverse=False): - ''' - return: - xs: (seq_len, batch, hidden) - h: (1, batch, hidden) - ''' - if self.batch_first: - xs = xs.permute(1, 0, 2).contiguous() - ys = [] - for i in range(xs.size(0)): - if reverse: - x = xs.narrow(0, (xs.size(0)-1)-i, 1) - else: - x = xs.narrow(0, i, 1) - y, h = l(x, h, image_emb) - ys.append(y) - y = torch.cat(ys, 0) - if self.batch_first: - y = y.permute(1, 0, 2) - return y, h - - def forward(self, x, hiddens=None, image_emb=None): - if hiddens is None: - hiddens = self.init_hidden(x.shape[0]) - if self.direction > 1: - x = torch.cat((x, x), 2) - if type(hiddens) != list: - # when the hidden feed is (direction * num_layer, batch, hidden) - tmp = [] - for idx in range(hiddens[0].size(0)): - tmp.append((hiddens[0].narrow(0, idx, 1), - (hiddens[1].narrow(0, idx, 1)))) - hiddens = tmp - - new_hs = [] - new_cs = [] - for l_idx in range(0, len(self.layers), self.direction): - l, h = self.layers[l_idx], hiddens[l_idx] - f_x, f_h = self.layer_forward(l, x, h, image_emb) - if self.direction > 1: - l, h = self.layers[l_idx+1], hiddens[l_idx+1] - r_x, r_h = self.layer_forward(l, x, h, image_emb, reverse=True) - - x = torch.cat((f_x, r_x), 2) - h = torch.cat((f_h[0], r_h[0]), 0) - c = torch.cat((f_h[1], r_h[1]), 0) - else: - if self.residual: - x = x + f_x - else: - x = f_x - h, c = f_h - new_hs.append(h) - new_cs.append(c) - - h = torch.cat(new_hs, 0) - c = torch.cat(new_cs, 0) - - return x, (h, c) - - -class CLN(nn.Module): - """ - Conditioned Layer Normalization - """ - def __init__(self, input_size, image_size, epsilon=1e-6): - super(CLN, self).__init__() - self.input_size = input_size - self.image_size = image_size - self.alpha = Tensor(1, input_size).fill_(1) - self.beta = Tensor(1, input_size).fill_(0) - self.epsilon = epsilon - - self.alpha = Parameter(self.alpha) - self.beta = Parameter(self.beta) - - # MLP used to predict delta of alpha, beta - self.fc_alpha = nn.Linear(self.image_size, self.input_size) - self.fc_beta = nn.Linear(self.image_size, self.input_size) - - self.reset_parameters() - - def reset_parameters(self): - std = 1.0 / math.sqrt(self.input_size) - for w in self.parameters(): - w.data.uniform_(-std, std) - - def create_cln_input(self, image_emb): - delta_alpha = self.fc_alpha(image_emb) - delta_beta = self.fc_beta(image_emb) - return delta_alpha, delta_beta - - def forward(self, x, image_emb): - if image_emb is None: - return x - # x: (batch, input_size) - size = x.size() - x = x.view(x.size(0), -1) - x = (x - torch.mean(x, 1).unsqueeze(1).expand_as(x)) / torch.sqrt(torch.var(x, 1).unsqueeze(1).expand_as(x) + self.epsilon) - - delta_alpha, delta_beta = self.create_cln_input(image_emb) - alpha = self.alpha.expand_as(x) + delta_alpha - beta = self.beta.expand_as(x) + delta_beta - x = alpha * x + beta - return x.view(size) - - -class LayerNorm(nn.Module): - """ - Layer Normalization based on Ba & al.: - 'Layer Normalization' - https://arxiv.org/pdf/1607.06450.pdf - """ - - def __init__(self, input_size, learnable=True, epsilon=1e-6): - super(LayerNorm, self).__init__() - self.input_size = input_size - self.learnable = learnable - self.alpha = Tensor(1, input_size).fill_(1) - self.beta = Tensor(1, input_size).fill_(0) - self.epsilon = epsilon - # Wrap as parameters if necessary - if learnable: - W = Parameter - else: - W = Variable - self.alpha = W(self.alpha) - self.beta = W(self.beta) - self.reset_parameters() - - def reset_parameters(self): - std = 1.0 / math.sqrt(self.input_size) - for w in self.parameters(): - w.data.uniform_(-std, std) - - def forward(self, x): - size = x.size() - x = x.view(x.size(0), -1) - x = (x - torch.mean(x, 1).unsqueeze(1).expand_as(x)) / torch.sqrt(torch.var(x, 1).unsqueeze(1).expand_as(x) + self.epsilon) - if self.learnable: - x = self.alpha.expand_as(x) * x + self.beta.expand_as(x) - return x.view(size) - - -class LSTMcell(nn.Module): - - """ - An implementation of Hochreiter & Schmidhuber: - 'Long-Short Term Memory' - http://www.bioinf.jku.at/publications/older/2604.pdf - Special args: - dropout_method: one of - * pytorch: default dropout implementation - * gal: uses GalLSTM's dropout - * moon: uses MoonLSTM's dropout - * semeniuta: uses SemeniutaLSTM's dropout - """ - - def __init__(self, input_size, hidden_size, bias=True, dropout=0.0, dropout_method='pytorch'): - super(LSTMcell, self).__init__() - self.input_size = input_size - self.hidden_size = hidden_size - self.bias = bias - self.dropout = dropout - self.i2h = nn.Linear(input_size, 4 * hidden_size, bias=bias) - self.h2h = nn.Linear(hidden_size, 4 * hidden_size, bias=bias) - self.reset_parameters() - assert(dropout_method.lower() in ['pytorch', 'gal', 'moon', 'semeniuta']) - self.dropout_method = dropout_method - - def sample_mask(self): - keep = 1.0 - self.dropout - self.mask = Variable(torch.bernoulli(Tensor(1, self.hidden_size).fill_(keep))) - - def reset_parameters(self): - std = 1.0 / math.sqrt(self.hidden_size) - for w in self.parameters(): - w.data.uniform_(-std, std) - - def forward(self, x, hidden): - do_dropout = self.training and self.dropout > 0.0 - h, c = hidden - h = h.view(h.size(1), -1) - c = c.view(c.size(1), -1) - x = x.view(x.size(1), -1) - - # Linear mappings - preact = self.i2h(x) + self.h2h(h) - - # activations - gates = preact[:, :3 * self.hidden_size].sigmoid() - g_t = preact[:, 3 * self.hidden_size:].tanh() - i_t = gates[:, :self.hidden_size] - f_t = gates[:, self.hidden_size:2 * self.hidden_size] - o_t = gates[:, -self.hidden_size:] - - # cell computations - if do_dropout and self.dropout_method == 'semeniuta': - g_t = F.dropout(g_t, p=self.dropout, training=self.training) - - c_t = torch.mul(c, f_t) + torch.mul(i_t, g_t) - - if do_dropout and self.dropout_method == 'moon': - c_t.data.set_(torch.mul(c_t, self.mask).data) - c_t.data *= 1.0/(1.0 - self.dropout) - - h_t = torch.mul(o_t, c_t.tanh()) - - # Reshape for compatibility - if do_dropout: - if self.dropout_method == 'pytorch': - F.dropout(h_t, p=self.dropout, training=self.training, inplace=True) - if self.dropout_method == 'gal': - h_t.data.set_(th.mul(h_t, self.mask).data) - h_t.data *= 1.0/(1.0 - self.dropout) - - h_t = h_t.view(1, h_t.size(0), -1) - c_t = c_t.view(1, c_t.size(0), -1) - return h_t, (h_t, c_t) - - -class LayerNormLSTM(LSTMcell): - - """ - Layer Normalization LSTM, based on Ba & al.: - 'Layer Normalization' - https://arxiv.org/pdf/1607.06450.pdf - Special args: - ln_preact: whether to Layer Normalize the pre-activations. - learnable: whether the LN alpha & gamma should be used. - """ - - def __init__(self, - input_size, - hidden_size, - bias=True, - dropout=0.0, - dropout_method='pytorch', - ln_preact=True, - learnable=True, - cln=True): - super(LayerNormLSTM, self).__init__(input_size=input_size, - hidden_size=hidden_size, - bias=bias, - dropout=dropout, - dropout_method=dropout_method) - self.cln = cln - if ln_preact: - if self.cln: - self.ln_i2h = CLN(4*hidden_size, 1024) - self.ln_h2h = CLN(4*hidden_size, 1024) - else: - self.ln_h2h = LayerNorm(4*hidden_size, learnable=learnable) - self.ln_i2h = LayerNorm(4*hidden_size, learnable=learnable) - self.ln_preact = ln_preact - if self.cln: - self.ln_cell = CLN(hidden_size, 1024) - else: - self.ln_cell = LayerNorm(hidden_size, learnable=learnable) - - def forward(self, x, hidden, image_emb=None): - do_dropout = self.training and self.dropout > 0.0 - h, c = hidden - h = h.view(h.size(1), -1) - c = c.view(c.size(1), -1) - x = x.view(x.size(1), -1) - - # Linear mappings - i2h = self.i2h(x) - h2h = self.h2h(h) - if self.ln_preact: - if self.cln: - i2h = self.ln_i2h(i2h, image_emb) - h2h = self.ln_h2h(h2h, image_emb) - else: - i2h = self.ln_i2h(i2h) - h2h = self.ln_h2h(h2h) - preact = i2h + h2h - - # activations - gates = preact[:, :3 * self.hidden_size].sigmoid() - g_t = preact[:, 3 * self.hidden_size:].tanh() - i_t = gates[:, :self.hidden_size] - f_t = gates[:, self.hidden_size:2 * self.hidden_size] - o_t = gates[:, -self.hidden_size:] - - # cell computations - if do_dropout and self.dropout_method == 'semeniuta': - g_t = F.dropout(g_t, p=self.dropout, training=self.training) - - c_t = torch.mul(c, f_t) + torch.mul(i_t, g_t) - - if do_dropout and self.dropout_method == 'moon': - c_t.data.set_(torch.mul(c_t, self.mask).data) - c_t.data *= 1.0/(1.0 - self.dropout) - - if self.cln: - c_t = self.ln_cell(c_t, image_emb) - else: - c_t = self.ln_cell(c_t) - h_t = torch.mul(o_t, c_t.tanh()) - - # Reshape for compatibility - if do_dropout: - if self.dropout_method == 'pytorch': - F.dropout(h_t, p=self.dropout, training=self.training, inplace=True) - if self.dropout_method == 'gal': - h_t.data.set_(torch.mul(h_t, self.mask).data) - h_t.data *= 1.0/(1.0 - self.dropout) - - h_t = h_t.view(1, h_t.size(0), -1) - c_t = c_t.view(1, c_t.size(0), -1) - return h_t, (h_t, c_t)