Skip to content

Commit

Permalink
more prepared to download and properly use file from 1st worker as input
Browse files Browse the repository at this point in the history
  • Loading branch information
jblom committed Nov 2, 2023
1 parent 2e22cd1 commit aed6016
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 52 deletions.
19 changes: 15 additions & 4 deletions feature_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from data_handling import VisXPData
from models import VisXPFeatureExtractionOutput
from provenance import generate_full_provenance_chain
from io_util import get_source_id, export_features
from io_util import get_source_id, save_features_to_file

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -59,18 +59,29 @@ def extract_features(
for i, batch in enumerate(dataset.batches(batch_size=256)):
batch_result = apply_model(batch=batch, model=model, device=device)
result_list.append(batch_result)
result = torch.cat(result_list)

# concatenate results and save to file
result = torch.cat(result_list)
destination = os.path.join(output_path, f"{source_id}.pt")
export_features(result, destination=destination)
file_saved = save_features_to_file(result, destination=destination)

if not file_saved:
return VisXPFeatureExtractionOutput(
500,
f"Could not save extracted features to {destination}",
destination,
None,
)

# generate provenance, since all went well
provenance = generate_full_provenance_chain(
start_time=start_time,
input_path=input_path,
provenance_chain=[],
output_path=destination,
)
return VisXPFeatureExtractionOutput(
200, "Succesfully extracted features", provenance
200, "Succesfully extracted features", destination, provenance
)

# Binarize resulting feature matrix
Expand Down
84 changes: 54 additions & 30 deletions io_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,26 @@
import os
from time import time
import torch
from typing import List, Tuple, Optional

from dane import Document
from dane.config import cfg
from dane.s3_util import S3Store, parse_s3_uri, validate_s3_uri
from models import CallbackResponse, Provenance, VisXPFeatureExtractionOutput
from models import (
CallbackResponse,
Provenance,
VisXPFeatureExtractionOutput,
VisXPFeatureExtractionInput,
)


logger = logging.getLogger(__name__)
DANE_VISXP_PREP_TASK_KEY = "VISXP_PREP"
OUTPUT_FILE_BASE_NAME = "VISXP_FEATURES"


# assesses the output and makes sure input & output is handled properly
def apply_desired_io_on_output(
input_file: str,
feature_extraction_input: VisXPFeatureExtractionInput,
proc_result: VisXPFeatureExtractionOutput,
delete_input_on_completion: bool,
delete_output_on_completetion: bool,
Expand All @@ -29,9 +34,7 @@ def apply_desired_io_on_output(
return {"state": proc_result.state, "message": proc_result.message}

# step 3: process returned successfully, generate the output
source_id = get_source_id(
input_file
) # TODO: this worker does not necessarily work per source, so consider how to capture output group
source_id = feature_extraction_input.source_id
output_path = get_base_output_dir(source_id) # TODO actually make sure this works

# step 4: transfer the output to S3 (if configured so)
Expand Down Expand Up @@ -78,9 +81,39 @@ def get_base_output_dir(source_id: str = "") -> str:
return os.path.join(*path_elements)


def export_features(features: torch.Tensor, destination: str):
with open(os.path.join(destination), "wb") as f:
torch.save(obj=features, f=f)
# output file name of the final .pt file that will be uploaded to S3
# TODO decide whether to tar.gz this as well
def get_output_file_name(source_id: str) -> str:
return f"{OUTPUT_FILE_BASE_NAME}__{source_id}.pt"


# e.g. s3://<bucket>/assets/<source_id>
def get_s3_base_uri(source_id: str) -> str:
return f"s3://{os.path.join(cfg.OUTPUT.S3_BUCKET, cfg.OUTPUT.S3_FOLDER_IN_BUCKET, source_id)}"


# e.g. s3://<bucket>/assets/<source_id>/visxp_features__<source_id>.pt
def get_s3_output_file_uri(source_id: str) -> str:
return f"{get_s3_base_uri(source_id)}/{get_output_file_name(source_id)}"


# e.g. s3://<bucket>/assets/<source_id>/visxp_prep__<source_id>.tar.gz
# TODO add validation of 1st VisXP worker's S3 URI
def source_id_from_s3_uri(s3_uri: str) -> str:
fn = os.path.basename(s3_uri)
source_id = fn[: -len(".tar.gz")].split("__")[1]
return f"{source_id}"


# saves the features to a local file, so it can be uploaded to S3
def save_features_to_file(features: torch.Tensor, destination: str) -> bool:
try:
with open(os.path.join(destination), "wb") as f:
torch.save(obj=features, f=f)
return True
except Exception:
logger.exception("Failed to save features to file")
return False


def delete_local_output(source_id: str) -> bool:
Expand All @@ -94,21 +127,8 @@ def transfer_output(output_dir: str) -> bool:
return True


def get_s3_base_url(source_id: str) -> str:
return f"s3://{os.path.join(cfg.OUTPUT.S3_BUCKET, cfg.OUTPUT.S3_FOLDER_IN_BUCKET, source_id)}"


def obtain_files_to_upload_to_s3(output_dir: str) -> List[str]:
s3_file_list = []
for root, dirs, files in os.walk(output_dir):
for f in files:
s3_file_list.append(os.path.join(root, f))
return s3_file_list


# TODO: implement or replace function calls
def get_download_dir():
return ""
return os.path.join(cfg.FILE_SYSTEM.BASE_MOUNT, cfg.FILE_SYSTEM.INPUT_DIR)


# NOTE: untested
Expand Down Expand Up @@ -141,14 +161,12 @@ def delete_input_file(input_file: str, actually_delete: bool) -> bool:
return True # return True even if empty dirs were not removed


def obtain_input_file(
handler, doc: Document
) -> Tuple[Optional[str], Optional[Provenance]]:
def obtain_input_file(handler, doc: Document) -> VisXPFeatureExtractionInput:
# first fetch and validate the obtained S3 URI
# TODO make sure this is a valid S3 URI
s3_uri = _fetch_visxp_prep_s3_uri(handler, doc)
if not validate_s3_uri(s3_uri):
return None, None
return VisXPFeatureExtractionInput(500, f"Invalid S3 URI: {s3_uri}")

start_time = time()
output_folder = get_download_dir()
Expand All @@ -161,17 +179,23 @@ def obtain_input_file(
if success:
# TODO uncompress the visxp_prep.tar.gz

download_provenance = Provenance(
provenance = Provenance(
activity_name="download",
activity_description="Download VISXP_PREP data",
start_time_unix=start_time,
processing_time_ms=time() - start_time,
input_data={},
output_data={"file_path": input_file_path},
)
return input_file_path, download_provenance
return VisXPFeatureExtractionInput(
200,
f"Failed to download: {s3_uri}",
source_id_from_s3_uri(s3_uri), # source_id
input_file_path, # locally downloaded .tar.gz
provenance,
)
logger.error("Failed to download VISXP_PREP data from S3")
return None, None
return VisXPFeatureExtractionInput(500, f"Failed to download: {s3_uri}")


def _fetch_visxp_prep_s3_uri(handler, doc: Document) -> str:
Expand Down
15 changes: 9 additions & 6 deletions models.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,16 @@ def to_json(self):

@dataclass
class VisXPFeatureExtractionInput:
state: int
message: str
provenance: Optional[Provenance]
state: int # HTTP status code
message: str # error/sucess message
source_id: str = "" # <program ID>__<carrier ID>
input_file_path: str = "" # where the visxp_prep.tar.gz was downloaded
provenance: Optional[Provenance] = None # mostly: how long did it take to download


@dataclass
class VisXPFeatureExtractionOutput:
state: int
message: str
provenance: Optional[Provenance]
state: int # HTTP status code
message: str # error/success message
output_file_path: str = "" # where to store the extracted features
provenance: Optional[Provenance] = None # feature extraction provenance
26 changes: 14 additions & 12 deletions worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,8 @@
apply_desired_io_on_output,
obtain_input_file,
get_base_output_dir,
get_source_id,
get_download_dir,
get_s3_base_url,
get_s3_output_file_uri,
)
from pika.exceptions import ChannelClosedByBroker # type: ignore
from feature_extraction import extract_features
Expand Down Expand Up @@ -139,20 +138,23 @@ def callback(self, task: Task, doc: Document) -> CallbackResponse:

# obtain the input file
# TODO make sure to download the output from S3
input_file_path, download_provenance = obtain_input_file(self.handler, doc)
if not input_file_path:
feature_extraction_input = obtain_input_file(self.handler, doc)
if not feature_extraction_input.state == 200:
return {
"state": 500,
"message": "Could not download the input from S3",
"state": feature_extraction_input.state,
"message": feature_extraction_input.message,
}
if download_provenance and provenance.steps:
provenance.steps.append(download_provenance)
if feature_extraction_input.provenance and provenance.steps:
provenance.steps.append(feature_extraction_input.provenance)

output_path = "TODO" # TODO think of this
# e.g. /mnt/dane-fs/output-files/<source_id>
output_path = get_base_output_dir(
feature_extraction_input.source_id
) # TODO think of this

# step 1: apply model to extract features
proc_result = extract_features(
input_file_path,
feature_extraction_input.input_file_path,
model_path=cfg.VISXP_EXTRACT.MODEL_PATH,
model_config_file=cfg.VISXP_EXTRACT.MODEL_CONFIG_PATH,
output_path=output_path,
Expand All @@ -162,7 +164,7 @@ def callback(self, task: Task, doc: Document) -> CallbackResponse:
provenance.steps.append(proc_result.provenance)

validated_output: CallbackResponse = apply_desired_io_on_output(
input_file_path,
feature_extraction_input,
proc_result,
self.DELETE_INPUT_ON_COMPLETION,
self.DELETE_OUTPUT_ON_COMPLETION,
Expand All @@ -177,7 +179,7 @@ def callback(self, task: Task, doc: Document) -> CallbackResponse:
self.save_to_dane_index(
doc,
task,
get_s3_base_url(get_source_id(input_file_path)),
get_s3_output_file_uri(feature_extraction_input.source_id),
provenance=provenance,
)
return validated_output
Expand Down

0 comments on commit aed6016

Please sign in to comment.