Skip to content

Commit

Permalink
implemented way to pick one samplerate from the VISXP_PREP dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
jblom committed Nov 8, 2023
1 parent 9f14aba commit faea8e8
Show file tree
Hide file tree
Showing 8 changed files with 33 additions and 8 deletions.
5 changes: 3 additions & 2 deletions config/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,17 @@ VISXP_EXTRACT:
MODEL_BASE_MOUNT: /model
MODEL_CHECKPOINT_FILE: checkpoint.tar # should be in MODEL_BASE_MOUNT
MODEL_CONFIG_FILE: model_config.yml # should be in MODEL_BASE_MOUNT
EXPECTED_SPECTOGRAM_SAMPLERATE_HZ: 24000 # -1 indicates undefined otherwise: 24000 | 48000
TEST_INPUT_PATH: /data/input-files/testob/visxp_prep__testob.tar.gz
INPUT:
S3_ENDPOINT_URL: https://s3-host
S3_ENDPOINT_URL: https://s3.eu-west-1.amazonaws.com/
MODEL_CHECKPOINT_S3_URI: s3://beng-daan-visxp/model/checkpoint.tar
MODEL_CONFIG_S3_URI: s3://beng-daan-visxp/model/model_config.yml
DELETE_ON_COMPLETION: False
OUTPUT:
DELETE_ON_COMPLETION: True
TRANSFER_ON_COMPLETION: True
S3_ENDPOINT_URL: https://s3-host
S3_ENDPOINT_URL: https://s3.eu-west-1.amazonaws.com/
S3_BUCKET: beng-daan-visxp # bucket reserved for 1 type of output
S3_FOLDER_IN_BUCKET: assets # folder within the bucket
DANE_DEPENDENCIES:
Expand Down
22 changes: 19 additions & 3 deletions data_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,22 +11,38 @@
logger = logging.getLogger(__name__)


KEYFRAME_INPUT_DIR = "keyframes"
SPECTOGRAM_INPUT_DIR = "spectograms"


class VisXPData(Dataset):
def __init__(
self,
datapath: Path,
model_config_file: str,
device: torch.device,
expected_sample_rate=-1, # 24000 | 48000
check_spec_dim=False,
):
if type(datapath) is not Path:
datapath = Path(datapath)
# Sorting not really necessary, but is a (poor) way of making sure specs and frames are aligned..
self.spec_paths = sorted(list(datapath.glob("spectograms/*.npz")))
self.frame_paths = sorted(list(datapath.glob("keyframes/*.jpg")))
self.frame_paths = sorted(list(datapath.glob(f"{KEYFRAME_INPUT_DIR}/*.jpg")))

# first determine if/which spectogram files to select
spectogram_suffix = (
"" if expected_sample_rate == -1 else f"_{expected_sample_rate}"
)
self.spec_paths = sorted(
list(datapath.glob(f"{SPECTOGRAM_INPUT_DIR}/*{spectogram_suffix}.npz"))
) # one samplerate is used
self.device = device
self.set_config(model_config_file=model_config_file)
self.list_of_shots = self.ListOfShots(datapath)

# NOTE use the keyframe list to determine __len__, since there can be
# multiple spectograms per keyframe
self.data_set_size = len(self.frame_paths)
if check_spec_dim:
all_ok = self.check_spec_dim()
if all_ok:
Expand Down Expand Up @@ -70,7 +86,7 @@ def set_config(self, model_config_file: str):
)

def __len__(self):
return len(self.spec_paths)
return self.data_set_size

def __getitem__(self, index):
item_dict = dict()
Expand Down
4 changes: 3 additions & 1 deletion feature_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,11 @@ def run(

# Step 4: Load spectograms + keyframes from file & preprocess
dataset = VisXPData(
Path(input_file_path),
datapath=Path(input_file_path),
model_config_file=os.path.join(model_base_mount, model_config_file),
device=device,
expected_sample_rate=feature_extraction_input.expected_sample_rate,
check_spec_dim=False,
)

# Step 5: Load model from file
Expand Down
1 change: 1 addition & 0 deletions io_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,7 @@ def obtain_input_file(s3_uri: str) -> VisXPFeatureExtractionInput:
f"Failed to download: {s3_uri}",
source_id_from_s3_uri(s3_uri), # source_id
input_file_path, # locally downloaded .tar.gz
cfg.VISXP_EXTRACT.EXPECTED_SPECTOGRAM_SAMPLERATE_HZ,
provenance,
)
logger.error("Failed to download VISXP_PREP data from S3")
Expand Down
1 change: 1 addition & 0 deletions models.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ class VisXPFeatureExtractionInput:
message: str # error/sucess message
source_id: str = "" # <program ID>__<carrier ID>
input_file_path: str = "" # where the visxp_prep.tar.gz was downloaded
expected_sample_rate: int = -1 # VISXP_EXTRACT.EXPECTED_SPECTOGRAM_SAMPLERATE_HZ
provenance: Optional[Provenance] = None # mostly: how long did it take to download


Expand Down
6 changes: 4 additions & 2 deletions tests/unit/data_handling_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@
def test_batches():
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
dataset = VisXPData(
"data/input-files/test_source_id",
device=device,
datapath="data/input-files/test_source_id",
model_config_file="model/model_config.yml",
device=device,
expected_sample_rate=-1, # not specified in the test data
check_spec_dim=False,
)
for i, item in enumerate(dataset.batches(1)):
index = int(item["timestamp"][0])
Expand Down
1 change: 1 addition & 0 deletions tests/unit/feature_extraction_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def test_extract_features():
f"Thank you for unit testing: let's process {UNIT_TEST_INPUT_PATH}",
UNIT_TEST_SOURCE_ID,
UNIT_TEST_INPUT_PATH,
-1,
None, # no provenance needed in test
),
model_base_mount="model",
Expand Down
1 change: 1 addition & 0 deletions worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def process_configured_input_file() -> bool:
f"Thank you for running us: let's test {cfg.VISXP_EXTRACT.TEST_INPUT_PATH}",
get_source_id_from_tar(cfg.VISXP_EXTRACT.TEST_INPUT_PATH),
cfg.VISXP_EXTRACT.TEST_INPUT_PATH,
cfg.VISXP_EXTRACT.EXPECTED_SPECTOGRAM_SAMPLERATE_HZ, # make sure this matches VISXP_PREP config
None, # no provenance needed in test
)

Expand Down

0 comments on commit faea8e8

Please sign in to comment.