Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add util for loss spike save and decode. #1044

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions atorch/atorch/tests/test_loss_spike_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import os
import unittest

from atorch.utils.loss_spike_utils import TokenLossSpike


class LossSpikeTest(unittest.TestCase):
def init(self):
self.min_iter = 100
self.min_loss = 4.0

sample_data_paths = [("wikipedia", "corpus/base"), ("wikipedia", "corpus/base"), ("wikipedia", "corpus/base")]
each_sample_len = 2

if not os.path.exists("utils/loss_spike"):
os.mkdir("utils/loss_spike")

self.loss_ins = TokenLossSpike(
"utils/loss_spike",
sample_data_paths,
each_sample_len,
self.min_iter,
self.min_loss,
)

def setUp(self):
self.init()

def test_save(self):
self.loss_ins.save_loss(
"test_loss.txt",
4.05,
103,
losses_str="2.44,2.33,4.05",
sample_infos_str="20-1-1385697-14158189-2,20-1-1385697-14158189-2,20-1-1385697-14158189-2",
)

def test_decode(self):
self.loss_ins.decode_loss_spike("res.txt", None)

def test_parse(self):
self.loss_ins.parse_sample_content(
losses_str="2.44,2.33,4.05",
sample_infos_str="20-1-1385697-14158189-2,20-1-1385697-14158189-2,20-1-1385697-14158189-2",
tokenizer=None,
)

def test_fetch(self):
self.loss_ins.fetch("20-1-1385697-14158189-2")
156 changes: 156 additions & 0 deletions atorch/atorch/utils/loss_spike_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
""" loss spike utils
Save loss spike to files;
Decode loss spike and save Corresponding sample to a file. Using doc see ../docs/README-LOSS-SPIKE-UTIL.md
"""
import datetime
import os

import numpy as np

from atorch.common.log_utils import default_logger as logger


class LossSpikeBase:
def __init__(
self,
loss_spike_save_dir,
sample_data_paths,
each_sample_len,
min_iter,
min_loss,
loss_info_splitter="\t",
loss_sample_str_splitter=",",
):
"""
init params
Args:
loss_spike_save_dir: str, The directory where loss spike files are stored
sample_data_paths: The path information stored in the sample can be a user-defined structure,
such as name + path, tuple list:
[("wikipedia", "corpus/base"), ("zhihu", "/dataset/fd5061f6/data/tokenize_data/zhihu.lazy")]
each_sample_len: int, The length of a single sample
min_iter: int, Record the minimum iterations required for loss
min_loss: float, Record the minimum loss threshold required for loss
loss_info_splitter: str, Delimiter used to store loss spike information, default ='\t'
loss_sample_str_splitter: str, Delimiter used by str for a batch of loss or sample information, default =','
"""
self.loss_spike_save_dir = loss_spike_save_dir
self.sample_data_paths = sample_data_paths
self.each_sample_len = each_sample_len
self.min_iter = min_iter
self.min_loss = min_loss
self.loss_info_splitter = loss_info_splitter
self.loss_sample_str_splitter = loss_sample_str_splitter

if not os.path.exists(loss_spike_save_dir):
raise ValueError("Param loss_spike_save_dir not exist!")
logger.info("Loss spike init success")

@staticmethod
def get_data_file_len(fpath, dtype):
with open(fpath) as f:
f.seek(0, 2)
return f.tell() // dtype.itemsize


class TokenLossSpike(LossSpikeBase):
def save_loss(self, file_name, cur_loss, cur_iter, *args, **kargs):
"""
Store spike loss and corresponding information
Args:
file_name: str, loss spike file name
cur_loss: float, current avg loss value
cur_iter: int, current iteration
args/kargs: any custom data in string format.

"""
file_path = os.path.join(self.loss_spike_save_dir, file_name)
losses_str = kargs["losses_str"]
sample_infos_str = kargs["sample_infos_str"]
if cur_loss > self.min_loss and cur_iter > self.min_iter:
logger.info(f"save loss={cur_loss}, iter={cur_iter}")
# define structure
cur_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
info = self.loss_info_splitter.join([cur_time, str(cur_iter), str(cur_loss), losses_str, sample_infos_str])
info = info + "\n"
with open(file_path, "a+") as w:
w.write(info)

def decode_loss_spike(self, result_file_path, tokenizer, min_iter=None, min_loss=None):
"""
According to the information such as spike loss in the file, the corresponding sample is parsed to the file
Args:
result_file_path: str, The address of the file that stores the contents of the decoded sample
tokenizer: instance,
min_iter: int, minimum iterations required for decode loss、sample
min_loss: float, minimum loss required for decode loss、sample

Returns:

"""
if min_iter is None:
min_iter = self.min_iter
if min_loss is None:
min_loss = self.min_loss

with open(result_file_path, "w") as fw:
# Traverse the entire directory
for loss_spike_file in os.listdir(self.loss_spike_save_dir):
file_path = os.path.join(self.loss_spike_save_dir, loss_spike_file)
with open(file_path) as fr:
for line in fr:
# process file content by line
# structure: f"{ctime}\t{iteration}\t{loss}\t{loss_str}\t{sample_ids_str}\n"
fcontent = line.strip().split(self.loss_info_splitter)
cur_iter = int(fcontent[1])
cur_loss = float(fcontent[2])
loss_str = fcontent[3]
sample_infos_str = fcontent[4]
if cur_iter < min_iter or cur_loss < min_loss:
logger.info(f"The content with iter={cur_iter} and loss={cur_loss} will not be parsed!")
continue
# parse
logger.info(f"Parse content with iter={cur_iter} and loss={cur_loss}!")
ds, text, max_loss = self.parse_sample_content(loss_str, sample_infos_str, tokenizer)
if ds is None:
continue
fw.write(f"=========={ds} {max_loss}================\n")
fw.write(f"{text}\n\n\n\n")

