diff --git a/catalog/dags/data_augmentation/rekognition/add_rekognition_labels.py b/catalog/dags/data_augmentation/rekognition/add_rekognition_labels.py index 1e4d2848b90..3a7149aba5a 100644 --- a/catalog/dags/data_augmentation/rekognition/add_rekognition_labels.py +++ b/catalog/dags/data_augmentation/rekognition/add_rekognition_labels.py @@ -55,7 +55,12 @@ def _insert_tags(tags_buffer: types.TagsBuffer, postgres_conn_id: str): postgres_conn_id=postgres_conn_id, default_statement_timeout=constants.INSERT_TIMEOUT, ) - postgres.insert_rows(constants.TEMP_TABLE_NAME, tags_buffer, executemany=True) + postgres.insert_rows( + constants.TEMP_TABLE_NAME, + tags_buffer, + executemany=True, + replace=True, + ) @task(trigger_rule=TriggerRule.NONE_FAILED_MIN_ONE_SUCCESS) @@ -76,7 +81,15 @@ def parse_and_insert_labels( deserialize_json=True, ) - s3_client = S3Hook(aws_conn_id=AWS_CONN_ID).get_client_type("s3") + # If an endpoint is defined for the hook, use the `get_client_type` method + # to retrieve the S3 client. Otherwise, create the client from the session + # so that Airflow doesn't override the endpoint default we want on the S3 client + hook = S3Hook(aws_conn_id=AWS_CONN_ID) + if hook.conn_config.endpoint_url: + get_client = hook.get_client_type + else: + get_client = hook.get_session().client + s3_client = get_client("s3") with smart_open.open( f"{s3_bucket}/{s3_prefix}", transport_params={"buffer_size": file_buffer_size, "client": s3_client},