Skip to content

Commit

Permalink
feat: update s3 data source for us-isof partition (#311)
Browse files Browse the repository at this point in the history
* feat: update s3 data source for us-isof partition

* add built in dataaset iso region mapping
  • Loading branch information
oyangz authored Jul 18, 2024
1 parent 5e70ca7 commit a38c880
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 14 deletions.
3 changes: 3 additions & 0 deletions src/fmeval/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
BUILT_IN_DATASET_PREFIX = "s3://fmeval/datasets"
BUILT_IN_DATASET_DEFAULT_REGION = "us-west-2"

# Mapping of iso region to built in dataset region in the same partition
BUILT_IN_DATASET_ISO_REGIONS = {"us-isof-south-1": "us-isof-south-1", "us-isof-east-1": "us-isof-south-1"}

# Environment variable for disabling telemetry
DISABLE_FMEVAL_TELEMETRY = "DISABLE_FMEVAL_TELEMETRY"

Expand Down
31 changes: 23 additions & 8 deletions src/fmeval/data_loaders/data_sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@
import urllib.parse
from typing import IO
from abc import ABC, abstractmethod
from fmeval.constants import BUILT_IN_DATASET_PREFIX, BUILT_IN_DATASET_DEFAULT_REGION
from fmeval.constants import (
BUILT_IN_DATASET_PREFIX,
BUILT_IN_DATASET_DEFAULT_REGION,
BUILT_IN_DATASET_ISO_REGIONS,
)
from fmeval.exceptions import EvalAlgorithmClientError


Expand Down Expand Up @@ -110,14 +114,25 @@ def __reduce__(self):

def get_s3_client(uri: str) -> boto3.client:
"""
Util method to return boto3 s3 client. For built-in datasets, the boto3 client region is default to us-west-2 as
the bucket is not accessible in opt-in regions.
Util method to return boto3 s3 client. For built-in datasets, the boto3 client region is default to us-west-2 for
commercial regions as the bucket is not accessible in opt-in regions.
For us-isof partition, built-in datasets are located in us-isof-south-1 region.
:param uri: s3 dataset uri
:return: boto3 s3 client
"""
s3_client = (
boto3.client("s3", region_name=BUILT_IN_DATASET_DEFAULT_REGION)
if uri.startswith(BUILT_IN_DATASET_PREFIX)
else boto3.client("s3")
)
session = boto3.session.Session()
region = session.region_name
if region in BUILT_IN_DATASET_ISO_REGIONS.keys():
s3_client = (
boto3.client("s3", region_name=BUILT_IN_DATASET_ISO_REGIONS[region], verify=False)
if uri.startswith(BUILT_IN_DATASET_PREFIX)
else boto3.client("s3", verify=False)
)
else:
s3_client = (
boto3.client("s3", region_name=BUILT_IN_DATASET_DEFAULT_REGION)
if uri.startswith(BUILT_IN_DATASET_PREFIX)
else boto3.client("s3")
)
return s3_client
36 changes: 30 additions & 6 deletions test/unit/data_loaders/test_data_sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import botocore.errorfactory

from unittest.mock import patch, Mock, mock_open

from fmeval.constants import BUILT_IN_DATASET_ISO_REGIONS
from fmeval.data_loaders.data_sources import LocalDataFile, S3DataFile, S3Uri, get_s3_client
from fmeval.eval_algorithms import DATASET_CONFIGS, TREX
from fmeval.exceptions import EvalAlgorithmClientError
Expand Down Expand Up @@ -81,25 +83,47 @@ def test_key(self, uri, key):
assert s3_uri.key == key


def test_get_s3_client_built_in_dataset():
@pytest.mark.parametrize(
"run_region, dataset_region",
[
("us-west-2", "us-west-2"),
("ap-east-1", "us-west-2"),
("us-isof-south-1", "us-isof-south-1"),
("us-isof-east-1", "us-isof-south-1"),
],
)
@patch("boto3.session.Session")
def test_get_s3_client_built_in_dataset(mock_session_class, run_region, dataset_region):
"""
GIVEN a built-in dataset s3 path
WHEN get_s3_client is called
THEN the boto3 s3 client is created with region name "us-west-2"
THEN the boto3 s3 client is created with corresponding built-in dataset region name
"""
with patch("boto3.client") as mock_client:
mock_instance = mock_session_class.return_value
mock_instance.region_name = run_region
dataset_uri = DATASET_CONFIGS[TREX].dataset_uri
s3_client = get_s3_client(dataset_uri)
mock_client.assert_called_once_with("s3", region_name="us-west-2")
if dataset_region in BUILT_IN_DATASET_ISO_REGIONS.values():
mock_client.assert_called_once_with("s3", region_name=dataset_region, verify=False)
else:
mock_client.assert_called_once_with("s3", region_name=dataset_region)


def test_get_s3_client_custom_dataset():
@pytest.mark.parametrize("region", ["us-west-2", "ap-east-1", "us-isof-south-1", "us-isof-east-1"])
@patch("boto3.session.Session")
def test_get_s3_client_custom_dataset(mock_session_class, region):
"""
GIVEN a custom dataset s3 path
WHEN get_s3_client is called
THEN the boto3 s3 client is created without region name
"""
with patch("boto3.client") as mock_client:
dataset_uri = S3_PREFIX + DATASET_URI
mock_instance = mock_session_class.return_value
mock_instance.region_name = region
dataset_uri = dataset_uri = S3_PREFIX + DATASET_URI
s3_client = get_s3_client(dataset_uri)
mock_client.assert_called_once_with("s3")
if region in BUILT_IN_DATASET_ISO_REGIONS.keys():
mock_client.assert_called_once_with("s3", verify=False)
else:
mock_client.assert_called_once_with("s3")

0 comments on commit a38c880

Please sign in to comment.