Skip to content

Commit 3c4a302

Browse files
committed
first commit
0 parents  commit 3c4a302

File tree

8 files changed

+1958
-0
lines changed

8 files changed

+1958
-0
lines changed

bert4pytorch/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
#! -*- coding: utf-8 -*-
2+
3+
import file_utils
4+
5+
__version__ = '0.1.0'

bert4pytorch/ema.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
class EMA():
2+
'''
3+
模型权重的指数滑动平均
4+
注意区别于类似adam一类的自适应学习率优化器,针对一阶二阶梯度的指数滑动平均,两者完全不同
5+
6+
例子:
7+
# 初始化
8+
ema = EMA(model, 0.999)
9+
10+
# 训练过程中,更新完参数后,同步update ema_weights weights
11+
def train():
12+
optimizer.step()
13+
ema.update()
14+
15+
# eval前,apply ema_weights weights;eval之后,恢复原来模型的参数
16+
def evaluate():
17+
ema.apply_ema_weights()
18+
# evaluate
19+
# 如果想保存ema后的模型,请在restore方法之前调用torch.save()
20+
ema.restore()
21+
'''
22+
def __init__(self, model, decay):
23+
self.model = model
24+
self.decay = decay
25+
# 保存影子权重(当前step的每一层的滑动平均权重)
26+
self.ema_weights = {}
27+
# 在进行evaluate的时候,保存原始的模型权重,当执行完evaluate后,从影子权重恢复到原始权重
28+
self.model_weights = {}
29+
30+
# 初始化ema_weights为model_weights
31+
for name, param in self.model.named_parameters():
32+
if param.requires_grad:
33+
self.ema_weights[name] = param.data.clone()
34+
35+
def update(self):
36+
for name, param in self.model.named_parameters():
37+
if param.requires_grad:
38+
assert name in self.ema_weights
39+
new_average = (1.0 - self.decay) * param.data + self.decay * self.ema_weights[name]
40+
self.ema_weights[name] = new_average.clone()
41+
42+
def apply_ema_weights(self):
43+
for name, param in self.model.named_parameters():
44+
if param.requires_grad:
45+
assert name in self.ema_weights
46+
self.model_weights[name] = param.data
47+
param.data = self.ema_weights[name]
48+
49+
def restore(self):
50+
for name, param in self.model.named_parameters():
51+
if param.requires_grad:
52+
assert name in self.model_weights
53+
param.data = self.model_weights[name]
54+
self.model_weights = {}
55+
56+
57+
58+

bert4pytorch/file_utils.py

