-
Notifications
You must be signed in to change notification settings - Fork 304
/
Copy pathinstall.py
27 lines (24 loc) · 1.17 KB
/
install.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import os
import patch
import subprocess
import sys
from utils.python_utils import pip_install_requirements
def patch_dalle2():
import dalle2_pytorch
current_dir = os.path.dirname(os.path.abspath(__file__))
dalle2_dir = os.path.dirname(dalle2_pytorch.__file__)
dalle2_patch = patch.fromfile(os.path.join(current_dir, "dalle2_pytorch.patch"))
if not dalle2_patch.apply(strip=1, root=dalle2_dir):
print("Failed to patch dalle2_pytorch/dalle2_pytorch.py. Exit.")
exit(1)
def pip_install_requirements_dalle2():
pip_install_requirements()
# DALLE2_pytorch requires embedding-reader
# https://github.com/lucidrains/DALLE2-pytorch/blob/00e07b7d61e21447d55e6d06d5c928cf8b67601d/setup.py#L34
# embedding-reader requires an old version of pandas and pyarrow
# https://github.com/rom1504/embedding-reader/blob/a4fd55830a502685600ed8ef07947cd1cb92b083/requirements.txt#L5
# So we need to reinstall a newer version of pandas and pyarrow, to be compatible with other models
subprocess.check_call([sys.executable, '-m', 'pip', 'install', '-U', 'pandas', 'pyarrow'])
if __name__ == '__main__':
pip_install_requirements_dalle2()
patch_dalle2()