Skip to content

Commit

Permalink
Add support for Trinity in data processing
Browse files Browse the repository at this point in the history
  • Loading branch information
nagyrajmund committed Jul 27, 2021
1 parent 5fc9c0d commit fcc9935
Show file tree
Hide file tree
Showing 4 changed files with 332 additions and 317 deletions.
32 changes: 24 additions & 8 deletions data_processing/motion_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ def nan_smooth(data, filt_len):
return out


def extract_joint_angles(bvh_dir, filenames, destpath, fps, use_fingers=False):
def extract_joint_angles(
dataset, bvh_dir, filenames, destpath, fps, use_mirroring, fullbody
):
bvh_parser = BVHParser()

bvh_files = [join(bvh_dir, bvh_file + ".bvh") for bvh_file in filenames]
Expand All @@ -40,22 +42,34 @@ def extract_joint_angles(bvh_dir, filenames, destpath, fps, use_fingers=False):
for file in progress_bar:
all_bvh_data.append(bvh_parser.parse(file))

data_pipe = utils.get_saga_bvh_pipeline(fps, use_fingers)
data_pipe = utils.get_bvh_pipeline(dataset, fps, use_mirroring, fullbody)
print("Processing...")

joint_angle_exp_maps = data_pipe.fit_transform(all_bvh_data)
assert len(joint_angle_exp_maps) == len(filenames)

if use_mirroring:
assert len(joint_angle_exp_maps) == 2 * len(filenames)
else:
assert len(joint_angle_exp_maps) == len(filenames)

jl.dump(data_pipe, os.path.join(destpath, "data_pipe.sav"))

for fname, motion_data in zip(filenames, joint_angle_exp_maps):
out_file = join(destpath, fname)
for file_ind in range(len(filenames)):
out_file = join(destpath, filenames[file_ind])
print(out_file)
np.savez(out_file + ".npz", clips=motion_data)
# np.savez(ff + "_mirrored.npz", clips=out_data[len(files) + fi])
np.savez(out_file + ".npz", clips=joint_angle_exp_maps[file_ind])

if use_mirroring:
# If there are mirrored files, they are stored in the second half
# of the motion array
np.savez(
out_file + "_mirrored.npz",
clips=joint_angle_exp_maps[len(filenames) + file_ind],
)

def extract_hand_pos(bvh_dir, files, destpath, fps):

def extract_hand_pos(bvh_dir, files, destpath, fps):
raise NotImplementedError()
p = BVHParser()

data_all = list()
Expand Down Expand Up @@ -87,6 +101,8 @@ def extract_hand_pos(bvh_dir, files, destpath, fps):


def extract_style_features(hand_pos_dir, files, destpath, fps, average_secs):
raise NotImplementedError()

filt_len = int(fps * average_secs)

for f in files:
Expand Down
111 changes: 65 additions & 46 deletions data_processing/prepare_gesture_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,16 @@
from os.path import join, exists
from argparse import ArgumentParser
import numpy as np

import glob
import os
import sys
from shutil import copyfile
from motion_features import (
extract_joint_angles,
extract_hand_pos,
extract_style_features,
)
from audio_features import extract_melspec

# from text_features import extract_text_features
import scipy.io.wavfile as wav
from pymo.parsers import BVHParser
from pymo.data import Joint, MocapData
from pymo.preprocessing import *
from pymo.writers import *
import joblib as jl
Expand Down Expand Up @@ -101,13 +95,14 @@ def align(data1, data2):


def load_and_concatenate(
file, motion_path, speech_path, start=0, end=None,
file, motion_path, speech_path, start=0, end=None, is_mirrored=False
):
"""Loads a file and concatenate all features to one [time, features] matrix.
NOTE: All sources will be truncated to the shortest length, i.e. we assume they
are time synchronized and has the same start time."""
motion_filename = file + "_mirrored" if is_mirrored else file

