Skip to content

Commit

Permalink
Merge pull request #17 from muou55555/dev
Browse files Browse the repository at this point in the history
Optimize log and update requriremnet.txt
  • Loading branch information
HuangLK authored Apr 18, 2023
2 parents ee10d80 + 4c633b1 commit faedea5
Show file tree
Hide file tree
Showing 7 changed files with 62 additions and 40 deletions.
42 changes: 42 additions & 0 deletions common/log.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from loguru import logger
from common.utils import is_rank_0
from common.mp_wraps import rank_zero

# Log config
LOG_FILENAME = "ds_training.log"

class GetLogger:
__instance = None
__init_flag = True

def __new__(cls, *args, **kwargs):
if not cls.__instance:
cls.__instance = super(GetLogger, cls).__new__(cls, *args, **kwargs)
return cls.__instance

def __init__(self):
if self.__init_flag:
logger.add(LOG_FILENAME)
self.__init_flag: False

@rank_zero
def trace(self, *args, **kwargs):
logger.trace(*args, **kwargs)

@rank_zero
def debug(self, *args, **kwargs):
logger.debug(*args, **kwargs)

@rank_zero
def info(self, *args, **kwargs):
logger.info(*args, **kwargs)

@rank_zero
def warning(self, *args, **kwargs):
logger.warning(*args, **kwargs)

@rank_zero
def error(self, *args, **kwargs):
logger.error(*args, **kwargs)

logger_rank0 = GetLogger()
15 changes: 15 additions & 0 deletions common/mp_wraps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from functools import wraps
from common.utils import is_rank_0

__all__ = ["rank_zero"]


def rank_zero(func):
@wraps(func)
def wrapper(*args, **kwargs):
if not is_rank_0():
return
result = func(*args, **kwargs)
return result

return wrapper
37 changes: 0 additions & 37 deletions utils.py → common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,48 +3,11 @@
import sys
import time
import json

import transformers
import torch.distributed as dist
from loguru import logger as logger


logger.add(f'ds_training.log')


def is_rank_0() -> bool:
return not dist.is_initialized() or dist.get_rank() == 0


class LoggerRank0:
def trace(self, *args, **kwargs):
if not is_rank_0():
return
logger.trace(*args, **kwargs)

def debug(self, *args, **kwargs):
if not is_rank_0():
return
logger.debug(*args, **kwargs)

def info(self, *args, **kwargs):
if not is_rank_0():
return
logger.info(*args, **kwargs)

def warning(self, *args, **kwargs):
if not is_rank_0():
return
logger.warning(*args, **kwargs)

def error(self, *args, **kwargs):
if not is_rank_0():
return
logger.error(*args, **kwargs)

logger_rank0 = LoggerRank0()


def _make_w_io_base(f, mode: str):
if not isinstance(f, io.IOBase):
f_dirname = os.path.dirname(f)
Expand Down
4 changes: 3 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,6 @@ numpy
sentencepiece
transformers==4.28.0
git+https://github.com/HuangLK/DeepSpeed.git@dev
flash_attn
flash_attn
einops
loguru
File renamed without changes.
File renamed without changes.
4 changes: 2 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
from feeder import (
make_prompt_dataloader,
)
from utils import jload
from utils import logger_rank0 as logger
from common.utils import jload
from common.log import logger_rank0 as logger

warnings.filterwarnings("ignore")

Expand Down

0 comments on commit faedea5

Please sign in to comment.