Skip to content

Commit

Permalink
update setup in v0.1.1
Browse files Browse the repository at this point in the history
  • Loading branch information
Lupin1998 committed Jan 9, 2023
1 parent 40e1ce9 commit d8c1c9a
Show file tree
Hide file tree
Showing 12 changed files with 352 additions and 82 deletions.
7 changes: 3 additions & 4 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -104,17 +104,15 @@ venv.bak/
# mypy
.mypy_cache/

openmixup/version.py
version.py
# custom
data
.vscode
.idea

# custom
*.pkl
*.pkl.json
*.log.json
work_dirs/
/openbioseq/.mim
tools/exp_bash/
pretrains

Expand All @@ -138,3 +136,4 @@ configs/selfsup/processed
# configs/semisup
# configs/selfsup
*.json
*.toml
4 changes: 4 additions & 0 deletions .style.yapf
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
[style]
BASED_ON_STYLE = pep8
BLANK_LINE_BEFORE_NESTED_CLASS_OR_DEF = true
SPLIT_BEFORE_EXPRESSION_AFTER_OPENING_PAREN = true
2 changes: 1 addition & 1 deletion LICENSE
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@
same "printed page" as the copyright notice for easier
identification within third-party archives.

Copyright 2020-2021 Open-MMLab.
Copyright 2021-2022 CAIRI AI Lab.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down
2 changes: 2 additions & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
include requirements/*.txt
recursive-include openbioseq/.mim/tools *.sh *.py
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
# OpenBioSeq
[![PyPI](https://img.shields.io/pypi/v/OpenBioSeq)](https://pypi.org/project/OpenBioSeq)
[![license](https://img.shields.io/badge/license-Apache--2.0-%23B7A800)](https://github.com/Westlake-AI/OpenBioSeq/blob/main/LICENSE)
![open issues](https://img.shields.io/github/issues-raw/Westlake-AI/OpenBioSeq?color=%23FF9600)
[![issue resolution](https://img.shields.io/badge/issue%20resolution-1%20d-%23009763)](https://github.com/Westlake-AI/OpenBioSeq/issues)

**News**

Expand Down
61 changes: 59 additions & 2 deletions openbioseq/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,60 @@
from .version import __version__, short_version
import warnings

__all__ = ['__version__', 'short_version']
import mmcv
from packaging.version import parse

from .version import __version__


def digit_version(version_str: str, length: int = 4):
"""Convert a version string into a tuple of integers.
This method is usually used for comparing two versions. For pre-release
versions: alpha < beta < rc.
Args:
version_str (str): The version string.
length (int): The maximum number of version levels. Default: 4.
Returns:
tuple[int]: The version info in digits (integers).
"""
version = parse(version_str)
assert version.release, f'failed to parse version {version_str}'
release = list(version.release)
release = release[:length]
if len(release) < length:
release = release + [0] * (length - len(release))
if version.is_prerelease:
mapping = {'a': -3, 'b': -2, 'rc': -1}
val = -4
# version.pre can be None
if version.pre:
if version.pre[0] not in mapping:
warnings.warn(f'unknown prerelease version {version.pre[0]}, '
'version checking may go wrong')
else:
val = mapping[version.pre[0]]
release.extend([val, version.pre[-1]])
else:
release.extend([val, 0])

elif version.is_postrelease:
release.extend([1, version.post])
else:
release.extend([0, 0])
return tuple(release)


mmcv_minimum_version = '1.4.2'
mmcv_maximum_version = '1.7.0'
mmcv_version = digit_version(mmcv.__version__)


assert (mmcv_version >= digit_version(mmcv_minimum_version)
and mmcv_version <= digit_version(mmcv_maximum_version)), \
f'MMCV=={mmcv.__version__} is used but incompatible. ' \
f'Please install mmcv>={mmcv_minimum_version}, <={mmcv_maximum_version}.'


__all__ = ['__version__', 'digit_version']
4 changes: 2 additions & 2 deletions openbioseq/core/hooks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
CustomCosineAnnealingHook
from .builder import build_hook, build_addtional_scheduler, build_optimizer
from .deepcluster_hook import DeepClusterHook
from .ema_hook import EMAHook
from .ema_hook import EMAHook, SwitchEMAHook
from .extractor import Extractor
from .lr_scheduler import StepFixCosineAnnealingLrUpdaterHook
from .momentum_hook import CosineHook, StepHook, CosineScheduleHook, StepScheduleHook
Expand All @@ -27,6 +27,6 @@
'build_hook', 'build_addtional_scheduler', 'build_optimizer',
'DeepClusterHook', 'ODCHook', 'PreciseBNHook', 'SwAVHook',
'StepFixCosineAnnealingLrUpdaterHook', 'CosineHook', 'StepHook', 'CosineScheduleHook', 'StepScheduleHook',
'EMAHook', 'Extractor', 'SAVEHook', 'SSLMetricHook', 'ValidateHook',
'EMAHook', 'SwitchEMAHook', 'Extractor', 'SAVEHook', 'SSLMetricHook', 'ValidateHook',
'DistOptimizerHook', 'Fp16OptimizerHook',
]
148 changes: 148 additions & 0 deletions openbioseq/core/hooks/ema_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,3 +93,151 @@ def _swap_ema_parameters(self):
ema_buffer = self.model_buffers[self.param_ema_buffer[name]]
value.data.copy_(ema_buffer.data)
ema_buffer.data.copy_(temp)


@HOOKS.register_module()
class SwitchEMAHook(Hook):
r"""Exponential Moving Average Hook.
IP172 v12.23
Use Exponential Moving Average on all parameters of model in training
process. All parameters have a ema backup, which update by the formula
as below. EMAHook takes priority over EvalHook and CheckpointSaverHook!
.. math::
Xema\_{t+1} = \text{momentum} \times Xema\_{t} +
(1 - \text{momentum}) \times X_t
Args:
momentum (float): The momentum used for updating ema parameter.
Defaults to 0.9999.
resume_from (str): The checkpoint path. Defaults to None.
warmup (string): Type of warmup used. It can be None(use no warmup),
'constant', 'linear' or 'exp'. Default to None.
warmup_iters (int): The number of iterations that warmup lasts, i.e.,
warmup by iteration. Default to 0.
warmup_ratio (float): Attr used at the beginning of warmup equals to
warmup_ratio * momentum.
full_params_ema (bool): Whether to register EMA parameters by
`named_parameters()` or `state_dict()`, which influences performances
of models with BN variants. defaults to False.
update_interval (int): Update ema parameter every interval iteration.
Defaults to 1.
"""

def __init__(self,
momentum=0.9999,
resume_from=None,
warmup=None,
warmup_iters=0,
warmup_ratio=0.9,
switch_params=False,
switch_by_iter=False,
switch_interval=100,
full_params_ema=False,
update_interval=1,
**kwargs):
assert isinstance(update_interval, int) and update_interval > 0
assert momentum > 0 and momentum < 1
self.momentum = momentum
self.regular_momentum = momentum
self.checkpoint = resume_from
if warmup is not None:
if warmup not in ['constant', 'linear', 'exp']:
raise ValueError(
f'"{warmup}" is not a supported type for warming up!')
assert warmup_iters > 0 and 0 < warmup_ratio <= 1.0
self.warmup = warmup
self.warmup_iters = warmup_iters
self.warmup_ratio = warmup_ratio
self.update_interval = update_interval

self.switch_params = switch_params
self.switch_by_iter = switch_by_iter
self.switch_interval = switch_interval
self.full_params_ema = full_params_ema
self.is_warmup = True

def get_warmup_momentum(self, cur_iters):
if self.warmup == 'constant':
warmup_m = self.warmup_ratio * self.momentum
elif self.warmup == 'linear':
k = (1 - cur_iters / self.warmup_iters) * (1 - self.warmup_ratio)
warmup_m = (1 - k) * self.momentum
elif self.warmup == 'exp':
k = self.warmup_ratio**(1 - cur_iters / self.warmup_iters)
warmup_m = k * self.momentum
return warmup_m

def before_run(self, runner):
"""To resume model with it's ema parameters more friendly.
Register ema parameter as ``named_buffer`` to model
"""
model = runner.model
if is_module_wrapper(model):
model = model.module
self.param_ema_buffer = {}
if self.full_params_ema:
self.model_parameters = dict(model.state_dict())
else:
self.model_parameters = dict(model.named_parameters(recurse=True))
for name, value in self.model_parameters.items():
# "." is not allowed in module's buffer name
buffer_name = f"ema_{name.replace('.', '_')}"
self.param_ema_buffer[name] = buffer_name
model.register_buffer(buffer_name, value.data.clone())
if self.checkpoint is not None:
from mmcv.runner import load_checkpoint
load_checkpoint(model, self.checkpoint, strict=False)
# runner.resume(self.checkpoint)
self.model_buffers = dict(model.named_buffers(recurse=True))

def after_train_iter(self, runner):
"""Update ema parameter every self.interval iterations."""
if self.every_n_iters(runner, self.update_interval):
curr_iter = runner.iter
if self.warmup is None or curr_iter > self.warmup_iters:
self.regular_momentum = self.momentum
self.is_warmup = False
else:
self.regular_momentum = self.get_warmup_momentum(curr_iter)
self.is_warmup = True
for name, parameter in self.model_parameters.items():
buffer_name = self.param_ema_buffer[name]
buffer_parameter = self.model_buffers[buffer_name]
buffer_parameter.mul_(self.regular_momentum).add_(
parameter.data, alpha=1. - self.regular_momentum)
# copy EMA to the model
if self.switch_params and self.switch_by_iter:
if not self.is_warmup:
if not self.every_n_iters(runner, self.switch_interval):
self._switch_ema_parameters()

def after_train_epoch(self, runner):
"""We load parameter values from ema backup to model before the
EvalHook."""
self._swap_ema_parameters()

def before_train_epoch(self, runner):
"""We recover model's parameter from ema backup after last epoch's
EvalHook."""
self._swap_ema_parameters()
if self.switch_params and not self.switch_by_iter: # copy EMA to the model
if not self.is_warmup:
if not self.every_n_epochs(runner, self.switch_interval):
self._switch_ema_parameters()

def _swap_ema_parameters(self):
"""Swap the parameter of model with parameter in ema_buffer."""
for name, value in self.model_parameters.items():
temp = value.data.clone()
ema_buffer = self.model_buffers[self.param_ema_buffer[name]]
value.data.copy_(ema_buffer.data)
ema_buffer.data.copy_(temp)

def _switch_ema_parameters(self):
"""Switch the parameter of model to parameters in ema_buffer."""
for name, value in self.model_parameters.items():
ema_buffer = self.model_buffers[self.param_ema_buffer[name]]
value.data.copy_(ema_buffer.data)
27 changes: 27 additions & 0 deletions openbioseq/version.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Copyright (c) CAIRI AI Lab. All rights reserved

__version__ = '0.1.1'

def parse_version_info(version_str):
"""Parse a version string into a tuple.
Args:
version_str (str): The version string.
Returns:
tuple[int | str]: The version info, e.g., "1.3.0" is parsed into
(1, 3, 0), and "2.0.0rc1" is parsed into (2, 0, 0, 'rc1').
"""
version_info = []
for x in version_str.split('.'):
if x.isdigit():
version_info.append(int(x))
elif x.find('rc') != -1:
patch_version = x.split('rc')
version_info.append(int(patch_version[0]))
version_info.append(f'rc{patch_version[1]}')
return tuple(version_info)


version_info = parse_version_info(__version__)

__all__ = ['__version__', 'version_info', 'parse_version_info']
4 changes: 4 additions & 0 deletions requirements/mminstall.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
mmcls >= 0.21.0
mmcv-full>=1.4.7
mmdet >= 2.16.0
mmsegmentation >= 0.20.2
24 changes: 24 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
[bdist_wheel]
universal=1

[aliases]
test=pytest

[yapf]
based_on_style = pep8
blank_line_before_nested_class_or_def = true
split_before_expression_after_opening_paren = true

[isort]
line_length = 79
multi_line_output = 0
extra_standard_library = setuptools
known_first_party = OpenBioSeq
known_third_party = PIL,detectron2,faiss,matplotlib,mmcls,mmcv,mmdet,mmseg,numpy,packaging,pytest,pytorch_sphinx_theme,scipy,seaborn,six,sklearn,svm_helper,timm,torch,torchvision,tqdm
no_lines_before = STDLIB,LOCALFOLDER
default_section = THIRDPARTY

[codespell]
skip = *.ipynb
quiet-level = 3
ignore-words-list = patten,confectionary,nd,ty,formating,dows
Loading

0 comments on commit d8c1c9a

Please sign in to comment.