From 653d537ddcd3da8d4316f3809f89585c468d5446 Mon Sep 17 00:00:00 2001 From: Divyanshu Patel Date: Fri, 11 Oct 2024 17:55:48 +0530 Subject: [PATCH] iam role fixes --- soda/core/soda/common/aws_credentials.py | 2 +- .../redshift/soda/data_sources/redshift_data_source.py | 10 +++++++--- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/soda/core/soda/common/aws_credentials.py b/soda/core/soda/common/aws_credentials.py index 8dcac8ca9..36ca1957f 100644 --- a/soda/core/soda/common/aws_credentials.py +++ b/soda/core/soda/common/aws_credentials.py @@ -58,7 +58,7 @@ def assume_role(self, role_session_name: str): aws_session_token=self.session_token, ) - assumed_role_object = self.sts_client.assume_role(RoleArn=self.role_arn, RoleSessionName=role_session_name) + assumed_role_object = self.sts_client.assume_role(RoleArn=self.role_arn, ExternalId=self.external_id, RoleSessionName=role_session_name) credentials_dict = assumed_role_object["Credentials"] return AwsCredentials( region_name=self.region_name, diff --git a/soda/redshift/soda/data_sources/redshift_data_source.py b/soda/redshift/soda/data_sources/redshift_data_source.py index 9f07bacd3..f4e5f850d 100644 --- a/soda/redshift/soda/data_sources/redshift_data_source.py +++ b/soda/redshift/soda/data_sources/redshift_data_source.py @@ -22,6 +22,9 @@ def __init__(self, logs: Logs, data_source_name: str, data_source_properties: di self.connect_timeout = data_source_properties.get("connection_timeout_sec") self.username = data_source_properties.get("username") self.password = data_source_properties.get("password") + self.dbuser = data_source_properties.get("dbuser") + self.dbname = data_source_properties.get("dbname") + self.cluster_id = data_source_properties.get("cluster_id") if not self.username or not self.password: aws_credentials = AwsCredentials( @@ -31,6 +34,7 @@ def __init__(self, logs: Logs, data_source_name: str, data_source_properties: di session_token=data_source_properties.get("session_token"), region_name=data_source_properties.get("region", "eu-west-1"), profile_name=data_source_properties.get("profile_name"), + external_id=data_source_properties.get("external_id"), ) self.username, self.password = self.__get_cluster_credentials(aws_credentials) @@ -60,9 +64,9 @@ def __get_cluster_credentials(self, aws_credentials: AwsCredentials): aws_session_token=resolved_aws_credentials.session_token, ) - cluster_name = self.host.split(".")[0] - username = self.username - db_name = self.database + cluster_name = self.cluster_id if self.cluster_id else self.host.split(".")[0] + username = self.dbuser if self.dbuser else self.username + db_name = self.dbname if self.dbname else self.database cluster_creds = client.get_cluster_credentials( DbUser=username, DbName=db_name, ClusterIdentifier=cluster_name, AutoCreate=False, DurationSeconds=3600 )