Skip to content

Commit

Permalink
basics work with tarfile as input
Browse files Browse the repository at this point in the history
  • Loading branch information
jblom committed Nov 2, 2023
1 parent 0df253f commit 591ddf4
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 5 deletions.
2 changes: 1 addition & 1 deletion config/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ PATHS:
VISXP_EXTRACT:
MODEL_PATH: /model/checkpoint.tar
MODEL_CONFIG_PATH: /model/model_config.yml
TEST_INPUT_PATH: /data
TEST_INPUT_PATH: /data/testob/visxp_prep__testob.tar.gz
INPUT:
DELETE_ON_COMPLETION: True
OUTPUT:
Expand Down
17 changes: 13 additions & 4 deletions feature_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from data_handling import VisXPData
from models import VisXPFeatureExtractionOutput
from provenance import generate_full_provenance_chain
from io_util import get_source_id, save_features_to_file
from io_util import get_source_id, save_features_to_file, untar_input_file

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -38,12 +38,21 @@ def extract_features(
source_id = get_source_id(input_path)
logger.info(f"Extracting features for: {source_id}.")

# Step 2: Load spectograms + keyframes from file & preprocess
# Step 2: check the type of input (tar.gz vs a directory)
if input_path.find(".tar.gz") != -1:
logger.info("Input is an archive, uncompressing it")
untar_input_file(input_path) # extracts contents in same dir
input_path = str(
Path(input_path).parent
) # change the input path to the parent dir
logger.info(f"Changed input_path to: {input_path}")

# Step 3: Load spectograms + keyframes from file & preprocess
dataset = VisXPData(
Path(input_path), model_config_file=model_config_file, device=device
)

# Step 3: Load model from file
# Step 4: Load model from file
model = load_model_from_file(
checkpoint_file=model_path,
config_file=model_config_file,
Expand All @@ -52,7 +61,7 @@ def extract_features(
# Switch model mode: in training mode, model layers behave differently!
model.eval()

# Step 4: Apply model to data
# Step 5: Apply model to data
logger.info(f"Going to extract features for {dataset.__len__()} items. ")

result_list = []
Expand Down
9 changes: 9 additions & 0 deletions io_util.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import logging
import os
from pathlib import Path
import tarfile
from time import time
import torch

Expand Down Expand Up @@ -206,3 +208,10 @@ def _fetch_visxp_prep_s3_uri(handler, doc: Document) -> str:
return possibles[0].payload.get("s3_location", "")
logger.error("No s3_location found in VISXP_PREP result")
return ""


# untars visxp_prep__<source_id>.tar.gz into the same dir
def untar_input_file(tar_file_path: str):
logger.info(f"Uncompressing {tar_file_path}")
with tarfile.open(tar_file_path) as tar:
tar.extractall(path=str(Path(tar_file_path).parent), filter="data") # type: ignore

0 comments on commit 591ddf4

Please sign in to comment.