diff --git a/configs/selfsup/_base_/datasets/gRNA/K562_pretrain.py b/configs/selfsup/_base_/datasets/gRNA/K562_pretrain.py new file mode 100644 index 0000000..4baf5bf --- /dev/null +++ b/configs/selfsup/_base_/datasets/gRNA/K562_pretrain.py @@ -0,0 +1,37 @@ +# dataset settings +data_root = 'data/on_target_K562/train/' +data_source_cfg = dict( + type='BioSeqDataset', + file_list=None, # use all splits + word_splitor="", data_splitor="\t", mapping_name="ACGT", # gRNA tokenize + has_labels=True, return_label=False, # pre-training + max_data_length=int(1e7), + data_type="regression", +) + +dataset_type = 'ExtractDataset' +sample_norm_cfg = dict(mean=[0,], std=[1,]) +train_pipeline = [ + dict(type='ToTensor'), +] +test_pipeline = [ + dict(type='ToTensor'), +] +# prefetch +prefetch = False + +data = dict( + samples_per_gpu=256, + workers_per_gpu=4, + drop_last=True, + train=dict( + type=dataset_type, + data_source=dict( + root=data_root, **data_source_cfg), + pipeline=train_pipeline, + prefetch=prefetch, + ), +) + +# checkpoint +checkpoint_config = dict(interval=200, max_keep_ckpts=1) diff --git a/configs/selfsup/_base_/datasets/gRNA/gRNA_pretrain.py b/configs/selfsup/_base_/datasets/gRNA/gRNA_pretrain.py index 8eb635d..389db5c 100644 --- a/configs/selfsup/_base_/datasets/gRNA/gRNA_pretrain.py +++ b/configs/selfsup/_base_/datasets/gRNA/gRNA_pretrain.py @@ -1,13 +1,15 @@ # dataset settings -data_root = 'data/on_target_K562/' +data_root = 'data/gRNA_pretrain/' data_source_cfg = dict( type='BioSeqDataset', file_list=None, # use all splits word_splitor="", data_splitor="\t", mapping_name="ACGT", # gRNA tokenize + has_labels=False, return_label=False, # pre-training + max_data_length=int(1e7), data_type="regression", ) -dataset_type = 'RegressionDataset' +dataset_type = 'ExtractDataset' sample_norm_cfg = dict(mean=[0,], std=[1,]) train_pipeline = [ dict(type='ToTensor'), @@ -25,8 +27,7 @@ train=dict( type=dataset_type, data_source=dict( - root=data_root+"train", - **data_source_cfg), + root=data_root, **data_source_cfg), pipeline=train_pipeline, prefetch=prefetch, ), diff --git a/configs/selfsup/gRNA/transformer/bert/layer4_spin_p2_h4_d64_init_bs256.py b/configs/selfsup/gRNA/transformer/bert/layer4_spin_p2_h4_d64_init_bs256.py new file mode 100644 index 0000000..92d29bd --- /dev/null +++ b/configs/selfsup/gRNA/transformer/bert/layer4_spin_p2_h4_d64_init_bs256.py @@ -0,0 +1,90 @@ +_base_ = [ + '../../../_base_/datasets/gRNA/gRNA_pretrain.py', + '../../../_base_/default_runtime.py', +] + +embed_dim = 64 +patch_size = 2 +seq_len = 63 + +# model settings +model = dict( + type='BERT', + pretrained=None, + mask_ratio=0.15, # BERT 15% + spin_stride=[1, 2, 4], + backbone=dict( + type='SimMIMTransformer', + arch=dict( + embed_dims=embed_dim, + num_layers=4, + num_heads=4, + feedforward_channels=embed_dim * 4, + ), + in_channels=4, + patch_size=patch_size, + seq_len=int(seq_len / patch_size) + bool(seq_len % patch_size != 0), + mask_layer=0, + mask_ratio=0.15, # BERT 15% + mask_token='learnable', + # mask_token='zero', + norm_cfg=dict(type='LN', eps=1e-6), + drop_rate=0., # no dropout for pre-training + drop_path_rate=0.1, + final_norm=True, + out_indices=-1, # last layer + with_cls_token=True, + output_cls_token=True, + ), + neck=dict( + type='BERTMLMNeck', feature_Nd="1d", + in_channels=embed_dim, out_channels=4, encoder_stride=patch_size), + head=dict( + type='MIMHead', + loss=dict(type='CrossEntropyLoss', + use_soft=True, use_sigmoid=False, loss_weight=1.0), + feature_Nd="1d", unmask_weight=0., encoder_in_channels=4, + ), + init_cfg=[ + dict(type='TruncNormal', layer=['Conv1d', 'Linear'], std=0.02, bias=0.), + dict(type='Constant', layer=['LayerNorm'], val=1., bias=0.) + ], +) + +# dataset +data = dict(samples_per_gpu=256, workers_per_gpu=4) + +# optimizer +optimizer = dict( + type='AdamW', + lr=1e-3, + weight_decay=1e-2, eps=1e-8, betas=(0.9, 0.999), + paramwise_options={ + '(bn|ln|gn)(\d+)?.(weight|bias)': dict(weight_decay=0.), + 'norm': dict(weight_decay=0.), + 'bias': dict(weight_decay=0.), + 'cls_token': dict(weight_decay=0.), + 'pos_embed': dict(weight_decay=0.), + 'mask_token': dict(weight_decay=0.), + }) + +# apex +use_fp16 = False +fp16 = dict(type='mmcv', loss_scale=dict(mode='dynamic')) +optimizer_config = dict( + grad_clip=dict(max_norm=1000.0), update_interval=1) + +# learning policy +lr_config = dict( + policy='CosineAnnealing', + by_epoch=False, min_lr=1e-5, + warmup='linear', + warmup_iters=5, warmup_by_epoch=True, + warmup_ratio=1e-5, +) + +# checkpoint +checkpoint_config = dict(interval=200, max_keep_ckpts=1) + +# runtime settings +runner = dict(type='EpochBasedRunner', max_epochs=100) diff --git a/docs/CHANGELOG.md b/docs/CHANGELOG.md index 1434aaf..ce6eb4e 100644 --- a/docs/CHANGELOG.md +++ b/docs/CHANGELOG.md @@ -4,10 +4,11 @@ #### Highlight * Support various popular backbones (ConvNets and ViTs), various image datasets, popular mixup methods, and benchmarks for supervised learning. Config files are available. -* Support popular self-supervised methods (e.g., BYOL, MoCo.V3, MAE) on both large-scale and small-scale datasets, and self-supervised benchmarks (merged from MMSelfSup). Config files are available. +* Support popular self-supervised methods (e.g., BYOL, MoCo.V3, MAE) on both large-scale and small-scale datasets, and self-supervised benchmarks (merged from MMSelfSup). Config files are available. Support BERT pre-training method and update config files. * Support analyzing tools for self-supervised learning (kNN/SVM/linear metrics and t-SNE/UMAP visualization). * Convenient usage of configs: fast configs generation by 'auto_train.py' and configs inheriting (MMCV). * Support mixed-precision training (NVIDIA Apex or MMCV Apex). +* Refactor `openbioseq.core` and support Adan optimizer. #### Bug Fixes * Done code refactoring follows MMSelfSup and MMClassification. diff --git a/openbioseq/core/optimizer/__init__.py b/openbioseq/core/optimizer/__init__.py index b4e8752..e9c5778 100644 --- a/openbioseq/core/optimizer/__init__.py +++ b/openbioseq/core/optimizer/__init__.py @@ -1,8 +1,10 @@ +from .adan import Adan from .builder import build_optimizer from .constructor import DefaultOptimizerConstructor, TransformerFinetuneConstructor -from .optimizers import LARS, LAMB +from .lamb import LAMB +from .lars import LARS __all__ = [ - 'LARS', 'LAMB', 'build_optimizer', + 'Adan', 'LARS', 'LAMB', 'build_optimizer', 'DefaultOptimizerConstructor', 'TransformerFinetuneConstructor' ] diff --git a/openbioseq/core/optimizer/adan.py b/openbioseq/core/optimizer/adan.py new file mode 100644 index 0000000..cd38e5a --- /dev/null +++ b/openbioseq/core/optimizer/adan.py @@ -0,0 +1,312 @@ +# Copyright 2022 Garena Online Private Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import List + +import torch +from torch import Tensor +from torch.optim.optimizer import Optimizer + +from mmcv.runner.optimizer.builder import OPTIMIZERS + + +@OPTIMIZERS.register_module() +class Adan(Optimizer): + """Implements a pytorch variant of Adan. + + Adan was proposed in + Adan : Adaptive Nesterov Momentum Algorithm for Faster Optimizing Deep Models. + https://arxiv.org/abs/2208.06677 + Arguments: + params (iterable): iterable of parameters to optimize + or dicts defining parameter groups. + lr (float, optional): learning rate. (default: 1e-3) + betas (Tuple[float, float, flot], optional): coefficients used + for computing running averages of gradient. + (default: (0.98, 0.92, 0.99)) + eps (float, optional): term added to the denominator to improve + numerical stability. (default: 1e-8) + weight_decay (float, optional): decoupled weight decay + (L2 penalty) (default: 0) + max_grad_norm (float, optional): value used to clip + global grad norm (default: 0.0 no clip) + no_prox (bool): how to perform the decoupled weight decay + (default: False) + foreach (bool): if True would use torch._foreach implementation. + It's faster but uses slightly more memory. + """ + + def __init__(self, + params, + lr=1e-3, + betas=(0.98, 0.92, 0.99), + eps=1e-8, + weight_decay=0.0, + max_grad_norm=0.0, + no_prox=False, + foreach: bool = True): + if not 0.0 <= max_grad_norm: + raise ValueError('Invalid Max grad norm: {}'.format(max_grad_norm)) + if not 0.0 <= lr: + raise ValueError('Invalid learning rate: {}'.format(lr)) + if not 0.0 <= eps: + raise ValueError('Invalid epsilon value: {}'.format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError('Invalid beta parameter at index 0: {}'.format( + betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError('Invalid beta parameter at index 1: {}'.format( + betas[1])) + if not 0.0 <= betas[2] < 1.0: + raise ValueError('Invalid beta parameter at index 2: {}'.format( + betas[2])) + defaults = dict( + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + max_grad_norm=max_grad_norm, + no_prox=no_prox, + foreach=foreach) + super().__init__(params, defaults) + + def __setstate__(self, state): + super(Adan, self).__setstate__(state) + for group in self.param_groups: + group.setdefault('no_prox', False) + + @torch.no_grad() + def restart_opt(self): + for group in self.param_groups: + group['step'] = 0 + for p in group['params']: + if p.requires_grad: + state = self.state[p] + # State initialization + + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(p) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like(p) + # Exponential moving average of gradient difference + state['exp_avg_diff'] = torch.zeros_like(p) + + @torch.no_grad() + def step(self): + """Performs a single optimization step.""" + if self.defaults['max_grad_norm'] > 0: + device = self.param_groups[0]['params'][0].device + global_grad_norm = torch.zeros(1, device=device) + + max_grad_norm = torch.tensor( + self.defaults['max_grad_norm'], device=device) + for group in self.param_groups: + + for p in group['params']: + if p.grad is not None: + grad = p.grad + global_grad_norm.add_(grad.pow(2).sum()) + + global_grad_norm = torch.sqrt(global_grad_norm) + group['eps'] + + clip_global_grad_norm = \ + torch.clamp(max_grad_norm / global_grad_norm, max=1.0) + else: + clip_global_grad_norm = 1.0 + + for group in self.param_groups: + params_with_grad = [] + grads = [] + exp_avgs = [] + exp_avg_sqs = [] + exp_avg_diffs = [] + pre_grads = [] + + beta1, beta2, beta3 = group['betas'] + # assume same step across group now to simplify things + # per parameter step can be easily support + # by making it tensor, or pass list into kernel + if 'step' in group: + group['step'] += 1 + else: + group['step'] = 1 + + bias_correction1 = 1.0 - beta1**group['step'] + bias_correction2 = 1.0 - beta2**group['step'] + bias_correction3 = 1.0 - beta3**group['step'] + + for p in group['params']: + if p.grad is None: + continue + params_with_grad.append(p) + grads.append(p.grad) + + state = self.state[p] + if len(state) == 0: + state['exp_avg'] = torch.zeros_like(p) + state['exp_avg_sq'] = torch.zeros_like(p) + state['exp_avg_diff'] = torch.zeros_like(p) + + if 'pre_grad' not in state or group['step'] == 1: + # at first step grad wouldn't be clipped + # by `clip_global_grad_norm` + # this is only to simplify implementation + state['pre_grad'] = p.grad + + exp_avgs.append(state['exp_avg']) + exp_avg_sqs.append(state['exp_avg_sq']) + exp_avg_diffs.append(state['exp_avg_diff']) + pre_grads.append(state['pre_grad']) + + kwargs = dict( + params=params_with_grad, + grads=grads, + exp_avgs=exp_avgs, + exp_avg_sqs=exp_avg_sqs, + exp_avg_diffs=exp_avg_diffs, + pre_grads=pre_grads, + beta1=beta1, + beta2=beta2, + beta3=beta3, + bias_correction1=bias_correction1, + bias_correction2=bias_correction2, + bias_correction3_sqrt=math.sqrt(bias_correction3), + lr=group['lr'], + weight_decay=group['weight_decay'], + eps=group['eps'], + no_prox=group['no_prox'], + clip_global_grad_norm=clip_global_grad_norm, + ) + if group['foreach']: + copy_grads = _multi_tensor_adan(**kwargs) + else: + copy_grads = _single_tensor_adan(**kwargs) + + for p, copy_grad in zip(params_with_grad, copy_grads): + self.state[p]['pre_grad'] = copy_grad + + +def _single_tensor_adan( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + exp_avg_diffs: List[Tensor], + pre_grads: List[Tensor], + *, + beta1: float, + beta2: float, + beta3: float, + bias_correction1: float, + bias_correction2: float, + bias_correction3_sqrt: float, + lr: float, + weight_decay: float, + eps: float, + no_prox: bool, + clip_global_grad_norm: Tensor, +): + copy_grads = [] + for i, param in enumerate(params): + grad = grads[i] + exp_avg = exp_avgs[i] + exp_avg_sq = exp_avg_sqs[i] + exp_avg_diff = exp_avg_diffs[i] + pre_grad = pre_grads[i] + + grad = grad.mul_(clip_global_grad_norm) + copy_grads.append(grad.clone()) + + diff = grad - pre_grad + update = grad + beta2 * diff + + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) # m_t + exp_avg_diff.mul_(beta2).add_(diff, alpha=1 - beta2) # diff_t + exp_avg_sq.mul_(beta3).addcmul_(update, update, value=1 - beta3) # n_t + + denom = (exp_avg_sq.sqrt() / bias_correction3_sqrt).add_(eps) + update = exp_avg / bias_correction1 + update.add_(beta2 * exp_avg_diff / bias_correction2).div_(denom) + + if no_prox: + param.mul_(1 - lr * weight_decay) + param.add_(update, alpha=-lr) + else: + param.add_(update, alpha=-lr) + param.div_(1 + lr * weight_decay) + return copy_grads + + +def _multi_tensor_adan( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + exp_avg_diffs: List[Tensor], + pre_grads: List[Tensor], + *, + beta1: float, + beta2: float, + beta3: float, + bias_correction1: float, + bias_correction2: float, + bias_correction3_sqrt: float, + lr: float, + weight_decay: float, + eps: float, + no_prox: bool, + clip_global_grad_norm: Tensor, +): + if clip_global_grad_norm < 1.0: + torch._foreach_mul_(grads, clip_global_grad_norm.item()) + copy_grads = [g.clone() for g in grads] + + diff = torch._foreach_sub(grads, pre_grads) + # NOTE: line below while looking identical gives different result, + # due to float precision errors. + # using mul+add produces identical results to single-tensor, + # using add+alpha doesn't + # update = torch._foreach_add(grads, torch._foreach_mul(diff, beta2)) + update = torch._foreach_add(grads, diff, alpha=beta2) + + torch._foreach_mul_(exp_avgs, beta1) + torch._foreach_add_(exp_avgs, grads, alpha=1 - beta1) # m_t + + torch._foreach_mul_(exp_avg_diffs, beta2) + torch._foreach_add_(exp_avg_diffs, diff, alpha=1 - beta2) # diff_t + + torch._foreach_mul_(exp_avg_sqs, beta3) + torch._foreach_addcmul_( + exp_avg_sqs, update, update, value=1 - beta3) # n_t + + denom = torch._foreach_sqrt(exp_avg_sqs) + torch._foreach_div_(denom, bias_correction3_sqrt) + torch._foreach_add_(denom, eps) + + update = torch._foreach_div(exp_avgs, bias_correction1) + # NOTE: same issue as above. + # beta2 * diff / bias_correction2 != diff * (beta2 / bias_correction2) # noqa + # using faster version by default. uncomment for tests to pass + # torch._foreach_add_(update, torch._foreach_div(torch._foreach_mul(exp_avg_diffs, beta2), bias_correction2)) # noqa + torch._foreach_add_( + update, torch._foreach_mul(exp_avg_diffs, beta2 / bias_correction2)) + torch._foreach_div_(update, denom) + + if no_prox: + torch._foreach_mul_(params, 1 - lr * weight_decay) + else: + torch._foreach_add_(params, update, alpha=-lr) + torch._foreach_div_(params, 1 + lr * weight_decay) + return copy_grads diff --git a/openbioseq/core/optimizer/optimizers.py b/openbioseq/core/optimizer/lamb.py similarity index 64% rename from openbioseq/core/optimizer/optimizers.py rename to openbioseq/core/optimizer/lamb.py index 0134f20..906c4df 100644 --- a/openbioseq/core/optimizer/optimizers.py +++ b/openbioseq/core/optimizer/lamb.py @@ -1,123 +1,46 @@ +# Copyright (c) 2021, Habana Labs Ltd. All rights reserved. + +# Copyright (c) 2019-2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# MIT License +# +# Copyright (c) 2019 cybertronai +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + import math import torch from mmcv.runner.optimizer.builder import OPTIMIZERS -from torch.optim.optimizer import Optimizer, required -from torch.optim import * - - -@OPTIMIZERS.register_module() -class LARS(Optimizer): - r"""Implements layer-wise adaptive rate scaling for SGD. - - Args: - params (iterable): iterable of parameters to optimize or dicts defining - parameter groups - lr (float): base learning rate (\gamma_0) - momentum (float, optional): momentum factor (default: 0) ("m") - weight_decay (float, optional): weight decay (L2 penalty) (default: 0) - ("\beta") - dampening (float, optional): dampening for momentum (default: 0) - eta (float, optional): LARS coefficient - nesterov (bool, optional): enables Nesterov momentum (default: False) - - Based on Algorithm 1 of the following paper by You, Gitman, and Ginsburg. - Large Batch Training of Convolutional Networks: - https://arxiv.org/abs/1708.03888 - - Example: - >>> optimizer = LARS(model.parameters(), lr=0.1, momentum=0.9, - >>> weight_decay=1e-4, eta=1e-3) - >>> optimizer.zero_grad() - >>> loss_fn(model(input), target).backward() - >>> optimizer.step() - """ - - def __init__(self, - params, - lr=required, - momentum=0, - dampening=0, - weight_decay=0, - eta=0.001, - nesterov=False): - if lr is not required and lr < 0.0: - raise ValueError("Invalid learning rate: {}".format(lr)) - if momentum < 0.0: - raise ValueError("Invalid momentum value: {}".format(momentum)) - if weight_decay < 0.0: - raise ValueError( - "Invalid weight_decay value: {}".format(weight_decay)) - if eta < 0.0: - raise ValueError("Invalid LARS coefficient value: {}".format(eta)) - - defaults = dict( - lr=lr, momentum=momentum, dampening=dampening, - weight_decay=weight_decay, nesterov=nesterov, eta=eta) - if nesterov and (momentum <= 0 or dampening != 0): - raise ValueError("Nesterov momentum requires a momentum and zero dampening") - - super(LARS, self).__init__(params, defaults) - - def __setstate__(self, state): - super(LARS, self).__setstate__(state) - for group in self.param_groups: - group.setdefault('nesterov', False) - - @torch.no_grad() - def step(self, closure=None): - """Performs a single optimization step. - - Args: - closure (callable, optional): A closure that reevaluates the model - and returns the loss. - """ - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - for group in self.param_groups: - weight_decay = group['weight_decay'] - momentum = group['momentum'] - dampening = group['dampening'] - eta = group['eta'] - nesterov = group['nesterov'] - lr = group['lr'] - lars_exclude = group.get('lars_exclude', False) - - for p in group['params']: - if p.grad is None: - continue - - d_p = p.grad - - if lars_exclude: - local_lr = 1. - else: - weight_norm = torch.norm(p).item() - grad_norm = torch.norm(d_p).item() - # Compute local learning rate for this layer - local_lr = eta * weight_norm / \ - (grad_norm + weight_decay * weight_norm) - - actual_lr = local_lr * lr - d_p = d_p.add(p, alpha=weight_decay).mul(actual_lr) - if momentum != 0: - param_state = self.state[p] - if 'momentum_buffer' not in param_state: - buf = param_state['momentum_buffer'] = \ - torch.clone(d_p).detach() - else: - buf = param_state['momentum_buffer'] - buf.mul_(momentum).add_(d_p, alpha=1 - dampening) - if nesterov: - d_p = d_p.add(buf, alpha=momentum) - else: - d_p = buf - p.add_(-d_p) - - return loss +from torch.optim.optimizer import Optimizer @OPTIMIZERS.register_module() diff --git a/openbioseq/core/optimizer/lars.py b/openbioseq/core/optimizer/lars.py new file mode 100644 index 0000000..4794520 --- /dev/null +++ b/openbioseq/core/optimizer/lars.py @@ -0,0 +1,118 @@ +import torch +from mmcv.runner.optimizer.builder import OPTIMIZERS +from torch.optim.optimizer import Optimizer, required +from torch.optim import * + + +@OPTIMIZERS.register_module() +class LARS(Optimizer): + r"""Implements layer-wise adaptive rate scaling for SGD. + + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float): base learning rate (\gamma_0) + momentum (float, optional): momentum factor (default: 0) ("m") + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + ("\beta") + dampening (float, optional): dampening for momentum (default: 0) + eta (float, optional): LARS coefficient + nesterov (bool, optional): enables Nesterov momentum (default: False) + + Based on Algorithm 1 of the following paper by You, Gitman, and Ginsburg. + Large Batch Training of Convolutional Networks: + https://arxiv.org/abs/1708.03888 + + Example: + >>> optimizer = LARS(model.parameters(), lr=0.1, momentum=0.9, + >>> weight_decay=1e-4, eta=1e-3) + >>> optimizer.zero_grad() + >>> loss_fn(model(input), target).backward() + >>> optimizer.step() + """ + + def __init__(self, + params, + lr=required, + momentum=0, + dampening=0, + weight_decay=0, + eta=0.001, + nesterov=False): + if lr is not required and lr < 0.0: + raise ValueError("Invalid learning rate: {}".format(lr)) + if momentum < 0.0: + raise ValueError("Invalid momentum value: {}".format(momentum)) + if weight_decay < 0.0: + raise ValueError( + "Invalid weight_decay value: {}".format(weight_decay)) + if eta < 0.0: + raise ValueError("Invalid LARS coefficient value: {}".format(eta)) + + defaults = dict( + lr=lr, momentum=momentum, dampening=dampening, + weight_decay=weight_decay, nesterov=nesterov, eta=eta) + if nesterov and (momentum <= 0 or dampening != 0): + raise ValueError("Nesterov momentum requires a momentum and zero dampening") + + super(LARS, self).__init__(params, defaults) + + def __setstate__(self, state): + super(LARS, self).__setstate__(state) + for group in self.param_groups: + group.setdefault('nesterov', False) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Args: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + weight_decay = group['weight_decay'] + momentum = group['momentum'] + dampening = group['dampening'] + eta = group['eta'] + nesterov = group['nesterov'] + lr = group['lr'] + lars_exclude = group.get('lars_exclude', False) + + for p in group['params']: + if p.grad is None: + continue + + d_p = p.grad + + if lars_exclude: + local_lr = 1. + else: + weight_norm = torch.norm(p).item() + grad_norm = torch.norm(d_p).item() + # Compute local learning rate for this layer + local_lr = eta * weight_norm / \ + (grad_norm + weight_decay * weight_norm) + + actual_lr = local_lr * lr + d_p = d_p.add(p, alpha=weight_decay).mul(actual_lr) + if momentum != 0: + param_state = self.state[p] + if 'momentum_buffer' not in param_state: + buf = param_state['momentum_buffer'] = \ + torch.clone(d_p).detach() + else: + buf = param_state['momentum_buffer'] + buf.mul_(momentum).add_(d_p, alpha=1 - dampening) + if nesterov: + d_p = d_p.add(buf, alpha=momentum) + else: + d_p = buf + p.add_(-d_p) + + return loss diff --git a/openbioseq/datasets/data_sources/bio_seq_source.py b/openbioseq/datasets/data_sources/bio_seq_source.py index 9b3179a..a0e73ad 100644 --- a/openbioseq/datasets/data_sources/bio_seq_source.py +++ b/openbioseq/datasets/data_sources/bio_seq_source.py @@ -3,21 +3,26 @@ from torch.utils.data import Dataset, DataLoader from tqdm import tqdm +from openbioseq.utils import print_log from ..registry import DATASOURCES from ..utils import read_file -def binarize(data_list, mapping): +def binarize(data_list, mapping, max_seq_length=None, data_splitor=None): assert isinstance(data_list, list) and len(data_list) > 0 token_list = list() num_entries = max(mapping.values()) + 1 for _seq in data_list: + if data_splitor is not None: + _seq = _seq.split(data_splitor)[-1] try: + seq_len = min(len(_seq), max_seq_length) onehot_seq = torch.zeros( - num_entries, (len(_seq)), dtype=torch.float32) + num_entries, (seq_len), dtype=torch.float32) for _idx, _str in enumerate(_seq): - map_idx = mapping[_str] - onehot_seq[map_idx, _idx] = 1 + if _idx < seq_len: + map_idx = mapping[_str] + onehot_seq[map_idx, _idx] = 1 token_list.append(onehot_seq) except: print(f"Error seq:", _seq) @@ -27,19 +32,26 @@ def binarize(data_list, mapping): class TokenizeDataset(Dataset): """ Tokenize string to binary encoding """ - def __init__(self, data, mapping): + def __init__(self, data, mapping, max_seq_length=None, data_splitor=None): super().__init__() self.data = data self.mapping = mapping + self.max_seq_length = max_seq_length if max_seq_length is not None else int(1e10) + self.data_splitor = data_splitor self.num_entries = max(mapping.values()) + 1 - + def __getitem__(self, idx): _seq = self.data[idx] - onehot_seq = torch.zeros(self.num_entries, (len(_seq)), dtype=torch.float32) + if self.data_splitor is not None: + _seq = _seq.split(self.data_splitor)[-1] + seq_len = min(len(_seq), self.max_seq_length) + onehot_seq = torch.zeros( + self.num_entries, (seq_len), dtype=torch.float32) try: for _idx, _str in enumerate(_seq): - map_idx = self.mapping[_str] - onehot_seq[map_idx, _idx] = 1 + if _idx < seq_len: + map_idx = self.mapping[_str] + onehot_seq[map_idx, _idx] = 1 except: print(f"Error seq:", _seq) return onehot_seq @@ -74,7 +86,11 @@ def __init__(self, word_splitor="", data_splitor=" ", mapping_name="ACGT", - return_label=True, data_type="classification"): + has_labels=True, + return_label=True, + data_type="classification", + max_seq_length=None, + max_data_length=None): assert file_list is None or isinstance(file_list, list) assert word_splitor in ["", " ", ",", ";", ".",] assert data_splitor in [" ", ",", ";", ".", "\t",] @@ -89,9 +105,10 @@ def __init__(self, lines = list() for file in file_list: lines += read_file(os.path.join(root, file)) - self.has_labels = len(lines[0].split(data_splitor)) >= 2 + self.has_labels = len(lines[0].split(data_splitor)) >= 2 and has_labels self.return_label = return_label self.data_type = data_type + print_log("Total file length: {}".format(len(lines)), logger='root') # preprocess if self.has_labels: @@ -105,17 +122,24 @@ def __init__(self, else: # assert self.return_label is False self.labels = None - data = [l.strip()[-1:] for l in lines] + data = [l.strip() for l in lines] + if max_data_length is not None: + assert isinstance(max_data_length, (int, float)) + data = data[:max_data_length] + print_log("Used data length: {}".format(len(data)), logger='root') mapping = getattr(self, mapping_name) max_file_len = 100000 num_workers = max(6, int(max_file_len / max_file_len + 1)) + data_splitor = None if self.has_labels else data_splitor if num_workers <= 1: - self.data = binarize(data, mapping) + self.data = binarize( + data, mapping, max_seq_length=max_seq_length, data_splitor=data_splitor) else: tokens = None - tokenizer = TokenizeDataset(data, mapping) + tokenizer = TokenizeDataset( + data, mapping, max_seq_length=max_seq_length, data_splitor=data_splitor) process_loader = DataLoader(tokenizer, batch_size=num_workers * 1000, shuffle=False, num_workers=num_workers) for i, _tokens in tqdm(enumerate(process_loader)): diff --git a/openbioseq/models/selfsup/bert.py b/openbioseq/models/selfsup/bert.py index c69f867..3fd6727 100644 --- a/openbioseq/models/selfsup/bert.py +++ b/openbioseq/models/selfsup/bert.py @@ -1,3 +1,4 @@ +import random import torch from openbioseq.utils import print_log @@ -18,7 +19,8 @@ class BERT(BaseModel): backbone (dict): Config dict for encoder. neck (dict): Config dict for encoder. Defaults to None. head (dict): Config dict for loss functions. Defaults to None. - pretrained (str, optional): Path to pre-trained weights. Default: None. + mask_ratio (float): Masking ratio for MLM pre-training. Default to 0.15. + pretrained (str, optional): Path to pre-trained weights. Default to None. init_cfg (dict): Config dict for weight initialization. Defaults to None. """ @@ -28,6 +30,7 @@ def __init__(self, neck=None, head=None, mask_ratio=0.15, + spin_stride=[], pretrained=None, init_cfg=None, **kwargs): @@ -37,6 +40,9 @@ def __init__(self, self.neck = builder.build_neck(neck) self.head = builder.build_head(head) self.mask_ratio = mask_ratio + self.spin_stride = list() if not isinstance(spin_stride, (tuple, list)) \ + else list(spin_stride) + self.patch_size = getattr(self.backbone, 'patch_size', 1) if self.patch_size > 1: self.padding = AdaptivePadding1d( @@ -95,7 +101,14 @@ def forward_train(self, data, **kwargs): data = self.padding(data) B, _, L = data.size() - mask = torch.bernoulli(torch.full([1, L], self.mask_ratio)).cuda() + if len(self.spin_stride) > 0: + spin = random.choices(self.spin_stride, k=1)[0] + assert L % spin == 0 and spin >= 1 + spin_L = L // spin + mask = torch.bernoulli(torch.full([1, spin_L], self.mask_ratio)).cuda() + mask = mask.view(1, spin_L, 1).expand(1, spin_L, spin).reshape(1, L) + else: + mask = torch.bernoulli(torch.full([1, L], self.mask_ratio)).cuda() latent, _ = self.backbone(data, mask=None) latent = latent.reshape(-1, latent.size(2)) # (B, L, C) -> (BxL, C) data_rec = self.neck(latent)