Skip to content

Commit

Permalink
Modify add_license_url DAG to use batched_update (#4370)
Browse files Browse the repository at this point in the history
Co-authored-by: Madison Swain-Bowden <[email protected]>
  • Loading branch information
krysal and AetherUnbound authored May 28, 2024
1 parent 183e6c6 commit 3329d30
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 133 deletions.
233 changes: 100 additions & 133 deletions catalog/dags/maintenance/add_license_url.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,15 @@
from airflow.exceptions import AirflowSkipException
from airflow.models.abstractoperator import AbstractOperator
from airflow.models.param import Param
from airflow.utils.state import State
from airflow.utils.trigger_rule import TriggerRule
from airflow.operators.trigger_dagrun import TriggerDagRunOperator
from psycopg2._json import Json
from tabulate import tabulate

from common import slack
from common.constants import DAG_DEFAULT_ARGS, POSTGRES_CONN_ID
from common.licenses import get_license_info_from_license_pair
from common.sql import RETURN_ROW_COUNT, PostgresHook
from common.sql import PostgresHook
from database.batched_update.constants import DAG_ID as BATCHED_UPDATE_DAG_ID


DAG_ID = "add_license_url"
Expand Down Expand Up @@ -54,148 +55,113 @@ def run_sql(


@task
def get_license_groups(query: str, ti=None) -> list[tuple[str, str]]:
def get_licenses(ti=None) -> list[tuple[str, str, str]]:
"""
Get license groups of rows that don't have a `license_url` in their
`meta_data` field.
`meta_data` field and notify the start of the DAG.
:return: List of (license, version) tuples.
:return: List of license_info tuples.
"""
query = dedent("""
SELECT license, license_version, count(identifier)
FROM image WHERE meta_data->>'license_url' IS NULL
GROUP BY license, license_version
""")
license_groups = run_sql(query, dag_task=ti.task)

total_nulls = sum(group[2] for group in license_groups)
licenses_detailed = "\n".join(
f"{group[0]} \t{group[1]} \t{group[2]}" for group in license_groups
licenses, invalid = [], []
headers = ["license", "version", "count"]
tabulate_params = {
"headers": headers,
"showindex": True,
"tablefmt": "rounded_grid",
"floatfmt": ".1f",
"intfmt": ",",
}

for row in license_groups:
license_, license_version, _ = row
license_info = get_license_info_from_license_pair(license_, license_version)
if license_info is None:
invalid.append(row)
else:
licenses.append(license_info)

license_groups = [lg for lg in license_groups if lg not in invalid]

message = (
f"""
Starting `{DAG_ID}` DAG. Found {len(license_groups):.0f} license groups with {total_nulls:.0f}
records to back fill `license_url` in `meta_data`.\nCount per license-version:
```
{tabulate(license_groups, **tabulate_params)}
```
"""
if license_groups
else f"""
No license groups found with records missing `license_url` in `meta_data`. The `{DAG_ID}` DAG is done.
"""
)

message = f"""
Starting `{DAG_ID}` DAG. Found {len(license_groups)} license groups with {total_nulls}
records without `license_url` in `meta_data` left.\nCount per license-version:
{licenses_detailed}
"""
if invalid:
message += f"""
\nThe following *invalid license(s)* were found and will be skipped:
```
{tabulate(invalid, **tabulate_params)}
```
"""

slack.send_message(
message,
username="Airflow DAG Data Normalization - license_url",
dag_id=DAG_ID,
)

return [(group[0], group[1]) for group in license_groups]
return licenses


@task(max_active_tis_per_dag=1, execution_timeout=timedelta(hours=36))
def update_license_url(license_group: tuple[str, str], batch_size: int, ti=None) -> int:
"""
Add license_url to meta_data batching all records with the same license.
:param license_group: tuple of license and version
:param batch_size: number of records to update in one update statement
:param ti: automatically passed by Airflow, used to set the execution timeout.
"""
license_, version = license_group
license_info = get_license_info_from_license_pair(license_, version)
if license_info is None:
raise AirflowSkipException(
f"No license pair ({license_}, {version}) in the license map."
)
*_, license_url = license_info

logging.info(
f"Will add `license_url` in `meta_data` for records with license "
f"{license_} {version} to {license_url}."
)
def get_license_conf(license_info) -> dict:
license_, license_version, license_url = license_info
license_url_dict = {"license_url": license_url}
query_id = f"add_license_url_{license_}_{license_version}"
for char_to_remove in [".", "-"]:
query_id = query_id.replace(char_to_remove, "_")

conf = {
"query_id": query_id,
"table_name": "image",
"select_query": (
f"WHERE license = '{license_}' AND license_version = '{license_version}' "
f"AND meta_data->>'license_url' IS NULL"
),
# Merge existing metadata with the new license_url
"update_query": f"SET meta_data = ({Json(license_url_dict)}::jsonb || meta_data), updated_on = now()",
"update_timeout": 259200, # 3 days in seconds
"dry_run": False,
"resume_update": False,
}
return conf

# Merge existing metadata with the new license_url
update_query = dedent(
f"""
UPDATE image
SET meta_data = ({Json(license_url_dict)}::jsonb || meta_data), updated_on = now()
WHERE identifier IN (
SELECT identifier
FROM image
WHERE license = '{license_}' AND license_version = '{version}'
AND meta_data->>'license_url' IS NULL
LIMIT {batch_size}
FOR UPDATE SKIP LOCKED
);
"""
)
total_updated = 0
updated_count = 1
while updated_count:
updated_count = run_sql(
update_query,
log_sql=total_updated == 0,
method="run",
handler=RETURN_ROW_COUNT,
autocommit=True,
dag_task=ti.task,
)
total_updated += updated_count
logger.info(f"Updated {total_updated} rows with {license_url}.")

return total_updated


@task(trigger_rule=TriggerRule.ALL_DONE)
def report_completion(updated, query: str, ti=None):
"""
Check for null in `meta_data` and send a message to Slack with the statistics
of the DAG run.
:param updated: total number of records updated
:param query: SQL query to get the count of records left with `license_url` as NULL
:param ti: automatically passed by Airflow, used to set the execution timeout.
"""
total_updated = sum(updated) if updated else 0

license_groups = run_sql(query, dag_task=ti.task)
total_nulls = sum(group[2] for group in license_groups)
licenses_detailed = "\n".join(
f"{group[0]} \t{group[1]} \t{group[2]}" for group in license_groups
)

message = f"""
`{DAG_ID}` DAG run completed. Updated {total_updated} record(s) with `license_url` in the
`meta_data` field. Found {len(license_groups)} license groups with {total_nulls} record(s) left pending.
"""
if total_nulls != 0:
message += f"\nCount per license-version:\n{licenses_detailed}"

slack.send_message(
message,
username="Airflow DAG Data Normalization - license_url",
dag_id=DAG_ID,
)

@task
def get_confs(licenses, batch_size: int) -> list[dict]:
if not licenses:
raise AirflowSkipException(
"Found no licenses to backfill. No DAG config is required."
)

@task(trigger_rule=TriggerRule.ALL_DONE)
def report_failed_license_pairs(dag_run=None):
"""
Send a message to Slack with the license-version pairs that could not be found
in the license map.
"""
skipped_tasks = [
dag_task
for dag_task in dag_run.get_task_instances(state=State.SKIPPED)
if "update_license_url" in dag_task.task_id
return [
{"batch_size": batch_size, **get_license_conf(license_info)}
for license_info in licenses
]

if not skipped_tasks:
raise AirflowSkipException

message = (
f"""
One or more license pairs could not be found in the license map while running
the `{DAG_ID}` DAG. See the logs for more details:
"""
) + "\n".join(
f" - <{dag_task.log_url}|{dag_task.task_id}>" for dag_task in skipped_tasks[:5]
)

slack.send_alert(
message,
username="Airflow DAG Data Normalization - license_url",
@task
def notify_slack():
slack.send_message(
"Finished processing the groups of licenses.",
username=f"Airflow DAG Data Normalization - {DAG_ID}",
dag_id=DAG_ID,
)

Expand All @@ -221,18 +187,19 @@ def report_failed_license_pairs(dag_run=None):
},
)
def add_license_url():
query = dedent("""
SELECT license, license_version, count(identifier)
FROM image WHERE meta_data->>'license_url' IS NULL
GROUP BY license, license_version
""")

license_groups = get_license_groups(query)
updated = update_license_url.partial(batch_size="{{ params.batch_size }}").expand(
license_group=license_groups
)
report_completion(updated, query)
updated >> report_failed_license_pairs()
licenses = get_licenses()

trigger = TriggerDagRunOperator.partial(
task_id="trigger_batched_update",
trigger_dag_id=BATCHED_UPDATE_DAG_ID,
wait_for_completion=True,
execution_timeout=timedelta(hours=5),
max_active_tis_per_dag=1,
map_index_template="""{{ task.conf['query_id'] }}""",
retries=0,
).expand(conf=get_confs(licenses, batch_size="{{ params.batch_size }}"))

trigger >> notify_slack()


add_license_url()
1 change: 1 addition & 0 deletions catalog/requirements-prod.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@ psycopg2-binary
requests-file==2.0.*
requests-oauthlib
retry==0.9.2
tabulate==0.9.0
tldextract==5.1.2

0 comments on commit 3329d30

Please sign in to comment.