Skip to content

Commit

Permalink
fixup(pipeline) : Simplify sync_cities DAG
Browse files Browse the repository at this point in the history
Directly get the data to the database and use DBT transforms to clean
and test the data.
  • Loading branch information
vperron committed Sep 10, 2024
1 parent 223b449 commit 923148d
Show file tree
Hide file tree
Showing 10 changed files with 178 additions and 257 deletions.
2 changes: 1 addition & 1 deletion pipeline/dags/dag_utils/marts.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def _export_di_dataset_to_s3(logical_date, run_id):
RESOURCES = {
"structures": "public_marts.marts_inclusion__structures",
"services": "public_marts.marts_inclusion__services",
"communes": "decoupage_administratif.communes",
"communes": "public_staging.stg_decoupage_administratif__communes",
}

for resource_name, sql_table in RESOURCES.items():
Expand Down
299 changes: 47 additions & 252 deletions pipeline/dags/sync_cities.py
Original file line number Diff line number Diff line change
@@ -1,267 +1,81 @@
import pendulum

from airflow.decorators import dag, task
from airflow.operators import empty, python
from airflow.operators import empty

from dag_utils import date
from dag_utils.notifications import format_failure, notify_webhook
from dag_utils import date, notifications
from dag_utils.virtualenvs import PYTHON_BIN_PATH

# TODO(vmttn): create a source dag factory
default_args = {
"on_failure_callback": lambda context: notify_webhook(
context,
conn_id="mattermost",
format_fn=format_failure,
)
}


@task.external_python(
python=str(PYTHON_BIN_PATH),
retries=2,
)
def extract_cities(run_id, logical_date):
import json
import urllib.parse

import httpx

from dag_utils import s3

def to_json_city(data):
return {
"code": data["code"],
"nom": data["nom"],
"region": data["codeRegion"],
"departement": data["codeDepartement"],
"siren_epci": data.get("codeEpci"),
"centre": data["centre"],
"codes_postaux": sorted(data["codesPostaux"]),
}

def format_district(raw_city):
raw_city["nom"] = raw_city["nom"].replace("Arrondissement", "").strip()
# FIXME(vperron) : this is a hack to get the SIREN of the EPCI for districts
# It might become invalid in the future but it's QUITE unlikely.
raw_city["codeEpci"] = {
"13": "200054807", # Marseille
"69": "200046977", # Lyon
"75": "200054781", # Paris
}[raw_city["codeDepartement"]]
return raw_city

def fetch_cities(districts_only=False):
params = {
"fields": (
"nom,code,codesPostaux,codeDepartement,codeRegion,codeEpci,centre"
),
"format": "json",
}
if districts_only:
params["type"] = "arrondissement-municipal"
response = httpx.get(
urllib.parse.urljoin(
"https://geo.api.gouv.fr/", f"communes?{urllib.parse.urlencode(params)}"
)
)
response.raise_for_status()
answer = response.json()
if districts_only:
answer = [format_district(raw_city) for raw_city in answer]

return answer

cities = sorted(
[to_json_city(c) for c in fetch_cities() + fetch_cities(districts_only=True)],
key=lambda c: c["nom"],
)
s3.store_content(
s3.source_file_path("api_geo", "communes.json", run_id, logical_date),
json.dumps(cities).encode(),
)


@task.external_python(
python=str(PYTHON_BIN_PATH),
retries=2,
)
def extract_metadata(run_id, logical_date):
import json

import httpx

from dag_utils import s3

def fetch_store_resource(name):
response = httpx.get(f"https://geo.api.gouv.fr/{name}")
response.raise_for_status()
answer = sorted(
response.json(),
key=lambda c: c["code"],
)
s3.store_content(
s3.source_file_path("api_geo", f"{name}.json", run_id, logical_date),
json.dumps(answer).encode(),
)

fetch_store_resource("departements")
fetch_store_resource("regions")
fetch_store_resource("epcis")


def load_cities_from_s3(run_id, logical_date):
def extract_load_from_api(run_id, logical_date):
import logging

from geoalchemy2 import Geometry
import pandas as pd
from furl import furl
from sqlalchemy import types

from dag_utils import pg, s3
from dag_utils.sources import utils as source_utils

DB_SCHEMA = "decoupage_administratif"
TABLE_NAME = "communes"
from dag_utils import pg

logger = logging.getLogger(__name__)

s3_file_path = s3.source_file_path(
source_id="api_geo",
filename="communes.json",
run_id=run_id,
logical_date=logical_date,
)

tmp_file_path = s3.download_file(s3_file_path)

logger.info("Downloading file s3_path=%s tmp_path=%s", s3_file_path, tmp_file_path)

df = source_utils.read_json(path=tmp_file_path)

def create_point(geom):
return f"POINT({geom['coordinates'][0]} {geom['coordinates'][1]})"

df["centre"] = df["centre"].apply(create_point)

pg.create_schema(DB_SCHEMA)

