Skip to content

Commit

Permalink
Update the system test `lib/system-tests/tests/example_comprehend_doc…
Browse files Browse the repository at this point in the history
…ument_classifier.ts` to use example files from S3 (#44368)
  • Loading branch information
vincbeck authored Nov 26, 2024
1 parent f0af1b3 commit 6d075cb
Showing 1 changed file with 44 additions and 83 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,10 @@
# under the License.
from __future__ import annotations

import os
from datetime import datetime

from airflow import DAG, settings
from airflow import DAG
from airflow.decorators import task, task_group
from airflow.models import Connection
from airflow.models.baseoperator import chain
from airflow.providers.amazon.aws.hooks.comprehend import ComprehendHook
from airflow.providers.amazon.aws.operators.comprehend import (
Expand All @@ -36,31 +34,27 @@
from airflow.providers.amazon.aws.sensors.comprehend import (
ComprehendCreateDocumentClassifierCompletedSensor,
)
from airflow.providers.amazon.aws.transfers.http_to_s3 import HttpToS3Operator
from airflow.utils.trigger_rule import TriggerRule

from providers.tests.system.amazon.aws.utils import SystemTestContextBuilder

ROLE_ARN_KEY = "ROLE_ARN"
sys_test_context_task = SystemTestContextBuilder().add_variable(ROLE_ARN_KEY).build()
BUCKET_NAME_KEY = "BUCKET_NAME"
BUCKET_KEY_DISCHARGE_KEY = "BUCKET_KEY_DISCHARGE"
BUCKET_KEY_DOCTORS_NOTES = "BUCKET_KEY_DOCTORS_NOTES"
sys_test_context_task = (
SystemTestContextBuilder()
.add_variable(ROLE_ARN_KEY)
.add_variable(BUCKET_NAME_KEY)
.add_variable(BUCKET_KEY_DISCHARGE_KEY)
.add_variable(BUCKET_KEY_DOCTORS_NOTES)
.build()
)

DAG_ID = "example_comprehend_document_classifier"
ANNOTATION_BUCKET_KEY = "training-labels/label.csv"
TRAINING_DATA_PREFIX = "training-docs"

# To create a custom document classifier, we need a minimum of 10 documents for each label.
# for testing purpose, we will generate 10 copies of each document referenced below.
PUBLIC_DATA_SOURCES = [
{
"fileName": "discharge-summary.pdf",
"endpoint": "aws-samples/amazon-comprehend-examples/blob/master/building-custom-classifier/sample-docs/discharge-summary.pdf?raw=true",
},
{
"fileName": "doctors-notes.pdf",
"endpoint": "aws-samples/amazon-comprehend-examples/blob/master/building-custom-classifier/sample-docs/doctors-notes.pdf?raw=true",
},
]

# Annotations file won't allow headers
# label,document name,page number

Expand Down Expand Up @@ -119,74 +113,27 @@ def delete_classifier(document_classifier_arn: str):
)


@task_group
def copy_data_to_s3(bucket: str, sources: list[dict], prefix: str, number_of_copies=1):
"""
Copy some sample data to S3 using HttpToS3Operator.
:param bucket: Name of the Amazon S3 bucket to send the data.
:param prefix: Folder to store the files
:param number_of_copies: Number of files to create for a document from the sources
:param sources: Public available data locations
"""

"""
EX: If number_of_copies is 2, sources has file name 'file.pdf', and prefix is 'training-docs'.
Will generate two copies and upload to s3:
- training-docs/file-0.pdf
- training-docs/file-1.pdf
"""

http_to_s3_configs = [
@task
def create_kwargs_discharge():
return [
{
"endpoint": source["endpoint"],
"s3_key": f"{prefix}/{os.path.splitext(os.path.basename(source['fileName']))[0]}-0{os.path.splitext(os.path.basename(source['fileName']))[1]}",
"source_bucket_key": str(test_context[BUCKET_KEY_DISCHARGE_KEY]),
"dest_bucket_key": f"{TRAINING_DATA_PREFIX}/discharge-summary-{counter}.pdf",
}
for source in sources
for counter in range(10)
]
copy_to_s3_configs = [


@task
def create_kwargs_doctors_notes():
return [
{
"source_bucket_key": f"{prefix}/{os.path.splitext(os.path.basename(source['fileName']))[0]}-0{os.path.splitext(os.path.basename(source['fileName']))[1]}",
"dest_bucket_key": f"{prefix}/{os.path.splitext(os.path.basename(source['fileName']))[0]}-{counter}{os.path.splitext(os.path.basename(source['fileName']))[1]}",
"source_bucket_key": str(test_context[BUCKET_KEY_DOCTORS_NOTES]),
"dest_bucket_key": f"{TRAINING_DATA_PREFIX}/doctors-notes-{counter}.pdf",
}
for counter in range(number_of_copies)
for source in sources
for counter in range(10)
]

@task
def create_connection(conn_id):
conn = Connection(
conn_id=conn_id,
conn_type="http",
host="https://github.com/",
)
session = settings.Session()
session.add(conn)
session.commit()

@task(trigger_rule=TriggerRule.ALL_DONE)
def delete_connection(conn_id):
session = settings.Session()
conn_to_details = session.query(Connection).filter(Connection.conn_id == conn_id).first()
session.delete(conn_to_details)
session.commit()

http_to_s3_task = HttpToS3Operator.partial(
task_id="http_to_s3_task",
http_conn_id=http_conn_id,
s3_bucket=bucket,
).expand_kwargs(http_to_s3_configs)

s3_copy_task = S3CopyObjectOperator.partial(
task_id="s3_copy_task",
source_bucket_name=bucket,
dest_bucket_name=bucket,
meta_data_directive="REPLACE",
).expand_kwargs(copy_to_s3_configs)

chain(create_connection(http_conn_id), http_to_s3_task, s3_copy_task, delete_connection(http_conn_id))


with DAG(
dag_id=DAG_ID,
Expand All @@ -199,7 +146,6 @@ def delete_connection(conn_id):
env_id = test_context["ENV_ID"]
classifier_name = f"{env_id}-custom-document-classifier"
bucket_name = f"{env_id}-comprehend-document-classifier"
http_conn_id = f"{env_id}-git"

input_data_configurations = {
"S3Uri": f"s3://{bucket_name}/{ANNOTATION_BUCKET_KEY}",
Expand All @@ -219,6 +165,22 @@ def delete_connection(conn_id):
bucket_name=bucket_name,
)

discharge_kwargs = create_kwargs_discharge()
s3_copy_discharge_task = S3CopyObjectOperator.partial(
task_id="s3_copy_discharge_task",
source_bucket_name=test_context[BUCKET_NAME_KEY],
dest_bucket_name=bucket_name,
meta_data_directive="REPLACE",
).expand_kwargs(discharge_kwargs)

doctors_notes_kwargs = create_kwargs_doctors_notes()
s3_copy_doctors_notes_task = S3CopyObjectOperator.partial(
task_id="s3_copy_doctors_notes_task",
source_bucket_name=test_context[BUCKET_NAME_KEY],
dest_bucket_name=bucket_name,
meta_data_directive="REPLACE",
).expand_kwargs(doctors_notes_kwargs)

upload_annotation_file = S3CreateObjectOperator(
task_id="upload_annotation_file",
s3_bucket=bucket_name,
Expand All @@ -236,10 +198,9 @@ def delete_connection(conn_id):
chain(
test_context,
create_bucket,
s3_copy_discharge_task,
s3_copy_doctors_notes_task,
upload_annotation_file,
copy_data_to_s3(
bucket=bucket_name, sources=PUBLIC_DATA_SOURCES, prefix=TRAINING_DATA_PREFIX, number_of_copies=10
),
# TEST BODY
document_classifier_workflow(),
# TEST TEARDOWN
Expand Down

0 comments on commit 6d075cb

Please sign in to comment.