From a64a0b22fd8b730340a05a150545e9e0b618df46 Mon Sep 17 00:00:00 2001 From: hktk Date: Thu, 21 Mar 2024 17:16:53 +0800 Subject: [PATCH 1/7] add util for loss spike save and decode. --- .../tests/utils/test_loss_spike_utils.py | 30 ++++ atorch/atorch/utils/loss_spike_utils.py | 146 ++++++++++++++++++ 2 files changed, 176 insertions(+) create mode 100644 atorch/atorch/tests/utils/test_loss_spike_utils.py create mode 100644 atorch/atorch/utils/loss_spike_utils.py diff --git a/atorch/atorch/tests/utils/test_loss_spike_utils.py b/atorch/atorch/tests/utils/test_loss_spike_utils.py new file mode 100644 index 000000000..da188b72b --- /dev/null +++ b/atorch/atorch/tests/utils/test_loss_spike_utils.py @@ -0,0 +1,30 @@ +import os +import unittest +from atorch.utils.loss_spike_utils import TokenLossSpike + + +class LossSpikeTest(unittest.TestCase): + def init(self): + self.min_iter = 100 + self.min_loss = 4.0 + + sample_data_paths = [] + each_sample_len = 2 + + if not os.path.exists("loss_spike"): + os.mkdir("loss_spike") + + self.loss_ins = TokenLossSpike("loss_spike", sample_data_paths, + each_sample_len, self.min_iter, + self.min_loss) + + def test_save(self): + self.init() + self.loss_ins.save_loss( + "test_loss.txt", 4.05, 103, + "2.44,2.33,4.05", + "20-1-1385697-14158189-2,20-1-1385697-14158189-2,20-1-1385697-14158189-2") + + def test_decode(self): + self.init() + self.loss_ins.decode_loss_spike("res.txt", None) diff --git a/atorch/atorch/utils/loss_spike_utils.py b/atorch/atorch/utils/loss_spike_utils.py new file mode 100644 index 000000000..720d0b600 --- /dev/null +++ b/atorch/atorch/utils/loss_spike_utils.py @@ -0,0 +1,146 @@ +""" loss spike utils +Save loss spike to files; +Decode loss spike and save Corresponding sample to a file. +""" +import os +import datetime + +import numpy as np +from atorch.common.log_utils import default_logger as logger + + +class LossSpikeBase: + def __init__(self, loss_spike_save_dir, sample_data_paths, each_sample_len, min_iter, min_loss, + loss_info_splitter='\t', loss_sample_str_splitter=','): + """ + init params + Args: + loss_spike_save_dir: str, The directory where loss spike files are stored + sample_data_paths: 样本存储的路径信息, 可以是用户自定义的结构,如名称+路径的tuple list: + [("wikipedia", "corpus/base"), ("zhihu", "/dataset/fd5061f6/data/tokenize_data/zhihu.lazy")] + each_sample_len: int, 单个样本的长度 + min_iter: int, 记录loss所需的最小迭代轮数 + min_loss: float, 记录loss所需的最小loss阈值 + loss_info_splitter: str, 存储loss尖刺信息时所用的分隔符, 默认值='\t' + loss_sample_str_splitter: str, 一批loss或者sample信息组成的str所用的分隔符, 默认值=',' + """ + self.loss_spike_save_dir = loss_spike_save_dir + self.sample_data_paths = sample_data_paths + self.each_sample_len = each_sample_len + self.min_iter = min_iter + self.min_loss = min_loss + self.loss_info_splitter = loss_info_splitter + self.loss_sample_str_splitter = loss_sample_str_splitter + + if not os.path.exists(loss_spike_save_dir): + raise ValueError("Param loss_spike_save_dir not exist!") + logger.info("Loss spike init success") + + @staticmethod + def get_data_file_len(fpath, dtype): + with open(fpath) as f: + f.seek(0, 2) + return f.tell() // dtype.itemsize + + +class TokenLossSpike(LossSpikeBase): + def save_loss(self, file_name, cur_loss, cur_iter, losses_str, sample_infos_str): + """ + 存储尖刺loss及对应的一些信息 + Args: + file_name: str, loss spike file name + cur_loss: float, current loss value + cur_iter: int, current iteration + losses_str: A string of loss concatenated with splitter, 如: + "2.31,2.30,10.98,1.56" + sample_infos_str: A string of sample id info concatenated with splitter, 如: + "20-17-1385697-14158189-936633,20-17-1385697-14158189-936633" + + """ + file_path = os.path.join(self.loss_spike_save_dir, file_name) + if cur_loss > self.min_loss and cur_iter > self.min_iter: + logger.info(f"save loss={cur_loss}, iter={cur_iter}") + # 存储结构需要定义 + cur_time = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S') + info = self.loss_info_splitter.join([cur_time, str(cur_iter), str(cur_loss), losses_str, sample_infos_str]) + info = info + "\n" + with open(file_path, 'a+') as w: + w.write(info) + + def decode_loss_spike(self, result_file_path, tokenizer, min_iter=None, min_loss=None): + """ + 根据文件中的尖刺loss等信息,解析出对应的样本到文件 + Args: + result_file_path: str, 存储解码后的样本内容的文件地址 + tokenizer: instance, 样本文件中的存储的样本, 需要时的解码器 + min_iter: int, 解码loss、sample所需的最小迭代轮数, 不传的话,使用存储时的初始化的阈值 + min_loss: float, 解码loss、sample所需的最小loss阈值, 不传的话,使用存储时的初始化的阈值 + + Returns: + + """ + if min_iter is None: + min_iter = self.min_iter + if min_loss is None: + min_loss = self.min_loss + + with open(result_file_path, "w") as fw: + # 遍历整个目录 + for loss_spike_file in os.listdir(self.loss_spike_save_dir): + file_path = os.path.join(self.loss_spike_save_dir, loss_spike_file) + with open(file_path) as fr: + for line in fr: + # process file content by line + # 内容结构: f"{ctime}\t{iteration}\t{loss}\t{loss_str}\t{sample_ids_str}\n" + fcontent = line.strip().split(self.loss_info_splitter) + cur_iter = int(fcontent[1]) + cur_loss = float(fcontent[2]) + loss_str = fcontent[3] + sample_infos_str = fcontent[4] + if cur_iter < min_iter or cur_loss < min_loss: + logger.info(f"The content with iter={cur_iter} and loss={cur_loss} will not be parsed!") + continue + # parse + logger.info(f"Parse content with iter={cur_iter} and loss={cur_loss}!") + ds, text, max_loss = self.parse_sample_content(loss_str, sample_infos_str, tokenizer) + if ds is None: + continue + fw.write(f"=========={ds} {max_loss}================\n") + fw.write(f"{text}\n\n\n\n") + + def parse_sample_content(self, losses_str, sample_infos_str, tokenizer): + losses = [float(e) for e in losses_str.split(self.loss_sample_str_splitter)] # 解析loss es + sample_infos = sample_infos_str.split(self.loss_sample_str_splitter) + if len(losses) != len(sample_infos): + logger.warn("batch loss length != batch sample length") + return None, None, None + + losses = np.array(losses) # 解析loss es + idx = losses.argmax(-1) + max_loss = losses[idx] + sample_with_max_loss = sample_infos[idx] + ds, data = self.fetch(sample_with_max_loss) + if ds is None: + return None, None, None + if tokenizer is not None: + data = tokenizer.decode(data) + return ds, data, max_loss + + def fetch(self, each_sample_info): + # 这里是比较定制化的,不同应用场景可以构建不同的子类 + # 20-17-1385697-14158189-936633 + # {scatter_id}-{dsid}-{idx}-{raw_id}-{sample_id} + scatter_id, dsid, _, _, sample_id = each_sample_info.split('-') # 这个结构是在初始化定义好的 + + ds_info = self.sample_data_paths[int(dsid)] # 用列表的索引?0是train data,1是对应的NAMED_CORPORA[e].PATH + + datapath = f'{ds_info[1]}.scatter/{scatter_id}.lazy/text' # 这怎么办啊,太定制了 + + if not os.path.exists(datapath): + logger.warn("sample data path not exist") + return None, None + flen = self.get_data_file_len(datapath, np.dtype(np.int32)) + sample_cnt = flen // self.each_sample_len + f = np.memmap(datapath, dtype=np.int32, shape=(sample_cnt, self.each_sample_len)) # 磁盘写内存 + data = f[int(sample_id)] + return ds_info[0], data From a077103546dd5d305d20ec8f2150a0edc7d6db76 Mon Sep 17 00:00:00 2001 From: hktk Date: Wed, 27 Mar 2024 14:21:27 +0800 Subject: [PATCH 2/7] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E6=B3=A8=E9=87=8A?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- atorch/atorch/utils/loss_spike_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/atorch/atorch/utils/loss_spike_utils.py b/atorch/atorch/utils/loss_spike_utils.py index 720d0b600..fb9384bc1 100644 --- a/atorch/atorch/utils/loss_spike_utils.py +++ b/atorch/atorch/utils/loss_spike_utils.py @@ -132,9 +132,9 @@ def fetch(self, each_sample_info): # {scatter_id}-{dsid}-{idx}-{raw_id}-{sample_id} scatter_id, dsid, _, _, sample_id = each_sample_info.split('-') # 这个结构是在初始化定义好的 - ds_info = self.sample_data_paths[int(dsid)] # 用列表的索引?0是train data,1是对应的NAMED_CORPORA[e].PATH + ds_info = self.sample_data_paths[int(dsid)] # 用列表的索引, 0是train data,1是对应的样本PATH - datapath = f'{ds_info[1]}.scatter/{scatter_id}.lazy/text' # 这怎么办啊,太定制了 + datapath = f'{ds_info[1]}.scatter/{scatter_id}.lazy/text' if not os.path.exists(datapath): logger.warn("sample data path not exist") From 8b19ab536f8cf12465a838f6a183f0164efac298 Mon Sep 17 00:00:00 2001 From: hktk Date: Wed, 27 Mar 2024 14:33:00 +0800 Subject: [PATCH 3/7] Format fix. --- .../tests/utils/test_loss_spike_utils.py | 18 +++++-- atorch/atorch/utils/loss_spike_utils.py | 51 ++++++++++++++----- 2 files changed, 50 insertions(+), 19 deletions(-) diff --git a/atorch/atorch/tests/utils/test_loss_spike_utils.py b/atorch/atorch/tests/utils/test_loss_spike_utils.py index da188b72b..408611da1 100644 --- a/atorch/atorch/tests/utils/test_loss_spike_utils.py +++ b/atorch/atorch/tests/utils/test_loss_spike_utils.py @@ -1,5 +1,6 @@ import os import unittest + from atorch.utils.loss_spike_utils import TokenLossSpike @@ -14,16 +15,23 @@ def init(self): if not os.path.exists("loss_spike"): os.mkdir("loss_spike") - self.loss_ins = TokenLossSpike("loss_spike", sample_data_paths, - each_sample_len, self.min_iter, - self.min_loss) + self.loss_ins = TokenLossSpike( + "loss_spike", + sample_data_paths, + each_sample_len, + self.min_iter, + self.min_loss, + ) def test_save(self): self.init() self.loss_ins.save_loss( - "test_loss.txt", 4.05, 103, + "test_loss.txt", + 4.05, + 103, "2.44,2.33,4.05", - "20-1-1385697-14158189-2,20-1-1385697-14158189-2,20-1-1385697-14158189-2") + "20-1-1385697-14158189-2,20-1-1385697-14158189-2,20-1-1385697-14158189-2", + ) def test_decode(self): self.init() diff --git a/atorch/atorch/utils/loss_spike_utils.py b/atorch/atorch/utils/loss_spike_utils.py index fb9384bc1..996534631 100644 --- a/atorch/atorch/utils/loss_spike_utils.py +++ b/atorch/atorch/utils/loss_spike_utils.py @@ -2,16 +2,25 @@ Save loss spike to files; Decode loss spike and save Corresponding sample to a file. """ -import os import datetime +import os import numpy as np + from atorch.common.log_utils import default_logger as logger class LossSpikeBase: - def __init__(self, loss_spike_save_dir, sample_data_paths, each_sample_len, min_iter, min_loss, - loss_info_splitter='\t', loss_sample_str_splitter=','): + def __init__( + self, + loss_spike_save_dir, + sample_data_paths, + each_sample_len, + min_iter, + min_loss, + loss_info_splitter="\t", + loss_sample_str_splitter=",", + ): """ init params Args: @@ -61,13 +70,17 @@ def save_loss(self, file_name, cur_loss, cur_iter, losses_str, sample_infos_str) if cur_loss > self.min_loss and cur_iter > self.min_iter: logger.info(f"save loss={cur_loss}, iter={cur_iter}") # 存储结构需要定义 - cur_time = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S') - info = self.loss_info_splitter.join([cur_time, str(cur_iter), str(cur_loss), losses_str, sample_infos_str]) + cur_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") + info = self.loss_info_splitter.join( + [cur_time, str(cur_iter), str(cur_loss), losses_str, sample_infos_str] + ) info = info + "\n" - with open(file_path, 'a+') as w: + with open(file_path, "a+") as w: w.write(info) - def decode_loss_spike(self, result_file_path, tokenizer, min_iter=None, min_loss=None): + def decode_loss_spike( + self, result_file_path, tokenizer, min_iter=None, min_loss=None + ): """ 根据文件中的尖刺loss等信息,解析出对应的样本到文件 Args: @@ -98,18 +111,26 @@ def decode_loss_spike(self, result_file_path, tokenizer, min_iter=None, min_loss loss_str = fcontent[3] sample_infos_str = fcontent[4] if cur_iter < min_iter or cur_loss < min_loss: - logger.info(f"The content with iter={cur_iter} and loss={cur_loss} will not be parsed!") + logger.info( + f"The content with iter={cur_iter} and loss={cur_loss} will not be parsed!" + ) continue # parse - logger.info(f"Parse content with iter={cur_iter} and loss={cur_loss}!") - ds, text, max_loss = self.parse_sample_content(loss_str, sample_infos_str, tokenizer) + logger.info( + f"Parse content with iter={cur_iter} and loss={cur_loss}!" + ) + ds, text, max_loss = self.parse_sample_content( + loss_str, sample_infos_str, tokenizer + ) if ds is None: continue fw.write(f"=========={ds} {max_loss}================\n") fw.write(f"{text}\n\n\n\n") def parse_sample_content(self, losses_str, sample_infos_str, tokenizer): - losses = [float(e) for e in losses_str.split(self.loss_sample_str_splitter)] # 解析loss es + losses = [ + float(e) for e in losses_str.split(self.loss_sample_str_splitter) + ] # 解析loss es sample_infos = sample_infos_str.split(self.loss_sample_str_splitter) if len(losses) != len(sample_infos): logger.warn("batch loss length != batch sample length") @@ -130,17 +151,19 @@ def fetch(self, each_sample_info): # 这里是比较定制化的,不同应用场景可以构建不同的子类 # 20-17-1385697-14158189-936633 # {scatter_id}-{dsid}-{idx}-{raw_id}-{sample_id} - scatter_id, dsid, _, _, sample_id = each_sample_info.split('-') # 这个结构是在初始化定义好的 + scatter_id, dsid, _, _, sample_id = each_sample_info.split("-") # 这个结构是在初始化定义好的 ds_info = self.sample_data_paths[int(dsid)] # 用列表的索引, 0是train data,1是对应的样本PATH - datapath = f'{ds_info[1]}.scatter/{scatter_id}.lazy/text' + datapath = f"{ds_info[1]}.scatter/{scatter_id}.lazy/text" if not os.path.exists(datapath): logger.warn("sample data path not exist") return None, None flen = self.get_data_file_len(datapath, np.dtype(np.int32)) sample_cnt = flen // self.each_sample_len - f = np.memmap(datapath, dtype=np.int32, shape=(sample_cnt, self.each_sample_len)) # 磁盘写内存 + f = np.memmap( + datapath, dtype=np.int32, shape=(sample_cnt, self.each_sample_len) + ) # 磁盘写内存 data = f[int(sample_id)] return ds_info[0], data From 5419ccfb326ade5affadbe96d90f747f4248e772 Mon Sep 17 00:00:00 2001 From: hktk Date: Fri, 29 Mar 2024 10:39:00 +0800 Subject: [PATCH 4/7] Format fix 2. --- atorch/atorch/utils/loss_spike_utils.py | 28 +++++++------------------ 1 file changed, 7 insertions(+), 21 deletions(-) diff --git a/atorch/atorch/utils/loss_spike_utils.py b/atorch/atorch/utils/loss_spike_utils.py index 996534631..1e12d21b9 100644 --- a/atorch/atorch/utils/loss_spike_utils.py +++ b/atorch/atorch/utils/loss_spike_utils.py @@ -71,16 +71,12 @@ def save_loss(self, file_name, cur_loss, cur_iter, losses_str, sample_infos_str) logger.info(f"save loss={cur_loss}, iter={cur_iter}") # 存储结构需要定义 cur_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") - info = self.loss_info_splitter.join( - [cur_time, str(cur_iter), str(cur_loss), losses_str, sample_infos_str] - ) + info = self.loss_info_splitter.join([cur_time, str(cur_iter), str(cur_loss), losses_str, sample_infos_str]) info = info + "\n" with open(file_path, "a+") as w: w.write(info) - def decode_loss_spike( - self, result_file_path, tokenizer, min_iter=None, min_loss=None - ): + def decode_loss_spike(self, result_file_path, tokenizer, min_iter=None, min_loss=None): """ 根据文件中的尖刺loss等信息,解析出对应的样本到文件 Args: @@ -111,26 +107,18 @@ def decode_loss_spike( loss_str = fcontent[3] sample_infos_str = fcontent[4] if cur_iter < min_iter or cur_loss < min_loss: - logger.info( - f"The content with iter={cur_iter} and loss={cur_loss} will not be parsed!" - ) + logger.info(f"The content with iter={cur_iter} and loss={cur_loss} will not be parsed!") continue # parse - logger.info( - f"Parse content with iter={cur_iter} and loss={cur_loss}!" - ) - ds, text, max_loss = self.parse_sample_content( - loss_str, sample_infos_str, tokenizer - ) + logger.info(f"Parse content with iter={cur_iter} and loss={cur_loss}!") + ds, text, max_loss = self.parse_sample_content(loss_str, sample_infos_str, tokenizer) if ds is None: continue fw.write(f"=========={ds} {max_loss}================\n") fw.write(f"{text}\n\n\n\n") def parse_sample_content(self, losses_str, sample_infos_str, tokenizer): - losses = [ - float(e) for e in losses_str.split(self.loss_sample_str_splitter) - ] # 解析loss es + losses = [float(e) for e in losses_str.split(self.loss_sample_str_splitter)] # 解析loss es sample_infos = sample_infos_str.split(self.loss_sample_str_splitter) if len(losses) != len(sample_infos): logger.warn("batch loss length != batch sample length") @@ -162,8 +150,6 @@ def fetch(self, each_sample_info): return None, None flen = self.get_data_file_len(datapath, np.dtype(np.int32)) sample_cnt = flen // self.each_sample_len - f = np.memmap( - datapath, dtype=np.int32, shape=(sample_cnt, self.each_sample_len) - ) # 磁盘写内存 + f = np.memmap(datapath, dtype=np.int32, shape=(sample_cnt, self.each_sample_len)) # 磁盘写内存 data = f[int(sample_id)] return ds_info[0], data From 5e395bae5dcf9566099fa3bfff47cc088ccbbc34 Mon Sep 17 00:00:00 2001 From: hktk Date: Wed, 3 Apr 2024 10:22:53 +0800 Subject: [PATCH 5/7] =?UTF-8?q?=E4=BC=98=E5=8C=96=E6=B3=A8=E9=87=8A?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- atorch/atorch/utils/loss_spike_utils.py | 46 +++++++++++++------------ 1 file changed, 24 insertions(+), 22 deletions(-) diff --git a/atorch/atorch/utils/loss_spike_utils.py b/atorch/atorch/utils/loss_spike_utils.py index 1e12d21b9..4ba438741 100644 --- a/atorch/atorch/utils/loss_spike_utils.py +++ b/atorch/atorch/utils/loss_spike_utils.py @@ -25,13 +25,14 @@ def __init__( init params Args: loss_spike_save_dir: str, The directory where loss spike files are stored - sample_data_paths: 样本存储的路径信息, 可以是用户自定义的结构,如名称+路径的tuple list: + sample_data_paths: The path information stored in the sample can be a user-defined structure, + such as name + path, tuple list: [("wikipedia", "corpus/base"), ("zhihu", "/dataset/fd5061f6/data/tokenize_data/zhihu.lazy")] - each_sample_len: int, 单个样本的长度 - min_iter: int, 记录loss所需的最小迭代轮数 - min_loss: float, 记录loss所需的最小loss阈值 - loss_info_splitter: str, 存储loss尖刺信息时所用的分隔符, 默认值='\t' - loss_sample_str_splitter: str, 一批loss或者sample信息组成的str所用的分隔符, 默认值=',' + each_sample_len: int, The length of a single sample + min_iter: int, Record the minimum iterations required for loss + min_loss: float, Record the minimum loss threshold required for loss + loss_info_splitter: str, Delimiter used to store loss spike information, default ='\t' + loss_sample_str_splitter: str, Delimiter used by str for a batch of loss or sample information, default =',' """ self.loss_spike_save_dir = loss_spike_save_dir self.sample_data_paths = sample_data_paths @@ -55,21 +56,21 @@ def get_data_file_len(fpath, dtype): class TokenLossSpike(LossSpikeBase): def save_loss(self, file_name, cur_loss, cur_iter, losses_str, sample_infos_str): """ - 存储尖刺loss及对应的一些信息 + Store spike loss and corresponding information Args: file_name: str, loss spike file name cur_loss: float, current loss value cur_iter: int, current iteration - losses_str: A string of loss concatenated with splitter, 如: + losses_str: A string of loss concatenated with splitter, eg: "2.31,2.30,10.98,1.56" - sample_infos_str: A string of sample id info concatenated with splitter, 如: + sample_infos_str: A string of sample id info concatenated with splitter, eg: "20-17-1385697-14158189-936633,20-17-1385697-14158189-936633" """ file_path = os.path.join(self.loss_spike_save_dir, file_name) if cur_loss > self.min_loss and cur_iter > self.min_iter: logger.info(f"save loss={cur_loss}, iter={cur_iter}") - # 存储结构需要定义 + # define structure cur_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") info = self.loss_info_splitter.join([cur_time, str(cur_iter), str(cur_loss), losses_str, sample_infos_str]) info = info + "\n" @@ -78,12 +79,12 @@ def save_loss(self, file_name, cur_loss, cur_iter, losses_str, sample_infos_str) def decode_loss_spike(self, result_file_path, tokenizer, min_iter=None, min_loss=None): """ - 根据文件中的尖刺loss等信息,解析出对应的样本到文件 + According to the information such as spike loss in the file, the corresponding sample is parsed to the file Args: - result_file_path: str, 存储解码后的样本内容的文件地址 - tokenizer: instance, 样本文件中的存储的样本, 需要时的解码器 - min_iter: int, 解码loss、sample所需的最小迭代轮数, 不传的话,使用存储时的初始化的阈值 - min_loss: float, 解码loss、sample所需的最小loss阈值, 不传的话,使用存储时的初始化的阈值 + result_file_path: str, The address of the file that stores the contents of the decoded sample + tokenizer: instance, + min_iter: int, minimum iterations required for decode loss、sample + min_loss: float, minimum loss required for decode loss、sample Returns: @@ -94,13 +95,13 @@ def decode_loss_spike(self, result_file_path, tokenizer, min_iter=None, min_loss min_loss = self.min_loss with open(result_file_path, "w") as fw: - # 遍历整个目录 + # Traverse the entire directory for loss_spike_file in os.listdir(self.loss_spike_save_dir): file_path = os.path.join(self.loss_spike_save_dir, loss_spike_file) with open(file_path) as fr: for line in fr: # process file content by line - # 内容结构: f"{ctime}\t{iteration}\t{loss}\t{loss_str}\t{sample_ids_str}\n" + # structure: f"{ctime}\t{iteration}\t{loss}\t{loss_str}\t{sample_ids_str}\n" fcontent = line.strip().split(self.loss_info_splitter) cur_iter = int(fcontent[1]) cur_loss = float(fcontent[2]) @@ -124,7 +125,7 @@ def parse_sample_content(self, losses_str, sample_infos_str, tokenizer): logger.warn("batch loss length != batch sample length") return None, None, None - losses = np.array(losses) # 解析loss es + losses = np.array(losses) idx = losses.argmax(-1) max_loss = losses[idx] sample_with_max_loss = sample_infos[idx] @@ -136,12 +137,13 @@ def parse_sample_content(self, losses_str, sample_infos_str, tokenizer): return ds, data, max_loss def fetch(self, each_sample_info): - # 这里是比较定制化的,不同应用场景可以构建不同的子类 + # Here is more customized, different application scenarios can build different subclasses # 20-17-1385697-14158189-936633 # {scatter_id}-{dsid}-{idx}-{raw_id}-{sample_id} - scatter_id, dsid, _, _, sample_id = each_sample_info.split("-") # 这个结构是在初始化定义好的 + scatter_id, dsid, _, _, sample_id = each_sample_info.split("-") # structure is defined during initialization - ds_info = self.sample_data_paths[int(dsid)] # 用列表的索引, 0是train data,1是对应的样本PATH + # Using the index of the list, 0 is the train data and 1 is the corresponding sample PATH + ds_info = self.sample_data_paths[int(dsid)] datapath = f"{ds_info[1]}.scatter/{scatter_id}.lazy/text" @@ -150,6 +152,6 @@ def fetch(self, each_sample_info): return None, None flen = self.get_data_file_len(datapath, np.dtype(np.int32)) sample_cnt = flen // self.each_sample_len - f = np.memmap(datapath, dtype=np.int32, shape=(sample_cnt, self.each_sample_len)) # 磁盘写内存 + f = np.memmap(datapath, dtype=np.int32, shape=(sample_cnt, self.each_sample_len)) # Disk to memory data = f[int(sample_id)] return ds_info[0], data From e66915fc420e799dddb8f4923006423a5104db8f Mon Sep 17 00:00:00 2001 From: hktk Date: Fri, 12 Apr 2024 16:15:33 +0800 Subject: [PATCH 6/7] Add readme.md. --- atorch/atorch/utils/loss_spike_utils.py | 2 +- atorch/docs/README-LOSS-SPIKE-UTIL.md | Bin 0 -> 3968 bytes 2 files changed, 1 insertion(+), 1 deletion(-) create mode 100644 atorch/docs/README-LOSS-SPIKE-UTIL.md diff --git a/atorch/atorch/utils/loss_spike_utils.py b/atorch/atorch/utils/loss_spike_utils.py index 4ba438741..f67136d57 100644 --- a/atorch/atorch/utils/loss_spike_utils.py +++ b/atorch/atorch/utils/loss_spike_utils.py @@ -1,6 +1,6 @@ """ loss spike utils Save loss spike to files; -Decode loss spike and save Corresponding sample to a file. +Decode loss spike and save Corresponding sample to a file. Using doc see ../docs/README-LOSS-SPIKE-UTIL.md """ import datetime import os diff --git a/atorch/docs/README-LOSS-SPIKE-UTIL.md b/atorch/docs/README-LOSS-SPIKE-UTIL.md new file mode 100644 index 0000000000000000000000000000000000000000..d0d403747c707c6561ee89984406867117a0a79d GIT binary patch literal 3968 zcmeHKT~iuo6z=u@iXFU&GKk_FJZ7~m{e1bNl{e3yv)rChX|UW0o|hN+YW<>B5XSwYS*x@5 zDyy%td^48JWG-4ov%UzXNgKp7(hqVNulU*-KUozfKxzLK-+Fen_>3G$&X7VllTz-< zwwf>mA?w;JV|(7%Z9-teojZ3@_tj`JAV_Lmoi&LtUR7Oqr$oOsuBmG76yB!~%RiF$BDK3XRPHnKB1AmdZ zv;`42=m7~lKq)Jepqt2*au}`*?N^~o-?)Eb!Z+y;Onop2yLMOwM>#MW91Tv~ zn!0)W^U>gBaOzg@*6q=8-}uL!dj?B1fb!wAoSBQMVXI<@nSZF5lemM6bm*&OLbm(m zTweL##ly!aSYSbg5ga>f7KsL6l)F!!_RJz-1*a?gX`ue=Kw*uMk$Z~i5DyS6Upz5i z&$FX-ws3$FK(c)8iLSjfDm5TukT7;?dS{cbAM1^*-gx8nveQ2dt&A$w0}v-cuNN77 zfUH2jk4uo$J1_OdF4&^B1G{%$=78q)=E?w;c_G`aX2H!xYXe-DR`qrVL|s}SMz@W2 zezPr>k9BQNZ*3#WATq!%6ks4YTmk0Jf%s>(%C}E>^{`u6`ZK#Jvhq`r9iOr03-Eh+ zOGPqIkJ(3h^`10!#|Avd7OcK2DmjMcT$X3tg(?Z`aV zSb2?Ps0sX+cbJ;B;^-V#o-=iDg&(nPb^&^-O1Zt?*v%biYSHs^<&Q(4O5Ln6lDU1_NeDSQCD}gSI$F+?6Y5VqVYuvsLpStf6p9uKmhdXBN$j}zAB~-m# zGTOC^R)q-DX3;7D#C|EeuUwdVmVL=;8~RxRI)r0j^$)tv$`wk?obtboJQKUG&M zLPVrR^yUmg-ody%(@-WW9}tIHK&6b8I(#;?eO~CmPYk_r1lVVlCRB}Yw^*wTx9HF! zR-gpFRpr$J%e~VzU|1A@Tb}<(*ACfIiIxi8_L?O}mgJh5&GDyg^X(4p7_S$9wOl_s z1Dovs^(w;`L;o`p>OiXqtLjjWOXF}tpd_Fp9Nh;cXu4dDI~l%1Kl%0 Date: Tue, 16 Apr 2024 16:20:45 +0800 Subject: [PATCH 7/7] Optimize code structure. --- atorch/atorch/tests/test_loss_spike_utils.py | 49 ++++++++++++++++++ .../tests/utils/test_loss_spike_utils.py | 38 -------------- atorch/atorch/utils/loss_spike_utils.py | 13 +++-- atorch/docs/README-LOSS-SPIKE-UTIL.md | Bin 3968 -> 3204 bytes 4 files changed, 55 insertions(+), 45 deletions(-) create mode 100644 atorch/atorch/tests/test_loss_spike_utils.py delete mode 100644 atorch/atorch/tests/utils/test_loss_spike_utils.py diff --git a/atorch/atorch/tests/test_loss_spike_utils.py b/atorch/atorch/tests/test_loss_spike_utils.py new file mode 100644 index 000000000..a955f9b82 --- /dev/null +++ b/atorch/atorch/tests/test_loss_spike_utils.py @@ -0,0 +1,49 @@ +import os +import unittest + +from atorch.utils.loss_spike_utils import TokenLossSpike + + +class LossSpikeTest(unittest.TestCase): + def init(self): + self.min_iter = 100 + self.min_loss = 4.0 + + sample_data_paths = [("wikipedia", "corpus/base"), ("wikipedia", "corpus/base"), ("wikipedia", "corpus/base")] + each_sample_len = 2 + + if not os.path.exists("utils/loss_spike"): + os.mkdir("utils/loss_spike") + + self.loss_ins = TokenLossSpike( + "utils/loss_spike", + sample_data_paths, + each_sample_len, + self.min_iter, + self.min_loss, + ) + + def setUp(self): + self.init() + + def test_save(self): + self.loss_ins.save_loss( + "test_loss.txt", + 4.05, + 103, + losses_str="2.44,2.33,4.05", + sample_infos_str="20-1-1385697-14158189-2,20-1-1385697-14158189-2,20-1-1385697-14158189-2", + ) + + def test_decode(self): + self.loss_ins.decode_loss_spike("res.txt", None) + + def test_parse(self): + self.loss_ins.parse_sample_content( + losses_str="2.44,2.33,4.05", + sample_infos_str="20-1-1385697-14158189-2,20-1-1385697-14158189-2,20-1-1385697-14158189-2", + tokenizer=None, + ) + + def test_fetch(self): + self.loss_ins.fetch("20-1-1385697-14158189-2") diff --git a/atorch/atorch/tests/utils/test_loss_spike_utils.py b/atorch/atorch/tests/utils/test_loss_spike_utils.py deleted file mode 100644 index 408611da1..000000000 --- a/atorch/atorch/tests/utils/test_loss_spike_utils.py +++ /dev/null @@ -1,38 +0,0 @@ -import os -import unittest - -from atorch.utils.loss_spike_utils import TokenLossSpike - - -class LossSpikeTest(unittest.TestCase): - def init(self): - self.min_iter = 100 - self.min_loss = 4.0 - - sample_data_paths = [] - each_sample_len = 2 - - if not os.path.exists("loss_spike"): - os.mkdir("loss_spike") - - self.loss_ins = TokenLossSpike( - "loss_spike", - sample_data_paths, - each_sample_len, - self.min_iter, - self.min_loss, - ) - - def test_save(self): - self.init() - self.loss_ins.save_loss( - "test_loss.txt", - 4.05, - 103, - "2.44,2.33,4.05", - "20-1-1385697-14158189-2,20-1-1385697-14158189-2,20-1-1385697-14158189-2", - ) - - def test_decode(self): - self.init() - self.loss_ins.decode_loss_spike("res.txt", None) diff --git a/atorch/atorch/utils/loss_spike_utils.py b/atorch/atorch/utils/loss_spike_utils.py index f67136d57..fedf46460 100644 --- a/atorch/atorch/utils/loss_spike_utils.py +++ b/atorch/atorch/utils/loss_spike_utils.py @@ -54,20 +54,19 @@ def get_data_file_len(fpath, dtype): class TokenLossSpike(LossSpikeBase): - def save_loss(self, file_name, cur_loss, cur_iter, losses_str, sample_infos_str): + def save_loss(self, file_name, cur_loss, cur_iter, *args, **kargs): """ Store spike loss and corresponding information Args: file_name: str, loss spike file name - cur_loss: float, current loss value + cur_loss: float, current avg loss value cur_iter: int, current iteration - losses_str: A string of loss concatenated with splitter, eg: - "2.31,2.30,10.98,1.56" - sample_infos_str: A string of sample id info concatenated with splitter, eg: - "20-17-1385697-14158189-936633,20-17-1385697-14158189-936633" + args/kargs: any custom data in string format. """ file_path = os.path.join(self.loss_spike_save_dir, file_name) + losses_str = kargs["losses_str"] + sample_infos_str = kargs["sample_infos_str"] if cur_loss > self.min_loss and cur_iter > self.min_iter: logger.info(f"save loss={cur_loss}, iter={cur_iter}") # define structure @@ -119,7 +118,7 @@ def decode_loss_spike(self, result_file_path, tokenizer, min_iter=None, min_loss fw.write(f"{text}\n\n\n\n") def parse_sample_content(self, losses_str, sample_infos_str, tokenizer): - losses = [float(e) for e in losses_str.split(self.loss_sample_str_splitter)] # 解析loss es + losses = [float(e) for e in losses_str.split(self.loss_sample_str_splitter)] sample_infos = sample_infos_str.split(self.loss_sample_str_splitter) if len(losses) != len(sample_infos): logger.warn("batch loss length != batch sample length") diff --git a/atorch/docs/README-LOSS-SPIKE-UTIL.md b/atorch/docs/README-LOSS-SPIKE-UTIL.md index d0d403747c707c6561ee89984406867117a0a79d..a2b6c967bd29399b5b70ac52318298c2148a33a5 100644 GIT binary patch literal 3204 zcmeHJ&2HQ_5Wedv2J*q{1+D$@uIm`sTYCr)6h)Js>O!U^u0>d)1d{S!5FiiIzFeN9 zGeb#HS5AubRv$_t>YrnVm5)pyFg!OgeiTNFhjLA(s? z4kTs~;4>29JLKB#;ujWU0W9y>%`}%@rj7K4+x} z7C#rdXlZ9bT2A@JKHo0SSNn5bSa@<=KPxu6VCkfk&rk8mAKjr!W<}baTq?8MbnBHi zsrJh5*30E`v_}5#R!i~uhT80F3|gr@i63A2<4!zWHC)PhJCYH{(aB*DlOfNuy)IR1 zWTkd-9~+A2u+0+ry-iad}ftvOP_MY2s8f@vHs#cF6{2- z+NdsH{pO4B0T17zz+fST6DqritMM0UF|!C2qyhYGtn$)pEXPr02INq!rQgfzKb98< z92s{M%btY<_(B2Fk;;($p1|R+Hn)Lu2}({NXCn!_)+!*vb|4MR=ju^APm+V~e9qK8 zp)LUj#3G<|q^S)|Pz*%YmuTO%1_l}k>`&-HrKa_C6hafPG&${(FI33}_OtOJrETVrp@^KPBYauez+0%L_-Nn!5X`@HcQ0nXrd zaBGz$(E9=-i4quA{E`@7%mTJOmG^%2)&#skn^Q$P=|FMQrTHIJEMOnM&%CbGhb#Y{ zj$>EogG}^qB0T!v1$E%D4_79Rwmp8PEG{sQdx_)&_p)p;M-D!H{8N1T_?N&;k{Mg^ zfwa>|0U;OgN6j8eG+UnBB-R$zE2-(Tc}@7>~{yI{BID}HF7q0Slj zp)T$T$CP|7+P-}Tu}UE@!yZmpebW^5&nkl)&;uxmI=F>q8MJRWX$;M*HU^0W)6L^u zKN%zhrbCd8mAm>PpKVFHkA_FJZ7~m{e1bNl{e3yv)rChX|UW0o|hN+YW<>B5XSwYS*x@5 zDyy%td^48JWG-4ov%UzXNgKp7(hqVNulU*-KUozfKxzLK-+Fen_>3G$&X7VllTz-< zwwf>mA?w;JV|(7%Z9-teojZ3@_tj`JAV_Lmoi&LtUR7Oqr$oOsuBmG76yB!~%RiF$BDK3XRPHnKB1AmdZ zv;`42=m7~lKq)Jepqt2*au}`*?N^~o-?)Eb!Z+y;Onop2yLMOwM>#MW91Tv~ zn!0)W^U>gBaOzg@*6q=8-}uL!dj?B1fb!wAoSBQMVXI<@nSZF5lemM6bm*&OLbm(m zTweL##ly!aSYSbg5ga>f7KsL6l)F!!_RJz-1*a?gX`ue=Kw*uMk$Z~i5DyS6Upz5i z&$FX-ws3$FK(c)8iLSjfDm5TukT7;?dS{cbAM1^*-gx8nveQ2dt&A$w0}v-cuNN77 zfUH2jk4uo$J1_OdF4&^B1G{%$=78q)=E?w;c_G`aX2H!xYXe-DR`qrVL|s}SMz@W2 zezPr>k9BQNZ*3#WATq!%6ks4YTmk0Jf%s>(%C}E>^{`u6`ZK#Jvhq`r9iOr03-Eh+ zOGPqIkJ(3h^`10!#|Avd7OcK2DmjMcT$X3tg(?Z`aV zSb2?Ps0sX+cbJ;B;^-V#o-=iDg&(nPb^&^-O1Zt?*v%biYSHs^<&Q(4O5Ln6lDU1_NeDSQCD}gSI$F+?6Y5VqVYuvsLpStf6p9uKmhdXBN$j}zAB~-m# zGTOC^R)q-DX3;7D#C|EeuUwdVmVL=;8~RxRI)r0j^$)tv$`wk?obtboJQKUG&M zLPVrR^yUmg-ody%(@-WW9}tIHK&6b8I(#;?eO~CmPYk_r1lVVlCRB}Yw^*wTx9HF! zR-gpFRpr$J%e~VzU|1A@Tb}<(*ACfIiIxi8_L?O}mgJh5&GDyg^X(4p7_S$9wOl_s z1Dovs^(w;`L;o`p>OiXqtLjjWOXF}tpd_Fp9Nh;cXu4dDI~l%1Kl%0