forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathhub.py
174 lines (136 loc) · 5.39 KB
/
hub.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
import importlib
import os
import shutil
import sys
import tempfile
import zipfile
if sys.version_info[0] == 2:
from urlparse import urlparse
from urllib2 import urlopen # noqa f811
else:
from urllib.request import urlopen
from urllib.parse import urlparse
import torch
import torch.utils.model_zoo as model_zoo
MASTER_BRANCH = 'master'
ENV_TORCH_HUB_DIR = 'TORCH_HUB_DIR'
DEFAULT_TORCH_HUB_DIR = '~/.torch/hub'
READ_DATA_CHUNK = 8192
hub_dir = None
def _check_module_exists(name):
if sys.version_info >= (3, 4):
import importlib.util
return importlib.util.find_spec(name) is not None
elif sys.version_info >= (3, 3):
# Special case for python3.3
import importlib.find_loader
return importlib.find_loader(name) is not None
else:
# NB: imp doesn't handle hierarchical module names (names contains dots).
try:
import imp
imp.find_module(name)
except Exception:
return False
return True
def _remove_if_exists(path):
if os.path.exists(path):
if os.path.isfile(path):
os.remove(path)
else:
shutil.rmtree(path)
def _git_archive_link(repo, branch):
return 'https://github.com/' + repo + '/archive/' + branch + '.zip'
def _download_url_to_file(url, filename):
sys.stderr.write('Downloading: \"{}\" to {}'.format(url, filename))
response = urlopen(url)
with open(filename, 'wb') as f:
while True:
data = response.read(READ_DATA_CHUNK)
if len(data) == 0:
break
f.write(data)
def _load_attr_from_module(module_name, func_name):
m = importlib.import_module(module_name)
# Check if callable is defined in the module
if func_name not in dir(m):
return None
return getattr(m, func_name)
def set_dir(d):
r"""
Optionally set hub_dir to a local dir to save downloaded models & weights.
If this argument is not set, env variable `TORCH_HUB_DIR` will be searched first,
`~/.torch/hub` will be created and used as fallback.
Args:
d: path to a local folder to save downloaded models & weights.
"""
global hub_dir
hub_dir = d
def load(github, model, force_reload=False, *args, **kwargs):
r"""
Load a model from a github repo, with pretrained weights.
Args:
github: Required, a string with format "repo_owner/repo_name[:tag_name]" with an optional
tag/branch. The default branch is `master` if not specified.
Example: 'pytorch/vision[:hub]'
model: Required, a string of entrypoint name defined in repo's hubconf.py
force_reload: Optional, whether to discard the existing cache and force a fresh download.
Default is `False`.
*args: Optional, the corresponding args for callable `model`.
**kwargs: Optional, the corresponding kwargs for callable `model`.
Returns:
a single model with corresponding pretrained weights.
"""
if not isinstance(model, str):
raise ValueError('Invalid input: model should be a string of function name')
# Setup hub_dir to save downloaded files
global hub_dir
if hub_dir is None:
hub_dir = os.getenv(ENV_TORCH_HUB_DIR, DEFAULT_TORCH_HUB_DIR)
if '~' in hub_dir:
hub_dir = os.path.expanduser(hub_dir)
if not os.path.exists(hub_dir):
os.makedirs(hub_dir)
# Parse github repo information
branch = MASTER_BRANCH
if ':' in github:
repo_info, branch = github.split(':')
else:
repo_info = github
repo_owner, repo_name = repo_info.split('/')
# Download zipped code from github
url = _git_archive_link(repo_info, branch)
cached_file = os.path.join(hub_dir, branch + '.zip')
repo_dir = os.path.join(hub_dir, repo_name + '_' + branch)
use_cache = (not force_reload) and os.path.exists(repo_dir)
# Github uses '{repo_name}-{branch_name}' as folder name which is not importable
# We need to manually rename it to '{repo_name}'
# Unzip the code and rename the base folder
if use_cache:
sys.stderr.write('Using cache found in {}'.format(repo_dir))
else:
_remove_if_exists(cached_file)
_download_url_to_file(url, cached_file)
cached_zipfile = zipfile.ZipFile(cached_file)
# Github renames folder repo-v1.x.x to repo-1.x.x
extraced_repo_name = cached_zipfile.infolist()[0].filename
extracted_repo = os.path.join(hub_dir, extraced_repo_name)
_remove_if_exists(extracted_repo)
cached_zipfile.extractall(hub_dir)
_remove_if_exists(cached_file)
_remove_if_exists(repo_dir)
shutil.move(extracted_repo, repo_dir) # rename the repo
sys.path.insert(0, repo_dir) # Make Python interpreter aware of the repo
dependencies = _load_attr_from_module('hubconf', 'dependencies')
if dependencies is not None:
missing_deps = [pkg for pkg in dependencies if not _check_module_exists(pkg)]
if len(missing_deps):
raise RuntimeError('Missing dependencies: {}'.format(', '.join(missing_deps)))
func = _load_attr_from_module('hubconf', model)
if func is None:
raise RuntimeError('Cannot find callable {} in hubconf'.format(model))
# Check if func is callable
if not callable(func):
raise RuntimeError('{} is not callable'.format(func))
# Call the function
return func(*args, **kwargs)