From aed6016ed24fd19c3ecf9924b4607b3964e2aa51 Mon Sep 17 00:00:00 2001 From: Jaap Blom Date: Thu, 2 Nov 2023 10:46:32 +0100 Subject: [PATCH] more prepared to download and properly use file from 1st worker as input --- feature_extraction.py | 19 +++++++--- io_util.py | 84 +++++++++++++++++++++++++++---------------- models.py | 15 ++++---- worker.py | 26 +++++++------- 4 files changed, 92 insertions(+), 52 deletions(-) diff --git a/feature_extraction.py b/feature_extraction.py index 00690fe..31b71c4 100644 --- a/feature_extraction.py +++ b/feature_extraction.py @@ -9,7 +9,7 @@ from data_handling import VisXPData from models import VisXPFeatureExtractionOutput from provenance import generate_full_provenance_chain -from io_util import get_source_id, export_features +from io_util import get_source_id, save_features_to_file logger = logging.getLogger(__name__) @@ -59,10 +59,21 @@ def extract_features( for i, batch in enumerate(dataset.batches(batch_size=256)): batch_result = apply_model(batch=batch, model=model, device=device) result_list.append(batch_result) - result = torch.cat(result_list) + # concatenate results and save to file + result = torch.cat(result_list) destination = os.path.join(output_path, f"{source_id}.pt") - export_features(result, destination=destination) + file_saved = save_features_to_file(result, destination=destination) + + if not file_saved: + return VisXPFeatureExtractionOutput( + 500, + f"Could not save extracted features to {destination}", + destination, + None, + ) + + # generate provenance, since all went well provenance = generate_full_provenance_chain( start_time=start_time, input_path=input_path, @@ -70,7 +81,7 @@ def extract_features( output_path=destination, ) return VisXPFeatureExtractionOutput( - 200, "Succesfully extracted features", provenance + 200, "Succesfully extracted features", destination, provenance ) # Binarize resulting feature matrix diff --git a/io_util.py b/io_util.py index a663f85..13b6cdc 100644 --- a/io_util.py +++ b/io_util.py @@ -2,21 +2,26 @@ import os from time import time import torch -from typing import List, Tuple, Optional from dane import Document from dane.config import cfg from dane.s3_util import S3Store, parse_s3_uri, validate_s3_uri -from models import CallbackResponse, Provenance, VisXPFeatureExtractionOutput +from models import ( + CallbackResponse, + Provenance, + VisXPFeatureExtractionOutput, + VisXPFeatureExtractionInput, +) logger = logging.getLogger(__name__) DANE_VISXP_PREP_TASK_KEY = "VISXP_PREP" +OUTPUT_FILE_BASE_NAME = "VISXP_FEATURES" # assesses the output and makes sure input & output is handled properly def apply_desired_io_on_output( - input_file: str, + feature_extraction_input: VisXPFeatureExtractionInput, proc_result: VisXPFeatureExtractionOutput, delete_input_on_completion: bool, delete_output_on_completetion: bool, @@ -29,9 +34,7 @@ def apply_desired_io_on_output( return {"state": proc_result.state, "message": proc_result.message} # step 3: process returned successfully, generate the output - source_id = get_source_id( - input_file - ) # TODO: this worker does not necessarily work per source, so consider how to capture output group + source_id = feature_extraction_input.source_id output_path = get_base_output_dir(source_id) # TODO actually make sure this works # step 4: transfer the output to S3 (if configured so) @@ -78,9 +81,39 @@ def get_base_output_dir(source_id: str = "") -> str: return os.path.join(*path_elements) -def export_features(features: torch.Tensor, destination: str): - with open(os.path.join(destination), "wb") as f: - torch.save(obj=features, f=f) +# output file name of the final .pt file that will be uploaded to S3 +# TODO decide whether to tar.gz this as well +def get_output_file_name(source_id: str) -> str: + return f"{OUTPUT_FILE_BASE_NAME}__{source_id}.pt" + + +# e.g. s3:///assets/ +def get_s3_base_uri(source_id: str) -> str: + return f"s3://{os.path.join(cfg.OUTPUT.S3_BUCKET, cfg.OUTPUT.S3_FOLDER_IN_BUCKET, source_id)}" + + +# e.g. s3:///assets//visxp_features__.pt +def get_s3_output_file_uri(source_id: str) -> str: + return f"{get_s3_base_uri(source_id)}/{get_output_file_name(source_id)}" + + +# e.g. s3:///assets//visxp_prep__.tar.gz +# TODO add validation of 1st VisXP worker's S3 URI +def source_id_from_s3_uri(s3_uri: str) -> str: + fn = os.path.basename(s3_uri) + source_id = fn[: -len(".tar.gz")].split("__")[1] + return f"{source_id}" + + +# saves the features to a local file, so it can be uploaded to S3 +def save_features_to_file(features: torch.Tensor, destination: str) -> bool: + try: + with open(os.path.join(destination), "wb") as f: + torch.save(obj=features, f=f) + return True + except Exception: + logger.exception("Failed to save features to file") + return False def delete_local_output(source_id: str) -> bool: @@ -94,21 +127,8 @@ def transfer_output(output_dir: str) -> bool: return True -def get_s3_base_url(source_id: str) -> str: - return f"s3://{os.path.join(cfg.OUTPUT.S3_BUCKET, cfg.OUTPUT.S3_FOLDER_IN_BUCKET, source_id)}" - - -def obtain_files_to_upload_to_s3(output_dir: str) -> List[str]: - s3_file_list = [] - for root, dirs, files in os.walk(output_dir): - for f in files: - s3_file_list.append(os.path.join(root, f)) - return s3_file_list - - -# TODO: implement or replace function calls def get_download_dir(): - return "" + return os.path.join(cfg.FILE_SYSTEM.BASE_MOUNT, cfg.FILE_SYSTEM.INPUT_DIR) # NOTE: untested @@ -141,14 +161,12 @@ def delete_input_file(input_file: str, actually_delete: bool) -> bool: return True # return True even if empty dirs were not removed -def obtain_input_file( - handler, doc: Document -) -> Tuple[Optional[str], Optional[Provenance]]: +def obtain_input_file(handler, doc: Document) -> VisXPFeatureExtractionInput: # first fetch and validate the obtained S3 URI # TODO make sure this is a valid S3 URI s3_uri = _fetch_visxp_prep_s3_uri(handler, doc) if not validate_s3_uri(s3_uri): - return None, None + return VisXPFeatureExtractionInput(500, f"Invalid S3 URI: {s3_uri}") start_time = time() output_folder = get_download_dir() @@ -161,7 +179,7 @@ def obtain_input_file( if success: # TODO uncompress the visxp_prep.tar.gz - download_provenance = Provenance( + provenance = Provenance( activity_name="download", activity_description="Download VISXP_PREP data", start_time_unix=start_time, @@ -169,9 +187,15 @@ def obtain_input_file( input_data={}, output_data={"file_path": input_file_path}, ) - return input_file_path, download_provenance + return VisXPFeatureExtractionInput( + 200, + f"Failed to download: {s3_uri}", + source_id_from_s3_uri(s3_uri), # source_id + input_file_path, # locally downloaded .tar.gz + provenance, + ) logger.error("Failed to download VISXP_PREP data from S3") - return None, None + return VisXPFeatureExtractionInput(500, f"Failed to download: {s3_uri}") def _fetch_visxp_prep_s3_uri(handler, doc: Document) -> str: diff --git a/models.py b/models.py index b8a26d0..4215fc0 100644 --- a/models.py +++ b/models.py @@ -37,13 +37,16 @@ def to_json(self): @dataclass class VisXPFeatureExtractionInput: - state: int - message: str - provenance: Optional[Provenance] + state: int # HTTP status code + message: str # error/sucess message + source_id: str = "" # __ + input_file_path: str = "" # where the visxp_prep.tar.gz was downloaded + provenance: Optional[Provenance] = None # mostly: how long did it take to download @dataclass class VisXPFeatureExtractionOutput: - state: int - message: str - provenance: Optional[Provenance] + state: int # HTTP status code + message: str # error/success message + output_file_path: str = "" # where to store the extracted features + provenance: Optional[Provenance] = None # feature extraction provenance diff --git a/worker.py b/worker.py index ffb8b50..080ecaa 100644 --- a/worker.py +++ b/worker.py @@ -12,9 +12,8 @@ apply_desired_io_on_output, obtain_input_file, get_base_output_dir, - get_source_id, get_download_dir, - get_s3_base_url, + get_s3_output_file_uri, ) from pika.exceptions import ChannelClosedByBroker # type: ignore from feature_extraction import extract_features @@ -139,20 +138,23 @@ def callback(self, task: Task, doc: Document) -> CallbackResponse: # obtain the input file # TODO make sure to download the output from S3 - input_file_path, download_provenance = obtain_input_file(self.handler, doc) - if not input_file_path: + feature_extraction_input = obtain_input_file(self.handler, doc) + if not feature_extraction_input.state == 200: return { - "state": 500, - "message": "Could not download the input from S3", + "state": feature_extraction_input.state, + "message": feature_extraction_input.message, } - if download_provenance and provenance.steps: - provenance.steps.append(download_provenance) + if feature_extraction_input.provenance and provenance.steps: + provenance.steps.append(feature_extraction_input.provenance) - output_path = "TODO" # TODO think of this + # e.g. /mnt/dane-fs/output-files/ + output_path = get_base_output_dir( + feature_extraction_input.source_id + ) # TODO think of this # step 1: apply model to extract features proc_result = extract_features( - input_file_path, + feature_extraction_input.input_file_path, model_path=cfg.VISXP_EXTRACT.MODEL_PATH, model_config_file=cfg.VISXP_EXTRACT.MODEL_CONFIG_PATH, output_path=output_path, @@ -162,7 +164,7 @@ def callback(self, task: Task, doc: Document) -> CallbackResponse: provenance.steps.append(proc_result.provenance) validated_output: CallbackResponse = apply_desired_io_on_output( - input_file_path, + feature_extraction_input, proc_result, self.DELETE_INPUT_ON_COMPLETION, self.DELETE_OUTPUT_ON_COMPLETION, @@ -177,7 +179,7 @@ def callback(self, task: Task, doc: Document) -> CallbackResponse: self.save_to_dane_index( doc, task, - get_s3_base_url(get_source_id(input_file_path)), + get_s3_output_file_uri(feature_extraction_input.source_id), provenance=provenance, ) return validated_output