Skip to content

Commit

Permalink
Create S3 client for smart_open from session (#4886)
Browse files Browse the repository at this point in the history
* Create S3 client for smart_open from session

* Fix the tests by determining which method to use via the endpoint_url

* Replace rows on conflict
  • Loading branch information
AetherUnbound authored Sep 11, 2024
1 parent f54612d commit 9943028
Showing 1 changed file with 15 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,12 @@ def _insert_tags(tags_buffer: types.TagsBuffer, postgres_conn_id: str):
postgres_conn_id=postgres_conn_id,
default_statement_timeout=constants.INSERT_TIMEOUT,
)
postgres.insert_rows(constants.TEMP_TABLE_NAME, tags_buffer, executemany=True)
postgres.insert_rows(
constants.TEMP_TABLE_NAME,
tags_buffer,
executemany=True,
replace=True,
)


@task(trigger_rule=TriggerRule.NONE_FAILED_MIN_ONE_SUCCESS)
Expand All @@ -76,7 +81,15 @@ def parse_and_insert_labels(
deserialize_json=True,
)

s3_client = S3Hook(aws_conn_id=AWS_CONN_ID).get_client_type("s3")
# If an endpoint is defined for the hook, use the `get_client_type` method
# to retrieve the S3 client. Otherwise, create the client from the session
# so that Airflow doesn't override the endpoint default we want on the S3 client
hook = S3Hook(aws_conn_id=AWS_CONN_ID)
if hook.conn_config.endpoint_url:
get_client = hook.get_client_type
else:
get_client = hook.get_session().client
s3_client = get_client("s3")
with smart_open.open(
f"{s3_bucket}/{s3_prefix}",
transport_params={"buffer_size": file_buffer_size, "client": s3_client},
Expand Down

0 comments on commit 9943028

Please sign in to comment.