diff --git a/edb/load_ext/main.py b/edb/load_ext/main.py index cc849833f55..c3ab4c2aff4 100644 --- a/edb/load_ext/main.py +++ b/edb/load_ext/main.py @@ -34,6 +34,7 @@ import shutil import subprocess import sys +import tempfile import tomllib import zipfile @@ -104,6 +105,37 @@ def get_pg_config(pg_config_path: pathlib.Path) -> dict[str, str]: return config +def install_edgedb_extension( + pkg: pathlib.Path, + ext_dir: pathlib.Path, +) -> None: + target = ext_dir / pkg.stem + print("Installing", target) + + with tempfile.TemporaryDirectory() as tdir, \ + zipfile.ZipFile(pkg) as z: + + ttarget = pathlib.Path(tdir) / pkg.stem + os.mkdir(ttarget) + + with z.open('MANIFEST.toml') as m: + manifest = tomllib.load(m) + + files = ['MANIFEST.toml'] + manifest['files'] + + for f in files: + target_file = target / f + ttarget_file = ttarget / f + + with z.open(f) as src: + with open(ttarget_file, "wb") as dst: + print("Installing", target_file) + shutil.copyfileobj(src, dst) + + os.makedirs(ext_dir, exist_ok=True) + shutil.move(ttarget, ext_dir) + + def load_ext_main( package: pathlib.Path, skip_edgedb: bool, @@ -114,9 +146,7 @@ def load_ext_main( from edb import buildmeta ext_dir = buildmeta.get_extension_dir_path() - os.makedirs(ext_dir, exist_ok=True) - print("Installing", ext_dir / package.name) - shutil.copyfile(package, ext_dir / package.name) + install_edgedb_extension(package, ext_dir) if not skip_postgres: if with_pg_config is None: diff --git a/edb/server/tenant.py b/edb/server/tenant.py index 5b0f84c41da..0ddfd019848 100644 --- a/edb/server/tenant.py +++ b/edb/server/tenant.py @@ -46,7 +46,6 @@ import tomllib import uuid import weakref -import zipfile import immutables @@ -440,8 +439,11 @@ async def load_extension_packages(self, path: pathlib.Path) -> None: try: with os.scandir(path) as it: for entry in it: - if entry.is_file() and entry.name.endswith('.zip'): - exts.append(entry) + if ( + entry.is_dir() + and (pathlib.Path(entry) / 'MANIFEST.toml').exists() + ): + exts.append(pathlib.Path(entry)) except FileNotFoundError: pass @@ -471,27 +473,26 @@ async def load_extension_packages(self, path: pathlib.Path) -> None: async def _load_extension_package( self, - path: os.PathLike, + path: pathlib.Path, ext_packages: set[tuple[str, verutils.Version]], ) -> None: - with zipfile.ZipFile(path) as z: - with z.open('MANIFEST.toml') as m: - manifest = tomllib.load(m) - - name = manifest['name'] - version = verutils.parse_version(manifest['version']) - if (name, version) in ext_packages: - logger.info( - f"Extension package '{manifest['name']}' {version} " - f"already installed" - ) + with open(path / 'MANIFEST.toml', 'rb') as m: + manifest = tomllib.load(m) - return + name = manifest['name'] + version = verutils.parse_version(manifest['version']) + if (name, version) in ext_packages: + logger.info( + f"Extension package '{manifest['name']}' {version} " + f"already installed" + ) + + return - scripts = [] - for file in manifest['files']: - with z.open(file) as f: - scripts.append(f.read().decode('utf-8')) + scripts = [] + for file in manifest['files']: + with open(path / file, 'rb') as f: + scripts.append(f.read().decode('utf-8')) from edb.schema import schema as s_schema