def parse_sample_content(self, losses_str, sample_infos_str, tokenizer):
losses = [float(e) for e in losses_str.split(self.loss_sample_str_splitter)]
sample_infos = sample_infos_str.split(self.loss_sample_str_splitter)
if len(losses) != len(sample_infos):
logger.warn("batch loss length != batch sample length")
return None, None, None

losses = np.array(losses)
idx = losses.argmax(-1)
max_loss = losses[idx]
sample_with_max_loss = sample_infos[idx]
ds, data = self.fetch(sample_with_max_loss)
if ds is None:
return None, None, None
if tokenizer is not None:
data = tokenizer.decode(data)
return ds, data, max_loss

def fetch(self, each_sample_info):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since fetch is user-defined method, then either:
define it in base class as abstract method.
or
the class instance initialization should have a parameter (fetch_func) , which is provided by user.

# Here is more customized, different application scenarios can build different subclasses
# 20-17-1385697-14158189-936633
# {scatter_id}-{dsid}-{idx}-{raw_id}-{sample_id}
scatter_id, dsid, _, _, sample_id = each_sample_info.split("-") # structure is defined during initialization

# Using the index of the list, 0 is the train data and 1 is the corresponding sample PATH
ds_info = self.sample_data_paths[int(dsid)]

datapath = f"{ds_info[1]}.scatter/{scatter_id}.lazy/text"

if not os.path.exists(datapath):
logger.warn("sample data path not exist")
return None, None
flen = self.get_data_file_len(datapath, np.dtype(np.int32))
sample_cnt = flen // self.each_sample_len
f = np.memmap(datapath, dtype=np.int32, shape=(sample_cnt, self.each_sample_len)) # Disk to memory
data = f[int(sample_id)]
return ds_info[0], data
82 changes: 82 additions & 0 deletions atorch/docs/README-LOSS-SPIKE-UTIL.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
### User's manual

1. Init instance with args.
2. Record spike loss with save_loss func.
3. Decode loss with decode_loss_spike func.

### Example code

```python
from atorch.utils.loss_spike_utils import TokenLossSpike


loss_spike_ins = TokenLossSpike(
loss_spike_save_dir="loss_spike_save_dir",
sample_data_paths=[("wikipedia", "corpus/base"), ("zhihu", "/dataset/fd5061f6/data/tokenize_data/zhihu.lazy")],
each_sample_len=4,
min_iter=2000,
min_loss=10,
loss_info_splitter='\t',
loss_sample_str_splitter=','
)

loss_spike_ins.save_loss(file_name="",
cur_loss=4,
cur_iter=100,
losses_str="2.44,2.33,4.05",
sample_infos_str="20-1-1385697-14158189-2,20-1-1385697-14158189-2,20-1-1385697-14158189-2")

loss_spike_ins.decode_loss_spike(result_file_path="",
tokenizer=None,
min_iter=None,
min_loss=None)
```

### Parameter interpretation

```python
loss_spike_ins = TokenLossSpike(
loss_spike_save_dir="loss_spike_save_dir",
sample_data_paths=[("wikipedia", "corpus/base"), ("zhihu", "/dataset/fd5061f6/data/tokenize_data/zhihu.lazy")],
each_sample_len=4,
min_iter=2000,
min_loss=10,
loss_info_splitter='\t',
loss_sample_str_splitter=','
)
```

1. loss_spike_save_dir, the directory storing loss files, make sure it is an **existing directory**.
2. sample_data_paths, each_sample_len, using in decode_loss_spike func: **run save_loss with None input**.
1. sample_data_paths: the map for sample data and their file path.
2. each_sample_len: single sample len for decoding sample data.
3. min_iter, min_loss
1. min_iter, Iterations greater than min iter are recorded.
2. min_loss, Loss greater than min loss are recorded.
4. loss_info_splitter, f"{ctime}\t{iteration}\t{loss}\t{loss_str}\t{sample_infos_str}\n"
5. loss_sample_str_splitter, default value is “,”, **information is passed in by the user, so the user needs to match it**.

```python
loss_spike_ins.save_loss(file_name="",
cur_loss=4,
cur_iter=100,
losses_str="2.44,2.33,4.05",
sample_infos_str="20-1-1385697-14158189-2,20-1-1385697-14158189-2,20-1-1385697-14158189-2")
```

1. file_name, the file for saving loss.
2. cur_loss, the mean loss.
3. cur_iter, current iteration.
4. losses_str, the batch loss connected with splitter.
5. sample_infos_str, the batch sample info(**This structure is fixed in the current implementation of the inheritance class, and then if there are other formats, we use other inheritance class implementation, or abstract parameters**).

```python
loss_spike_ins.decode_loss_spike(result_file_path="",
tokenizer=None,
min_iter=None,
min_loss=None)
```

1. result_file_path, the file saving decoded sample.
2. tokenizer, default value is None, if exist, the sample will decode by tokenizer.decode func.
3. min_iter, min_loss.
Loading