motion_data = np.load(join(motion_path, file + ".npz"))["clips"].astype(np.float32)
motion_data = np.load(join(motion_path, motion_filename + ".npz"))["clips"].astype(np.float32)
n_motion_feats = motion_data.shape[1]

load_array = lambda dir, file: np.load(join(dir, file + ".npy")).astype(np.float32)
Expand All @@ -133,7 +128,14 @@ def load_and_concatenate(


def create_data_windows(
files, motion_path, speech_data, slice_window, slice_overlap, start=0, end=None,
files,
motion_path,
speech_data,
slice_window,
slice_overlap,
is_mirrored,
start=0,
end=None,
):
"""Imports all features and slices them to samples with equal lenth time
[samples, timesteps, features]."""
Expand All @@ -149,14 +151,13 @@ def create_data_windows(
)
data_windows = slice_to_windows(concat_data, slice_window, slice_overlap)

# if mirror:
# concat_mirr, nmf = load_and_concatenate(
# file, motion_path, speech_data, text_path, style_path, True, start, end
# )
# sliced_mirr = slice_to_windows(concat_mirr, slice_window, slice_overlap)

# # append to the sliced dataset
# sliced = np.concatenate((sliced, sliced_mirr), axis=0)
if is_mirrored:
concat_mirr, _ = load_and_concatenate(
file, motion_path, speech_data, start, end, is_mirrored=True
)
sliced_mirr = slice_to_windows(concat_mirr, slice_window, slice_overlap)
# append to the sliced dataset
data_windows = np.concatenate((data_windows, sliced_mirr), axis=0)

out_data.append(data_windows)

Expand All @@ -169,32 +170,42 @@ def create_data_windows(

def parse_args():
parser = ArgumentParser()
parser.add_argument("--fps", type=int, default=25)
parser.add_argument("--raw_data_dir", type=str, default="../data/SAGA/source/")
parser.add_argument("--processed_dir", type=str, default="../data/SAGA/processed/")
parser.add_argument("--dataset", choices=["saga", "trinity"], default="saga")
parser.add_argument("--source_dir", type=str, default=None)
parser.add_argument("--processed_dir", type=str, default=None)
parser.add_argument("--held_out_files", nargs="+", default=None)

parser.add_argument("--fps", type=int, default=None)
# parser.add_argument("--style", choices=[None, "MG-R", "MG-V", "MG-H" "MS-S"])
parser.add_argument("--train_window_secs", type=int, default=6)
parser.add_argument("--test_window_secs", type=float, default=20)
parser.add_argument("--window_overlap", type=float, default=0.5)
parser.add_argument("--use_fingers", action="store_true")
parser.add_argument("--held_out_files", nargs="+", default=["V1", "V13", "V17"])
return parser.parse_args()
parser.add_argument("--fullbody", action="store_true")
parser.add_argument("--use_mirror_augment", action="store_true")

args = parser.parse_args()

def gather_filenames(data_dir):
"""
Return all filenames in the dataset, e.g. "Recording_001" (in Trinity) or "V01" (in SaGA).
"""
# r=root, d=directories, f = files
files = []
for r, d, f in os.walk(data_dir):
for file in sorted(f):
if ".bvh" in file:
ff = join(r, file)
basename = os.path.splitext(os.path.basename(ff))[0]
files.append(basename)
if args.source_dir is None:
args.source_dir = join("../data", args.dataset, "source")

if args.processed_dir is None:
args.processed_dir = join("../data", args.dataset, "processed")

if args.dataset == "saga":
if args.held_out_files is None:
args.held_out_files = ["V1", "V13", "V17"]
if args.fps is None:
args.fps = 25

return files
elif args.dataset == "trinity":
if args.held_out_files is None:
args.held_out_files = ["Recording_008"]
if args.fps is None:
args.fps = 20
else:
raise ValueError(f"Unknown dataset {args.dataset}!")

return args


if __name__ == "__main__":
Expand All @@ -217,11 +228,11 @@ def gather_filenames(data_dir):
if exists(feature_dir):
os.makedirs(feature_dir)

bvh_dir = join(args.raw_data_dir, "bvh")
audio_dir = join(args.raw_data_dir, "audio")
bvh_source_dir = join(args.source_dir, "bvh")
audio_source_dir = join(args.source_dir, "audio")
speech_output_dir = join(args.processed_dir, "melspec")
motion_output_dir = join(args.processed_dir, "joint_rot")
files = gather_filenames(bvh_dir)
files = utils.gather_filenames(bvh_source_dir)

# -----------------------------
# Extract speech features
Expand All @@ -231,7 +242,7 @@ def gather_filenames(data_dir):
else:
print("Processing speech features...")
os.makedirs(speech_output_dir)
extract_melspec(audio_dir, files, speech_output_dir, args.fps)
extract_melspec(audio_source_dir, files, speech_output_dir, args.fps)

# -----------------------------
# Extract joint angles
Expand All @@ -242,7 +253,13 @@ def gather_filenames(data_dir):
print("Processing motion features...")
os.makedirs(motion_output_dir)
extract_joint_angles(
bvh_dir, files, motion_output_dir, args.fps, use_fingers=args.use_fingers
args.dataset,
bvh_source_dir,
files,
motion_output_dir,
fps=args.fps,
use_mirroring=args.use_mirror_augment,
fullbody=args.fullbody,
)

# copy pipeline for converting motion features to bvh
Expand All @@ -261,7 +278,7 @@ def gather_filenames(data_dir):
test_window_len = args.test_window_secs * args.fps

# We reserve the first 20 evaluation samples for validation, and the rest for testing.
val_test_split = 10 * args.test_window_secs * args.fps
val_test_split = 20 * args.test_window_secs * args.fps

# ------------------------------------------------
# Create training/validation data in short windows
Expand All @@ -273,6 +290,7 @@ def gather_filenames(data_dir):
speech_output_dir,
train_window_len,
args.window_overlap,
is_mirrored=args.use_mirror_augment,
)

val_motion, val_ctrl = create_data_windows(
Expand All @@ -282,6 +300,7 @@ def gather_filenames(data_dir):
train_window_len,
args.window_overlap,
end=val_test_split,
is_mirrored=args.use_mirror_augment,
)

# ------------------------------------------------
Expand All @@ -297,6 +316,7 @@ def gather_filenames(data_dir):
slice_overlap=0,
start=0,
end=val_test_split,
is_mirrored=args.use_mirror_augment,
)
test_motion, test_ctrl = create_data_windows(
args.held_out_files,
Expand All @@ -305,6 +325,7 @@ def gather_filenames(data_dir):
test_window_len,
slice_overlap=0,
start=val_test_split,
is_mirrored=args.use_mirror_augment,
)

# ------------------------------------------------
Expand All @@ -331,8 +352,6 @@ def gather_filenames(data_dir):
# test_ctrl[1::3, :, -1].fill(np.quantile(train_ctrl[:, :, -1], 0.5))
# test_ctrl[2::3, :, -1].fill(np.quantile(train_ctrl[:, :, -1], 0.85))

# import pdb;pdb.set_trace()

# ------------------------------------------------
# Save the arrays into the 'processed' dir
# ------------------------------------------------
Expand Down Expand Up @@ -361,15 +380,15 @@ def save_array(fname, array):

for file in args.held_out_files:
cut_audio(
join(audio_dir, file + ".wav"),
join(audio_source_dir, file + ".wav"),
args.test_window_secs,
dev_vispath,
starttime=0.0,
endtime=10 * args.test_window_secs,
)

cut_audio(
join(audio_dir, file + ".wav"),
join(audio_source_dir, file + ".wav"),
args.test_window_secs,
test_vispath,
starttime=10 * args.test_window_secs,
Expand Down
Loading

0 comments on commit fcc9935

Please sign in to comment.