with pg.connect_begin() as conn:
df.to_sql(
f"{TABLE_NAME}_tmp",
con=conn,
schema=DB_SCHEMA,
if_exists="replace",
index=False,
dtype={
"centre": Geometry("POINT", srid=4326),
"codes_postaux": types.ARRAY(types.TEXT),
},
city_params = {
"fields": ("nom,code,codesPostaux,codeDepartement,codeRegion,codeEpci,centre"),
"format": "json",
}
city_dtypes = {
"centre": types.JSON,
"codesPostaux": types.ARRAY(types.TEXT),
}

for resource, query_params, table_name, dtypes in [
("departements", {"zone": "metro,drom,com"}, None, None),
("regions", {"zone": "metro,drom,com"}, None, None),
("epcis", None, None, None),
("communes", city_params, "communes", city_dtypes),
(
"communes",
city_params | {"type": "arrondissement-municipal"},
"districts",
city_dtypes,
),
]:
url = (
(furl("https://geo.api.gouv.fr") / resource)
.set(query_params=query_params)
.url
)

conn.execute(
f"""\
CREATE TABLE IF NOT EXISTS {DB_SCHEMA}.{TABLE_NAME} (
code TEXT PRIMARY KEY,
nom TEXT NOT NULL,
departement TEXT NOT NULL,
region TEXT NOT NULL,
siren_epci TEXT NULL,
centre GEOMETRY(Point, 4326) NOT NULL,
codes_postaux TEXT[] NOT NULL
);
TRUNCATE {DB_SCHEMA}.{TABLE_NAME};
INSERT INTO {DB_SCHEMA}.{TABLE_NAME}
SELECT * FROM {DB_SCHEMA}.{TABLE_NAME}_tmp;
DROP TABLE {DB_SCHEMA}.{TABLE_NAME}_tmp;"""
)


def load_metadata_from_s3(run_id, logical_date):
import logging

from sqlalchemy import types

from dag_utils import pg, s3
from dag_utils.sources import utils as source_utils

logger = logging.getLogger(__name__)

DB_SCHEMA = "decoupage_administratif"

def load_resource(name, columns, sql_schema, dtype=None):
s3_file_path = s3.source_file_path(
source_id="api_geo",
filename=f"{name}.json",
run_id=run_id,
logical_date=logical_date,
)
tmp_file_path = s3.download_file(s3_file_path)
logger.info(
"Downloading file s3_path=%s tmp_path=%s", s3_file_path, tmp_file_path
)
df = source_utils.read_json(path=tmp_file_path)
logger.info(f"> fetching resource={resource} from url={url}")
df = pd.read_json(url, dtype=False)
with pg.connect_begin() as conn:
print(f">>> creating table={name} with df={df}")
df = df[list(columns)]
schema = "decoupage_administratif"
table_name = table_name or resource
df.to_sql(
f"{name}_tmp",
f"{table_name}_tmp",
con=conn,
schema=DB_SCHEMA,
schema=schema,
if_exists="replace",
index=False,
dtype={} if not dtype else dtype,
dtype=dtypes,
)

print(f">>> replacing table={name} with schema={sql_schema}")
conn.execute(
f"""\
CREATE TABLE IF NOT EXISTS {DB_SCHEMA}.{name} (
{sql_schema}
);
TRUNCATE {DB_SCHEMA}.{name};
INSERT INTO {DB_SCHEMA}.{name}
SELECT * FROM {DB_SCHEMA}.{name}_tmp;
DROP TABLE {DB_SCHEMA}.{name}_tmp;"""
CREATE TABLE IF NOT EXISTS {schema}.{table_name}
(LIKE {schema}.{table_name}_tmp);
TRUNCATE {schema}.{table_name};
INSERT INTO {schema}.{table_name}
(SELECT * FROM {schema}.{table_name}_tmp);
DROP TABLE {schema}.{table_name}_tmp;
"""
)

pg.create_schema(DB_SCHEMA)

load_resource(
"regions",
("code", "nom"),
"""
code TEXT PRIMARY KEY,
nom TEXT NOT NULL
""",
)
load_resource(
"departements",
("code", "nom", "codeRegion"),
"""
code TEXT PRIMARY KEY,
nom TEXT NOT NULL,
codeRegion TEXT NOT NULL
""",
)
load_resource(
"epcis",
("code", "nom", "codesDepartements", "codesRegions", "population"),
"""
code TEXT PRIMARY KEY,
nom TEXT NOT NULL,
codesDepartements TEXT[] NOT NULL,
codesRegions TEXT[] NOT NULL,
population INT NOT NULL
""",
{
"codesRegions": types.ARRAY(types.TEXT),
"codesDepartements": types.ARRAY(types.TEXT),
},
)


@dag(
start_date=pendulum.datetime(2022, 1, 1, tz=date.TIME_ZONE),
default_args=default_args,
default_args=notifications.notify_failure_args(),
schedule="@monthly",
catchup=False,
tags=["source"],
Expand All @@ -270,26 +84,7 @@ def sync_cities():
start = empty.EmptyOperator(task_id="start")
end = empty.EmptyOperator(task_id="end")

load_metadata = python.ExternalPythonOperator(
task_id="load_metadata",
python=str(PYTHON_BIN_PATH),
python_callable=load_metadata_from_s3,
)

load_cities = python.ExternalPythonOperator(
task_id="load_cities",
python=str(PYTHON_BIN_PATH),
python_callable=load_cities_from_s3,
)

(
start
>> extract_metadata()
>> extract_cities()
>> load_metadata
>> load_cities
>> end
)
(start >> extract_load_from_api() >> end)


sync_cities()
1 change: 1 addition & 0 deletions pipeline/dbt/models/_sources.yml
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ sources:
schema: decoupage_administratif
tables:
- name: communes
- name: districts
- name: departements
- name: epcis
- name: regions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ WITH agences AS (
),

communes AS (
SELECT * FROM {{ source('decoupage_administratif', 'communes') }}
SELECT * FROM {{ ref('stg_decoupage_administratif__communes') }}
),

final AS (
Expand Down
Loading

0 comments on commit 923148d

Please sign in to comment.