Skip to content
This repository has been archived by the owner on Aug 26, 2020. It is now read-only.

Commit

Permalink
fix: extract module to correct location in download_and_install (#261)
Browse files Browse the repository at this point in the history
  • Loading branch information
laurenyu authored Mar 19, 2020
1 parent a4519f8 commit 1d2c04e
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 15 deletions.
4 changes: 1 addition & 3 deletions src/sagemaker_containers/_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,10 +155,8 @@ def download_and_install(uri, name=DEFAULT_MODULE_NAME, cache=True):

if not should_use_cache:
with _files.tmpdir() as tmpdir:
dst = os.path.join(tmpdir, "tar_file")
_files.download_and_extract(uri, dst)
module_path = os.path.join(tmpdir, "module_dir")
os.makedirs(module_path)
_files.download_and_extract(uri, module_path)
prepare(module_path, name)
install(module_path)

Expand Down
38 changes: 26 additions & 12 deletions test/unit/test_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,17 +218,31 @@ def test_import_module(reload, import_module, install, download_and_extract):
reload.assert_called_with(import_module(_modules.DEFAULT_MODULE_NAME))


def test_download_and_install_local_directory():
@patch("sagemaker_containers._modules.exists", return_value=False)
@patch("sagemaker_containers._files.tmpdir")
@patch("sagemaker_containers._files.download_and_extract")
@patch("sagemaker_containers._modules.prepare")
@patch("sagemaker_containers._modules.install")
def test_download_and_install(install, prepare, download_and_extract, files_tmpdir, module_exists):
files_tmpdir.return_value.__enter__.return_value = "tmp"
uri = "s3://foo/bar"
_modules.download_and_install(uri)

module_path = os.path.join("tmp", "module_dir")
download_and_extract.assert_called_with(uri, module_path)
prepare.assert_called_with(module_path, "default_user_module_name")
install.assert_called_with(module_path)


@patch("sagemaker_containers._files.s3_download")
@patch("tarfile.open")
@patch("sagemaker_containers._modules.prepare")
@patch("sagemaker_containers._modules.install")
def test_download_and_install_local_directory(install, prepare, tarfile, s3_download):
uri = "/opt/ml/input/data/code/sourcedir.tar.gz"
_modules.download_and_install(uri)

with patch("sagemaker_containers._files.s3_download") as s3_download, patch(
"tarfile.open"
) as tarfile, patch("sagemaker_containers._modules.prepare") as prepare, patch(
"sagemaker_containers._modules.install"
) as install:
_modules.download_and_install(uri)

s3_download.assert_not_called()
tarfile.assert_called_with(name="/opt/ml/input/data/code/sourcedir.tar.gz", mode="r:gz")
prepare.assert_called_once()
install.assert_called_once()
s3_download.assert_not_called()
tarfile.assert_called_with(name="/opt/ml/input/data/code/sourcedir.tar.gz", mode="r:gz")
prepare.assert_called_once()
install.assert_called_once()

0 comments on commit 1d2c04e

Please sign in to comment.