Lines changed: 265 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,265 @@
1+
from __future__ import (absolute_import, division, print_function, unicode_literals)
2+
3+
import sys
4+
import json
5+
import logging
6+
import os
7+
import shutil
8+
import tempfile
9+
import fnmatch
10+
from functools import wraps
11+
from hashlib import sha256
12+
import sys
13+
from io import open
14+
15+
import boto3
16+
import requests
17+
from botocore.exceptions import ClientError
18+
from tqdm import tqdm
19+
20+
try:
21+
from urllib.parse import urlparse
22+
except ImportError:
23+
from urlparse import urlparse
24+
25+
try:
26+
from pathlib import Path
27+
PYTORCH_PRETRAINED_BERT_CACHE = Path(os.getenv('PYTORCH_PRETRAINED_BERT_CACHE',
28+
Path.home() / '.pytorch_pretrained_bert'))
29+
except (AttributeError, ImportError):
30+
PYTORCH_PRETRAINED_BERT_CACHE = os.getenv('PYTORCH_PRETRAINED_BERT_CACHE',
31+
os.path.join(os.path.expanduser("~"), '.pytorch_pretrained_bert'))
32+
33+
CONFIG_NAME = "config.json"
34+
WEIGHTS_NAME = "pytorch_model.bin"
35+
36+
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
37+
38+
39+
def url_to_filename(url, etag=None):
40+
"""
41+
Convert `url` into a hashed filename in a repeatable way.
42+
If `etag` is specified, append its hash to the url's, delimited
43+
by a period.
44+
"""
45+
url_bytes = url.encode('utf-8')
46+
url_hash = sha256(url_bytes)
47+
filename = url_hash.hexdigest()
48+
49+
if etag:
50+
etag_bytes = etag.encode('utf-8')
51+
etag_hash = sha256(etag_bytes)
52+
filename += '.' + etag_hash.hexdigest()
53+
54+
return filename
55+
56+
57+
def filename_to_url(filename, cache_dir=None):
58+
"""
59+
Return the url and etag (which may be ``None``) stored for `filename`.
60+
Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist.
61+
"""
62+
if cache_dir is None:
63+
cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
64+
if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
65+
cache_dir = str(cache_dir)
66+
67+
cache_path = os.path.join(cache_dir, filename)
68+
if not os.path.exists(cache_path):
69+
raise EnvironmentError("file {} not found".format(cache_path))
70+
71+
meta_path = cache_path + '.json'
72+
if not os.path.exists(meta_path):
73+
raise EnvironmentError("file {} not found".format(meta_path))
74+
75+
with open(meta_path, encoding="utf-8") as meta_file:
76+
metadata = json.load(meta_file)
77+
url = metadata['url']
78+
etag = metadata['etag']
79+
80+
return url, etag
81+
82+
83+
def cached_path(url_or_filename, cache_dir=None):
84+
"""
85+
Given something that might be a URL (or might be a local path),
86+
determine which. If it's a URL, download the file and cache it, and
87+
return the path to the cached file. If it's already a local path,
88+
make sure the file exists and then return the path.
89+
"""
90+
if cache_dir is None:
91+
cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
92+
if sys.version_info[0] == 3 and isinstance(url_or_filename, Path):
93+
url_or_filename = str(url_or_filename)
94+
if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
95+
cache_dir = str(cache_dir)
96+
97+
parsed = urlparse(url_or_filename)
98+
99+
if parsed.scheme in ('http', 'https', 's3'):
100+
# URL, so get it from the cache (downloading if necessary)
101+
return get_from_cache(url_or_filename, cache_dir)
102+
elif os.path.exists(url_or_filename):
103+
# File, and it exists.
104+
return url_or_filename
105+
elif parsed.scheme == '':
106+
# File, but it doesn't exist.
107+
raise EnvironmentError("file {} not found".format(url_or_filename))
108+
else:
109+
# Something unknown
110+
raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename))
111+
112+
113+
def split_s3_path(url):
114+
"""Split a full s3 path into the bucket name and path."""
115+
parsed = urlparse(url)
116+
if not parsed.netloc or not parsed.path:
117+
raise ValueError("bad s3 path {}".format(url))
118+
bucket_name = parsed.netloc
119+
s3_path = parsed.path
120+
# Remove '/' at beginning of path.
121+
if s3_path.startswith("/"):
122+
s3_path = s3_path[1:]
123+
return bucket_name, s3_path
124+
125+
126+
def s3_request(func):
127+
"""
128+
Wrapper function for s3 requests in order to create more helpful error
129+
messages.
130+
"""
131+
132+
@wraps(func)
133+
def wrapper(url, *args, **kwargs):
134+
try:
135+
return func(url, *args, **kwargs)
136+
except ClientError as exc:
137+
if int(exc.response["Error"]["Code"]) == 404:
138+
raise EnvironmentError("file {} not found".format(url))
139+
else:
140+
raise
141+
142+
return wrapper
143+
144+
145+
@s3_request
146+
def s3_etag(url):
147+
"""Check ETag on S3 object."""
148+
s3_resource = boto3.resource("s3")
149+
bucket_name, s3_path = split_s3_path(url)
150+
s3_object = s3_resource.Object(bucket_name, s3_path)
151+
return s3_object.e_tag
152+
153+
154+
@s3_request
155+
def s3_get(url, temp_file):
156+
"""Pull a file directly from S3."""
157+
s3_resource = boto3.resource("s3")
158+
bucket_name, s3_path = split_s3_path(url)
159+
s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file)
160+
161+
162+
def http_get(url, temp_file):
163+
req = requests.get(url, stream=True)
164+
content_length = req.headers.get('Content-Length')
165+
total = int(content_length) if content_length is not None else None
166+
progress = tqdm(unit="B", total=total)
167+
for chunk in req.iter_content(chunk_size=1024):
168+
if chunk: # filter out keep-alive new chunks
169+
progress.update(len(chunk))
170+
temp_file.write(chunk)
171+
progress.close()
172+
173+
174+
def get_from_cache(url, cache_dir=None):
175+
"""
176+
Given a URL, look for the corresponding dataset in the local cache.
177+
If it's not there, download it. Then return the path to the cached file.
178+
"""
179+
if cache_dir is None:
180+
cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
181+
if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
182+
cache_dir = str(cache_dir)
183+
184+
if not os.path.exists(cache_dir):
185+
os.makedirs(cache_dir)
186+
187+
# Get eTag to add to filename, if it exists.
188+
if url.startswith("s3://"):
189+
etag = s3_etag(url)
190+
else:
191+
try:
192+
response = requests.head(url, allow_redirects=True)
193+
if response.status_code != 200:
194+
etag = None
195+
else:
196+
etag = response.headers.get("ETag")
197+
except EnvironmentError:
198+
etag = None
199+
200+
if sys.version_info[0] == 2 and etag is not None:
201+
etag = etag.decode('utf-8')
202+
filename = url_to_filename(url, etag)
203+
204+
# get cache path to put the file
205+
cache_path = os.path.join(cache_dir, filename)
206+
207+
# If we don't have a connection (etag is None) and can't identify the file
208+
# try to get the last downloaded one
209+
if not os.path.exists(cache_path) and etag is None:
210+
matching_files = fnmatch.filter(os.listdir(cache_dir), filename + '.*')
211+
matching_files = list(filter(lambda s: not s.endswith('.json'), matching_files))
212+
if matching_files:
213+
cache_path = os.path.join(cache_dir, matching_files[-1])
214+
215+
if not os.path.exists(cache_path):
216+
# Download to temporary file, then copy to cache dir once finished.
217+
# Otherwise you get corrupt cache entries if the download gets interrupted.
218+
with tempfile.NamedTemporaryFile() as temp_file:
219+
logger.info("%s not found in cache, downloading to %s", url, temp_file.name)
220+
221+
# GET file object
222+
if url.startswith("s3://"):
223+
s3_get(url, temp_file)
224+
else:
225+
http_get(url, temp_file)
226+
227+
# we are copying the file before closing it, so flush to avoid truncation
228+
temp_file.flush()
229+
# shutil.copyfileobj() starts at the current position, so go to the start
230+
temp_file.seek(0)
231+
232+
logger.info("copying %s to cache at %s", temp_file.name, cache_path)
233+
with open(cache_path, 'wb') as cache_file:
234+
shutil.copyfileobj(temp_file, cache_file)
235+
236+
logger.info("creating metadata file for %s", cache_path)
237+
meta = {'url': url, 'etag': etag}
238+
meta_path = cache_path + '.json'
239+
with open(meta_path, 'w') as meta_file:
240+
output_string = json.dumps(meta)
241+
if sys.version_info[0] == 2 and isinstance(output_string, str):
242+
output_string = unicode(output_string, 'utf-8') # The beauty of python 2
243+
meta_file.write(output_string)
244+
245+
logger.info("removing temp file %s", temp_file.name)
246+
247+
return cache_path
248+
249+
250+
def read_set_from_file(filename):
251+
'''
252+
Extract a de-duped collection (set) of text from a file.
253+
Expected file format is one item per line.
254+
'''
255+
collection = set()
256+
with open(filename, 'r', encoding='utf-8') as file_:
257+
for line in file_:
258+
collection.add(line.rstrip())
259+
return collection
260+
261+
262+
def get_file_extension(path, dot=True, lower=True):
263+
ext = os.path.splitext(path)[1]
264+
ext = ext if dot else ext[1:]
265+
return ext.lower() if lower else ext

0 commit comments

Comments
 (0)