Skip to content
This repository has been archived by the owner on Aug 26, 2020. It is now read-only.

Commit

Permalink
fix: remove unnecessary name argument from download and extract function
Browse files Browse the repository at this point in the history
  • Loading branch information
mvsusp authored May 7, 2019
1 parent 6fc0585 commit 52cb4b6
Show file tree
Hide file tree
Showing 7 changed files with 15 additions and 15 deletions.
5 changes: 2 additions & 3 deletions src/sagemaker_containers/_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,13 +105,12 @@ def read_json(path): # type: (str) -> dict
return json.load(f)


def download_and_extract(uri, name, path): # type: (str, str, str) -> None
def download_and_extract(uri, path): # type: (str, str) -> None
"""Download, prepare and install a compressed tar file from S3 or local directory as an entry point.
SageMaker Python SDK saves the user provided entry points as compressed tar files in S3
Args:
name (str): name of the entry point.
uri (str): the location of the entry point.
path (bool): The path where the script will be installed. It will not download and install the
if the path already has the user entry point.
Expand All @@ -134,7 +133,7 @@ def download_and_extract(uri, name, path): # type: (str, str, str) -> None
shutil.rmtree(path)
shutil.move(uri, path)
else:
shutil.copy2(uri, os.path.join(path, name))
shutil.copy2(uri, path)


def s3_download(url, dst): # type: (str, str) -> None
Expand Down
4 changes: 2 additions & 2 deletions src/sagemaker_containers/_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def import_module(uri, name=DEFAULT_MODULE_NAME, cache=None): # type: (str, str
(module): the imported module
"""
_warning_cache_deprecation(cache)
_files.download_and_extract(uri, name, _env.code_dir)
_files.download_and_extract(uri, _env.code_dir)

prepare(_env.code_dir, name)
install(_env.code_dir)
Expand Down Expand Up @@ -271,7 +271,7 @@ def run_module(uri, args, env_vars=None, name=DEFAULT_MODULE_NAME, cache=None, w
env_vars = env_vars or {}
env_vars = env_vars.copy()

_files.download_and_extract(uri, name, _env.code_dir)
_files.download_and_extract(uri, _env.code_dir)

prepare(_env.code_dir, name)
install(_env.code_dir)
Expand Down
6 changes: 4 additions & 2 deletions src/sagemaker_containers/_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,11 @@ def run(self, wait=True, capture_error=False):
_logging.log_script_invocation(cmd, self._env_vars)

if wait:
process = check_error(cmd, _errors.ExecuteUserScriptError, capture_error=capture_error)
process = check_error(cmd, _errors.ExecuteUserScriptError,
capture_error=capture_error, cwd=_env.code_dir)
else:
process = create(cmd, _errors.ExecuteUserScriptError, capture_error=capture_error)
process = create(cmd, _errors.ExecuteUserScriptError,
capture_error=capture_error, cwd=_env.code_dir)

self._tear_down()

Expand Down
2 changes: 1 addition & 1 deletion src/sagemaker_containers/entry_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def run(uri,
env_vars = env_vars or {}
env_vars = env_vars.copy()

_files.download_and_extract(uri, user_entry_point, _env.code_dir)
_files.download_and_extract(uri, _env.code_dir)

install(user_entry_point, _env.code_dir, capture_error)

Expand Down
2 changes: 1 addition & 1 deletion test/unit/test_entry_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def test_run_module_wait(chmod, download_and_extract):
entry_point.run(uri='s3://url', user_entry_point='launcher.sh', args=['42'],
capture_error=True, runner=runner)

download_and_extract.assert_called_with('s3://url', 'launcher.sh', _env.code_dir)
download_and_extract.assert_called_with('s3://url', _env.code_dir)
runner.run.assert_called_with(True, True)
chmod.assert_called_with(os.path.join(_env.code_dir, 'launcher.sh'), 511)

Expand Down
6 changes: 3 additions & 3 deletions test/unit/test_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def test_write_failure_file():
@patch('shutil.move')
def test_download_and_and_extract_source_dir(move, rmtree, s3_download):
uri = _env.channel_path('code')
_files.download_and_extract(uri, 'train.sh', _env.code_dir)
_files.download_and_extract(uri, _env.code_dir)
s3_download.assert_not_called()

rmtree.assert_any_call(_env.code_dir)
Expand All @@ -127,7 +127,7 @@ def test_download_and_and_extract_source_dir(move, rmtree, s3_download):
@patch('shutil.copy2')
def test_download_and_and_extract_file(copy, s3_download):
uri = _env.channel_path('code')
_files.download_and_extract(uri, 'train.sh', _env.code_dir)
_files.download_and_extract(uri, _env.code_dir)

s3_download.assert_not_called()
copy.assert_called_with(uri, os.path.join(_env.code_dir, 'train.sh'))
copy.assert_called_with(uri, _env.code_dir)
5 changes: 2 additions & 3 deletions test/unit/test_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,9 +174,8 @@ def test_run_no_wait(log_script_invocation, create, executable):
def test_run_module_wait(download_and_extract, write_env_vars, install, run, wait, cache):
with pytest.warns(DeprecationWarning):
_modules.run_module(uri='s3://url', args=['42'], wait=wait, cache=cache)
module_name = 'default_user_module_name'

download_and_extract.assert_called_with('s3://url', module_name, _env.code_dir)
download_and_extract.assert_called_with('s3://url', _env.code_dir)
write_env_vars.assert_called_with({})
install.assert_called_with(_env.code_dir)

Expand All @@ -191,7 +190,7 @@ def test_import_module(reload, import_module, install, download_and_extract):

_modules.import_module('s3://bucket/my-module')

download_and_extract.assert_called_with('s3://bucket/my-module', 'default_user_module_name', _env.code_dir)
download_and_extract.assert_called_with('s3://bucket/my-module', _env.code_dir)
install.assert_called_with(_env.code_dir)
reload.assert_called_with(import_module(_modules.DEFAULT_MODULE_NAME))

Expand Down

0 comments on commit 52cb4b6

Please sign in to comment.