Skip to content

Commit

Permalink
add command to validate a BIDS folder has all the extracted features
Browse files Browse the repository at this point in the history
  • Loading branch information
alistairewj committed Oct 1, 2024
1 parent d07d229 commit d5db59b
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 1 deletion.
24 changes: 23 additions & 1 deletion src/b2aiprep/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from streamlit.web.bootstrap import run

from b2aiprep.prepare.bids_like_data import redcap_to_bids
from b2aiprep.prepare.prepare import prepare_bids_like_data
from b2aiprep.prepare.prepare import prepare_bids_like_data, validate_bids_data


@click.group()
Expand Down Expand Up @@ -119,6 +119,28 @@ def prepbidslikedata(
)


@main.command()
@click.argument("bids_dir_path", type=click.Path())
@click.argument("fix", type=bool)
def validate(
bids_dir_path,
fix,
):
"""Organizes the data into a BIDS-like directory structure.
redcap_csv_path: path to the redcap csv\n
audio_dir_path: path to directory with audio files\n
bids_dir_path: path to store bids-like data\n
tar_file_path: path to store tar file\n
transcription_model_size: tiny, small, medium, or large\n
n_cores: number of cores to run feature extraction on\n
with_sensitive: whether to include sensitive data
"""
validate_bids_data(
bids_dir_path=Path(bids_dir_path),
fix=fix,
)

@main.command()
@click.argument("filename", type=click.Path(exists=True))
@click.option("-s", "--subject", type=str, default=None)
Expand Down
74 changes: 74 additions & 0 deletions src/b2aiprep/prepare/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,80 @@ def prepare_bids_like_data(

_logger.info("Process completed.")

def validate_bids_data(
bids_dir_path: Path,
fix: bool = True,
transcription_model_size: str = 'medium',
) -> None:
"""Scans BIDS audio data and verifies that all expected features are present."""
_logger.info("Scanning for features in BIDS directory.")
# TODO: add a check to see if the audio feature extraction is complete
# before proceeding to the next step
# can verify features are generated for each audio_dir by looking for .pt files
# in audio_dir.parent / "audio"
audio_paths = get_audio_paths(bids_dir_path)
audio_to_reprocess = defaultdict(list)
features = ('speaker_embedding', 'specgram', 'melfilterbank', 'mfcc', 'sample_rate', 'opensmile')
for audio_path in audio_paths:
audio_path = Path(audio_path)
audio_dir = audio_path.parent
features_dir = audio_dir.parent / "audio"
for feat_name in features:
if features_dir.joinpath(f'{audio_path.stem}_{feat_name}.pt').exists() is False:
audio_to_reprocess[audio_path].append(feat_name)

# also check for transcription
if features_dir.joinpath(f'{audio_path.stem}_transcription.txt').exists() is False:
audio_to_reprocess[audio_path].append('transcription')

if len(audio_to_reprocess) > 0:
_logger.info(f"Missing features for {len(audio_to_reprocess)} / {len(audio_paths)} audio files")
else:
_logger.info("All audio files have been processed and all feature files are present.")
return

if not fix:
return

feature_extraction_fcns = {
"speaker_embedding": extract_speaker_embeddings_from_audios,
"specgram": extract_spectrogram_from_audios,
"melfilterbank": extract_mel_filter_bank_from_audios,
"mfcc": extract_mfcc_from_audios,
"opensmile": extract_opensmile_features_from_audios,
}
for audio_path, missing_feats in tqdm(audio_to_reprocess.items(), total=len(audio_to_reprocess), desc='Reprocessing audio files'):
audio_dir = audio_path.parent
features_dir = audio_dir.parent / "audio"
features_dir.mkdir(exist_ok=True)
for feat_name in missing_feats:
audio = Audio.from_filepath(str(audio_path))
audio = resample_audios([audio], resample_rate=RESAMPLE_RATE)[0]

file_extension = "pt"
if feat_name == "sample_rate":
feature_value = audio.sampling_rate
elif feat_name in feature_extraction_fcns:
feature_value = feature_extraction_fcns[feat_name]([audio])[0]
elif feat_name == 'transcription':
language = Language.model_validate({"language_code": "en"})
speech_to_text_model = HFModel(path_or_uri=f"openai/whisper-{transcription_model_size}")
feature_value = transcribe_audios(
audios=[audio], model=speech_to_text_model, language=language
)[0]
file_extension = "txt"
else:
_logger.warning(f"Unsupported feature: {feat_name}")
continue

save_path = features_dir / f"{audio_path.stem}_{feat_name}.{file_extension}"
if file_extension == "pt":
torch.save(feature_value, save_path)
else:
with open(save_path, "w", encoding="utf-8") as text_file:
text_file.write(feature_value.text)

_logger.info("Process completed.")

def main():
pass
Expand Down

0 comments on commit d5db59b

Please sign in to comment.