Skip to content
This repository has been archived by the owner on Aug 26, 2020. It is now read-only.

Commit

Permalink
Fix choosing region for S3 client (#112)
Browse files Browse the repository at this point in the history
  • Loading branch information
laurenyu authored Sep 26, 2018
1 parent 65e2556 commit 6816d0b
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 5 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@
CHANGELOG
=========

2.2.2
=====

* bug-fix: Fix choosing region for S3 client

2.2.1
=====

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def read(file_name):

setup(
name='sagemaker_containers',
version='2.2.1',
version='2.2.2',
description='Open source library for creating containers to run on Amazon SageMaker.',

packages=packages,
Expand Down
6 changes: 4 additions & 2 deletions src/sagemaker_containers/_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import six
from six.moves.urllib.parse import urlparse

from sagemaker_containers import _errors, _files, _logging
from sagemaker_containers import _errors, _files, _logging, _params

logger = _logging.get_logger()

Expand All @@ -45,7 +45,9 @@ def s3_download(url, dst): # type: (str, str) -> None

bucket, key = url.netloc, url.path.lstrip('/')

s3 = boto3.resource('s3', region_name=os.environ.get('AWS_REGION'))
region = os.environ.get('AWS_REGION', os.environ.get(_params.REGION_NAME_ENV))
s3 = boto3.resource('s3', region_name=region)

s3.Bucket(bucket).download_file(key, dst)


Expand Down
7 changes: 5 additions & 2 deletions test/unit/test_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import pytest
from six import PY2

from sagemaker_containers import _errors, _modules
from sagemaker_containers import _errors, _modules, _params
import test

builtins_open = '__builtin__.open' if PY2 else 'builtins.open'
Expand All @@ -34,9 +34,12 @@
[('S3://my-bucket/path/to/my-file', 'my-bucket', 'path/to/my-file', '/tmp/my-file'),
('s3://my-bucket/my-file', 'my-bucket', 'my-file', '/tmp/my-file')])
def test_s3_download(resource, url, bucket_name, key, dst):
region = 'us-west-2'
os.environ[_params.REGION_NAME_ENV] = region

_modules.s3_download(url, dst)

chain = call('s3', region_name=os.environ.get('AWS_REGION')).Bucket(bucket_name).download_file(key, dst)
chain = call('s3', region_name=region).Bucket(bucket_name).download_file(key, dst)
assert resource.mock_calls == chain.call_list()


Expand Down

0 comments on commit 6816d0b

Please sign in to comment.