From 52cb4b6a66a08294e6e87ae988ceb481dfd8b5f0 Mon Sep 17 00:00:00 2001 From: Marcio Vinicius dos Santos Date: Tue, 7 May 2019 14:30:43 -0700 Subject: [PATCH] fix: remove unnecessary name argument from download and extract function --- src/sagemaker_containers/_files.py | 5 ++--- src/sagemaker_containers/_modules.py | 4 ++-- src/sagemaker_containers/_process.py | 6 ++++-- src/sagemaker_containers/entry_point.py | 2 +- test/unit/test_entry_point.py | 2 +- test/unit/test_files.py | 6 +++--- test/unit/test_modules.py | 5 ++--- 7 files changed, 15 insertions(+), 15 deletions(-) diff --git a/src/sagemaker_containers/_files.py b/src/sagemaker_containers/_files.py index 6d11b6f..302aa27 100644 --- a/src/sagemaker_containers/_files.py +++ b/src/sagemaker_containers/_files.py @@ -105,13 +105,12 @@ def read_json(path): # type: (str) -> dict return json.load(f) -def download_and_extract(uri, name, path): # type: (str, str, str) -> None +def download_and_extract(uri, path): # type: (str, str) -> None """Download, prepare and install a compressed tar file from S3 or local directory as an entry point. SageMaker Python SDK saves the user provided entry points as compressed tar files in S3 Args: - name (str): name of the entry point. uri (str): the location of the entry point. path (bool): The path where the script will be installed. It will not download and install the if the path already has the user entry point. @@ -134,7 +133,7 @@ def download_and_extract(uri, name, path): # type: (str, str, str) -> None shutil.rmtree(path) shutil.move(uri, path) else: - shutil.copy2(uri, os.path.join(path, name)) + shutil.copy2(uri, path) def s3_download(url, dst): # type: (str, str) -> None diff --git a/src/sagemaker_containers/_modules.py b/src/sagemaker_containers/_modules.py index 0f42997..bd8a50b 100644 --- a/src/sagemaker_containers/_modules.py +++ b/src/sagemaker_containers/_modules.py @@ -238,7 +238,7 @@ def import_module(uri, name=DEFAULT_MODULE_NAME, cache=None): # type: (str, str (module): the imported module """ _warning_cache_deprecation(cache) - _files.download_and_extract(uri, name, _env.code_dir) + _files.download_and_extract(uri, _env.code_dir) prepare(_env.code_dir, name) install(_env.code_dir) @@ -271,7 +271,7 @@ def run_module(uri, args, env_vars=None, name=DEFAULT_MODULE_NAME, cache=None, w env_vars = env_vars or {} env_vars = env_vars.copy() - _files.download_and_extract(uri, name, _env.code_dir) + _files.download_and_extract(uri, _env.code_dir) prepare(_env.code_dir, name) install(_env.code_dir) diff --git a/src/sagemaker_containers/_process.py b/src/sagemaker_containers/_process.py index f410396..bebf803 100644 --- a/src/sagemaker_containers/_process.py +++ b/src/sagemaker_containers/_process.py @@ -98,9 +98,11 @@ def run(self, wait=True, capture_error=False): _logging.log_script_invocation(cmd, self._env_vars) if wait: - process = check_error(cmd, _errors.ExecuteUserScriptError, capture_error=capture_error) + process = check_error(cmd, _errors.ExecuteUserScriptError, + capture_error=capture_error, cwd=_env.code_dir) else: - process = create(cmd, _errors.ExecuteUserScriptError, capture_error=capture_error) + process = create(cmd, _errors.ExecuteUserScriptError, + capture_error=capture_error, cwd=_env.code_dir) self._tear_down() diff --git a/src/sagemaker_containers/entry_point.py b/src/sagemaker_containers/entry_point.py index faf5c91..252fffb 100644 --- a/src/sagemaker_containers/entry_point.py +++ b/src/sagemaker_containers/entry_point.py @@ -80,7 +80,7 @@ def run(uri, env_vars = env_vars or {} env_vars = env_vars.copy() - _files.download_and_extract(uri, user_entry_point, _env.code_dir) + _files.download_and_extract(uri, _env.code_dir) install(user_entry_point, _env.code_dir, capture_error) diff --git a/test/unit/test_entry_point.py b/test/unit/test_entry_point.py index 00b5af7..9a517ae 100644 --- a/test/unit/test_entry_point.py +++ b/test/unit/test_entry_point.py @@ -89,7 +89,7 @@ def test_run_module_wait(chmod, download_and_extract): entry_point.run(uri='s3://url', user_entry_point='launcher.sh', args=['42'], capture_error=True, runner=runner) - download_and_extract.assert_called_with('s3://url', 'launcher.sh', _env.code_dir) + download_and_extract.assert_called_with('s3://url', _env.code_dir) runner.run.assert_called_with(True, True) chmod.assert_called_with(os.path.join(_env.code_dir, 'launcher.sh'), 511) diff --git a/test/unit/test_files.py b/test/unit/test_files.py index 6ac7f5d..92d1b90 100644 --- a/test/unit/test_files.py +++ b/test/unit/test_files.py @@ -115,7 +115,7 @@ def test_write_failure_file(): @patch('shutil.move') def test_download_and_and_extract_source_dir(move, rmtree, s3_download): uri = _env.channel_path('code') - _files.download_and_extract(uri, 'train.sh', _env.code_dir) + _files.download_and_extract(uri, _env.code_dir) s3_download.assert_not_called() rmtree.assert_any_call(_env.code_dir) @@ -127,7 +127,7 @@ def test_download_and_and_extract_source_dir(move, rmtree, s3_download): @patch('shutil.copy2') def test_download_and_and_extract_file(copy, s3_download): uri = _env.channel_path('code') - _files.download_and_extract(uri, 'train.sh', _env.code_dir) + _files.download_and_extract(uri, _env.code_dir) s3_download.assert_not_called() - copy.assert_called_with(uri, os.path.join(_env.code_dir, 'train.sh')) + copy.assert_called_with(uri, _env.code_dir) diff --git a/test/unit/test_modules.py b/test/unit/test_modules.py index 6cb58a3..491b67a 100644 --- a/test/unit/test_modules.py +++ b/test/unit/test_modules.py @@ -174,9 +174,8 @@ def test_run_no_wait(log_script_invocation, create, executable): def test_run_module_wait(download_and_extract, write_env_vars, install, run, wait, cache): with pytest.warns(DeprecationWarning): _modules.run_module(uri='s3://url', args=['42'], wait=wait, cache=cache) - module_name = 'default_user_module_name' - download_and_extract.assert_called_with('s3://url', module_name, _env.code_dir) + download_and_extract.assert_called_with('s3://url', _env.code_dir) write_env_vars.assert_called_with({}) install.assert_called_with(_env.code_dir) @@ -191,7 +190,7 @@ def test_import_module(reload, import_module, install, download_and_extract): _modules.import_module('s3://bucket/my-module') - download_and_extract.assert_called_with('s3://bucket/my-module', 'default_user_module_name', _env.code_dir) + download_and_extract.assert_called_with('s3://bucket/my-module', _env.code_dir) install.assert_called_with(_env.code_dir) reload.assert_called_with(import_module(_modules.DEFAULT_MODULE_NAME))