From 565fa898d3646f14aa59602d93fe2b8b62b49237 Mon Sep 17 00:00:00 2001 From: Shinichi Shirakawa Date: Wed, 13 Oct 2021 01:41:51 +0900 Subject: [PATCH] add files --- .gitignore | 927 +++++++++++++++++++++++++++++++++++++++++ README.md | 178 +++++++- dataset.py | 346 +++++++++++++++ make_gesture_video.py | 158 +++++++ test.py | 91 ++++ train.py | 417 ++++++++++++++++++ utils/__init__.py | 0 utils/gesture_plot.py | 142 +++++++ utils/model.py | 182 ++++++++ utils/speaker_const.py | 568 +++++++++++++++++++++++++ utils/torch_dataset.py | 48 +++ 11 files changed, 3055 insertions(+), 2 deletions(-) create mode 100755 .gitignore create mode 100755 dataset.py create mode 100644 make_gesture_video.py create mode 100644 test.py create mode 100644 train.py create mode 100644 utils/__init__.py create mode 100644 utils/gesture_plot.py create mode 100755 utils/model.py create mode 100644 utils/speaker_const.py create mode 100644 utils/torch_dataset.py diff --git a/.gitignore b/.gitignore new file mode 100755 index 0000000..4c1c84a --- /dev/null +++ b/.gitignore @@ -0,0 +1,927 @@ +# Prerequisites +*.d + +# Object files +*.o +*.ko +*.obj +*.elf + +# Precompiled Headers +*.gch +*.pch + +# Libraries +*.lib +*.a +*.la +*.lo + +# Shared objects (inc. Windows DLLs) +*.dll +*.so +*.so.* +*.dylib + +# Executables +*.exe +*.out +*.app +*.i*86 +*.x86_64 +*.hex + +# Debug files +*.dSYM/ +*.su + +# Prerequisites +*.d + +# Compiled Object files +*.slo +*.lo +*.o +*.obj + +# Precompiled Headers +*.gch +*.pch + +# Compiled Dynamic libraries +*.so +*.dylib +*.dll + +# Fortran module files +*.mod +*.smod + +# Compiled Static libraries +*.lai +*.la +*.a +*.lib + +# Executables +*.exe +*.out +*.app + +*.i +*.ii +*.gpu +*.ptx +*.cubin +*.fatbin + +*.class + +# Mobile Tools for Java (J2ME) +.mtj.tmp/ + +# Package Files # +*.jar +*.war +*.ear + +# virtual machine crash logs, see http://www.java.com/en/download/help/error_hotspot.xml +hs_err_pid* + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +env/ +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +*.egg-info/ +.installed.cfg +*.egg + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*,cover +.hypothesis/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# IPython Notebook +.ipynb_checkpoints + +# pyenv +.python-version + +# celery beat schedule file +celerybeat-schedule + +# dotenv +.env + +# virtualenv +venv/ +ENV/ + +# Spyder project settings +.spyderproject + +# Rope project settings +.ropeproject + +# C++ objects and libs + +*.slo +*.lo +*.o +*.a +*.la +*.lai +*.so +*.dll +*.dylib + +# Qt-es + +/.qmake.cache +/.qmake.stash +*.pro.user +*.pro.user.* +*.qbs.user +*.qbs.user.* +*.moc +moc_*.cpp +qrc_*.cpp +ui_*.h +Makefile* +*build-* + +# QtCreator + +*.autosave + +# QtCtreator Qml +*.qmlproject.user +*.qmlproject.user.* + +# QtCtreator CMake +CMakeLists.txt.user* + + +# History files +.Rhistory +.Rapp.history + +# Session Data files +.RData + +# Example code in package build process +*-Ex.R + +# Output files from R CMD build +/*.tar.gz + +# Output files from R CMD check +/*.Rcheck/ + +# RStudio files +.Rproj.user/ + +# produced vignettes +vignettes/*.html +vignettes/*.pdf + +# OAuth2 token, see https://github.com/hadley/httr/releases/tag/v0.3 +.httr-oauth + +# knitr and R markdown default cache directories +/*_cache/ +/cache/ + +# Temporary files created by R markdown +*.utf8.md +*.knit.md + +## Core latex/pdflatex auxiliary files: +*.aux +*.lof +*.log +*.lot +*.fls +*.out +*.toc +*.fmt +*.fot +*.cb +*.cb2 + +## Intermediate documents: +*.dvi +*-converted-to.* +# these rules might exclude image files for figures etc. +# *.ps +# *.eps +# *.pdf + +## Generated if empty string is given at "Please type another file name for output:" +.pdf + +## Bibliography auxiliary files (bibtex/biblatex/biber): +*.bbl +*.bcf +*.blg +*-blx.aux +*-blx.bib +*.brf +*.run.xml + +## Build tool auxiliary files: +*.fdb_latexmk +*.synctex +*.synctex(busy) +*.synctex.gz +*.synctex.gz(busy) +*.pdfsync + +## Auxiliary and intermediate files from other packages: +# algorithms +*.alg +*.loa + +# achemso +acs-*.bib + +# amsthm +*.thm + +# beamer +*.nav +*.snm +*.vrb + +# cprotect +*.cpt + +# fixme +*.lox + +#(r)(e)ledmac/(r)(e)ledpar +*.end +*.?end +*.[1-9] +*.[1-9][0-9] +*.[1-9][0-9][0-9] +*.[1-9]R +*.[1-9][0-9]R +*.[1-9][0-9][0-9]R +*.eledsec[1-9] +*.eledsec[1-9]R +*.eledsec[1-9][0-9] +*.eledsec[1-9][0-9]R +*.eledsec[1-9][0-9][0-9] +*.eledsec[1-9][0-9][0-9]R + +# glossaries +*.acn +*.acr +*.glg +*.glo +*.gls +*.glsdefs + +# gnuplottex +*-gnuplottex-* + +# gregoriotex +*.gaux +*.gtex + +# hyperref +*.brf + +# knitr +*-concordance.tex +# TODO Comment the next line if you want to keep your tikz graphics files +*.tikz +*-tikzDictionary + +# listings +*.lol + +# makeidx +*.idx +*.ilg +*.ind +*.ist + +# minitoc +*.maf +*.mlf +*.mlt +*.mtc +*.mtc[0-9] +*.mtc[1-9][0-9] + +# minted +_minted* +*.pyg + +# morewrites +*.mw + +# mylatexformat +*.fmt + +# nomencl +*.nlo + +# sagetex +*.sagetex.sage +*.sagetex.py +*.sagetex.scmd + +# scrwfile +*.wrt + +# sympy +*.sout +*.sympy +sympy-plots-for-*.tex/ + +# pdfcomment +*.upa +*.upb + +# pythontex +*.pytxcode +pythontex-files-*/ + +# thmtools +*.loe + +# TikZ & PGF +*.dpth +*.md5 +*.auxlock + +# todonotes +*.tdo + +# easy-todo +*.lod + +# xindy +*.xdy + +# xypic precompiled matrices +*.xyc + +# endfloat +*.ttt +*.fff + +# Latexian +TSWLatexianTemp* + +## Editors: +# WinEdt +*.bak +*.sav + +# Texpad +.texpadtmp + +# Kile +*.backup + +# KBibTeX +*~[0-9]* + +## Ignore Visual Studio temporary files, build results, and +## files generated by popular Visual Studio add-ons. + +# User-specific files +*.suo +*.user +*.userosscache +*.sln.docstates + +# User-specific files (MonoDevelop/Xamarin Studio) +*.userprefs + +# Build results +[Dd]ebug/ +[Dd]ebugPublic/ +[Rr]elease/ +[Rr]eleases/ +x64/ +x86/ +bld/ +[Bb]in/ +[Oo]bj/ +[Ll]og/ + +# Visual Studio 2015 cache/options directory +.vs/ +# Uncomment if you have tasks that create the project's static files in wwwroot +#wwwroot/ + +# MSTest test Results +[Tt]est[Rr]esult*/ +[Bb]uild[Ll]og.* + +# NUNIT +*.VisualState.xml +TestResult.xml + +# Build Results of an ATL Project +[Dd]ebugPS/ +[Rr]eleasePS/ +dlldata.c + +# DNX +project.lock.json +project.fragment.lock.json +artifacts/ + +*_i.c +*_p.c +*_i.h +*.ilk +*.meta +*.obj +*.pch +*.pdb +*.pgc +*.pgd +*.rsp +*.sbr +*.tlb +*.tli +*.tlh +*.tmp +*.tmp_proj +*.log +*.vspscc +*.vssscc +.builds +*.pidb +*.svclog +*.scc + +# Chutzpah Test files +_Chutzpah* + +# Visual C++ cache files +ipch/ +*.aps +*.ncb +*.opendb +*.opensdf +*.sdf +*.cachefile +*.VC.db +*.VC.VC.opendb + +# Visual Studio profiler +*.psess +*.vsp +*.vspx +*.sap + +# TFS 2012 Local Workspace +$tf/ + +# Guidance Automation Toolkit +*.gpState + +# ReSharper is a .NET coding add-in +_ReSharper*/ +*.[Rr]e[Ss]harper +*.DotSettings.user + +# JustCode is a .NET coding add-in +.JustCode + +# TeamCity is a build add-in +_TeamCity* + +# DotCover is a Code Coverage Tool +*.dotCover + +# NCrunch +_NCrunch_* +.*crunch*.local.xml +nCrunchTemp_* + +# MightyMoose +*.mm.* +AutoTest.Net/ + +# Web workbench (sass) +.sass-cache/ + +# Installshield output folder +[Ee]xpress/ + +# DocProject is a documentation generator add-in +DocProject/buildhelp/ +DocProject/Help/*.HxT +DocProject/Help/*.HxC +DocProject/Help/*.hhc +DocProject/Help/*.hhk +DocProject/Help/*.hhp +DocProject/Help/Html2 +DocProject/Help/html + +# Click-Once directory +publish/ + +# Publish Web Output +*.[Pp]ublish.xml +*.azurePubxml +# TODO: Comment the next line if you want to checkin your web deploy settings +# but database connection strings (with potential passwords) will be unencrypted +*.pubxml +*.publishproj + +# Microsoft Azure Web App publish settings. Comment the next line if you want to +# checkin your Azure Web App publish settings, but sensitive information contained +# in these scripts will be unencrypted +PublishScripts/ + +# NuGet Packages +*.nupkg +# The packages folder can be ignored because of Package Restore +**/packages/* +# except build/, which is used as an MSBuild target. +!**/packages/build/ +# Uncomment if necessary however generally it will be regenerated when needed +#!**/packages/repositories.config +# NuGet v3's project.json files produces more ignoreable files +*.nuget.props +*.nuget.targets + +# Microsoft Azure Build Output +csx/ +*.build.csdef + +# Microsoft Azure Emulator +ecf/ +rcf/ + +# Windows Store app package directories and files +AppPackages/ +BundleArtifacts/ +Package.StoreAssociation.xml +_pkginfo.txt + +# Visual Studio cache files +# files ending in .cache can be ignored +*.[Cc]ache +# but keep track of directories ending in .cache +!*.[Cc]ache/ + +# Others +ClientBin/ +~$* +*~ +*.dbmdl +*.dbproj.schemaview +*.pfx +*.publishsettings +node_modules/ +orleans.codegen.cs + +# Since there are multiple workflows, uncomment next line to ignore bower_components +# (https://github.com/github/gitignore/pull/1529#issuecomment-104372622) +#bower_components/ + +# RIA/Silverlight projects +Generated_Code/ + +# Backup & report files from converting an old project file +# to a newer Visual Studio version. Backup files are not needed, +# because we have git ;-) +_UpgradeReport_Files/ +Backup*/ +UpgradeLog*.XML +UpgradeLog*.htm + +# SQL Server files +*.mdf +*.ldf + +# Business Intelligence projects +*.rdl.data +*.bim.layout +*.bim_*.settings + +# Microsoft Fakes +FakesAssemblies/ + +# GhostDoc plugin setting file +*.GhostDoc.xml + +# Node.js Tools for Visual Studio +.ntvs_analysis.dat + +# Visual Studio 6 build log +*.plg + +# Visual Studio 6 workspace options file +*.opt + +# Visual Studio LightSwitch build output +**/*.HTMLClient/GeneratedArtifacts +**/*.DesktopClient/GeneratedArtifacts +**/*.DesktopClient/ModelManifest.xml +**/*.Server/GeneratedArtifacts +**/*.Server/ModelManifest.xml +_Pvt_Extensions + +# Paket dependency manager +.paket/paket.exe +paket-files/ + +# FAKE - F# Make +.fake/ + +# JetBrains Rider +.idea/ +*.sln.iml + + +.metadata +bin/ +tmp/ +*.tmp +*.bak +*.swp +*~.nib +local.properties +.settings/ +.loadpath +.recommenders + +# Eclipse Core +.project + +# External tool builders +.externalToolBuilders/ + +# Locally stored "Eclipse launch configurations" +*.launch + +# PyDev specific (Python IDE for Eclipse) +*.pydevproject + +# CDT-specific (C/C++ Development Tooling) +.cproject + +# JDT-specific (Eclipse Java Development Tools) +.classpath + +# Java annotation processor (APT) +.factorypath + +# PDT-specific (PHP Development Tools) +.buildpath + +# sbteclipse plugin +.target + +# Tern plugin +.tern-project + +# TeXlipse plugin +.texlipse + +# STS (Spring Tool Suite) +.springBeans + +# Code Recommenders +.recommenders/ + +# Temporary data +.ipynb_checkpoints/ + +# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and Webstorm +# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 + +# User-specific stuff: +.idea/workspace.xml +.idea/tasks.xml +.idea/dictionaries +.idea/vcs.xml +.idea/jsLibraryMappings.xml + +# Sensitive or high-churn files: +.idea/dataSources.ids +.idea/dataSources.xml +.idea/dataSources.local.xml +.idea/sqlDataSources.xml +.idea/dynamic.xml +.idea/uiDesigner.xml + +# Gradle: +.idea/gradle.xml +.idea/libraries + +# Mongo Explorer plugin: +.idea/mongoSettings.xml + +## File-based project format: +*.iws + +## Plugin-specific files: + +# IntelliJ +/out/ + +# mpeltonen/sbt-idea plugin +.idea_modules/ + +# JIRA plugin +atlassian-ide-plugin.xml + +# Crashlytics plugin (for Android Studio and IntelliJ) +com_crashlytics_export_strings.xml +crashlytics.properties +crashlytics-build.properties +fabric.properties + +*~ + +# temporary files which can be created if a process still has a handle open of a deleted file +.fuse_hidden* + +# KDE directory preferences +.directory + +# Linux trash folder which might appear on any partition or disk +.Trash-* + +*.tmp + +# Word temporary +~$*.doc* + +# Excel temporary +~$*.xls* + +# Excel Backup File +*.xlk + +# PowerPoint temporary +~$*.ppt* + +# Visio autosave temporary files +*.~vsdx + +*.DS_Store +.AppleDouble +.LSOverride + +# Icon must end with two \r +Icon + +# Thumbnails +._* + +# Files that might appear in the root of a volume +.DocumentRevisions-V100 +.fseventsd +.Spotlight-V100 +.TemporaryItems +.Trashes +.VolumeIcon.icns +.com.apple.timemachine.donotpresent + +# Directories potentially created on remote AFP share +.AppleDB +.AppleDesktop +Network Trash Folder +Temporary Items +.apdisk + +# cache files for sublime text +*.tmlanguage.cache +*.tmPreferences.cache +*.stTheme.cache + +# workspace files are user-specific +*.sublime-workspace + +# project files should be checked into the repository, unless a significant +# proportion of contributors will probably not be using SublimeText +# *.sublime-project + +# sftp configuration file +sftp-config.json + +# Package control specific files +Package Control.last-run +Package Control.ca-list +Package Control.ca-bundle +Package Control.system-ca-bundle +Package Control.cache/ +Package Control.ca-certs/ +bh_unicode_properties.cache + +# Sublime-github package stores a github token in this file +# https://packagecontrol.io/packages/sublime-github +GitHub.sublime-settings + +# Windows image file caches +Thumbs.db +ehthumbs.db + +# Folder config file +Desktop.ini + +# Recycle Bin used on file shares +$RECYCLE.BIN/ + +# Windows Installer files +*.cab +*.msi +*.msm +*.msp + +# Windows shortcuts +*.lnk + +# Xcode +# +# gitignore contributors: remember to update Global/Xcode.gitignore, Objective-C.gitignore & Swift.gitignore + +## Build generated +build/ +DerivedData/ + +## Various settings +*.pbxuser +!default.pbxuser +*.mode1v3 +!default.mode1v3 +*.mode2v3 +!default.mode2v3 +*.perspectivev3 +!default.perspectivev3 +xcuserdata/ + +## Other +*.moved-aside +*.xccheckout +*.xcscmblueprint + diff --git a/README.md b/README.md index 5e0183a..e9d3606 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,176 @@ -# text2gesture_cnn -Text-to-Gesture Generation Model Using Convolutional Neural Network +# Evaluation of Text-to-Gesture Generation Model Using Convolutional Neural Network + +This repository contains the code for the text-to-gesture generation model using CNN. + +The demonstration video of generated gestures is available at . + +## Requirements +We used the [PyTorch](https://pytorch.org/) version 1.7.1 for neural network implementation. We tested the codes on the following environment: + +- Ubuntu 16.04 LTS +- GPU: NVIDIA GeForce GTX 1080Ti +- Python environment: anaconda3-2020.07 + - [fasttext](https://fasttext.cc/) + - cv2 (4.4.0) +- ffmpeg + +## Preparation +1. Our code uses the speech and gesture dataset provided by Ginosar et al. Download the Speech2Gesture dataset by following the instruction "Download specific speaker data" in . + +``` +Shiry Ginosar, Amir Bar, Gefen Kohavi, Caroline Chan, Andrew Owens, and Jitendra Malik, "Learning Individual Styles of Conversational Gesture," CVPR 2019. +``` + +After downloading the Speech2Gesture dataset, your dataset folder should be like: +``` +Gestures +├── frames_df_10_19_19.csv +├── almaram +    ├── frames +   ├── keypoints_all +   ├── keypoints_simple +   └── videos +... +└── shelly +    ├── frames +   ├── keypoints_all +   ├── keypoints_simple +   └── videos +``` +2. Download the text dataset from [HERE](https://drive.google.com/file/d/1OjSJ-F9hoLOfecF5FwdCGG2Mp8fBPgGb/view?usp=sharing) and unarchive the zip file. +3. Move the `words` directory in each speaker name's directory to the corresponding speaker's directory in your dataset directory. + +After this step, your dataset folder should be like: +``` +Gestures +├── frames_df_10_19_19.csv +├── almaram +    ├── frames +   ├── keypoints_all +   ├── keypoints_simple +   ├── videos +   └── words +... +└── shelly +    ├── frames +   ├── keypoints_all +   ├── keypoints_simple +   ├── videos +   └── words +``` +Note that the word data for speaker Jon is very little. Therefore, it should not use for model training. + +4. Set up the fasttext by following the instruction [HERE](https://fasttext.cc/docs/en/support.html). Download the pre-trained model file (wiki-news-300d-1M-subword.bin) of fasttext from [HERE](https://dl.fbaipublicfiles.com/fasttext/vectors-english/wiki-news-300d-1M-subword.bin.zip). + +## Create training and test data +* Run the script as +```shell +python dataset.py --base_path --speaker --wordvec_file --dataset_type --frames +``` + +* Options + * : Path to dataset folder (e.g., `/path_to_your_dataset/Gestures/`) + * : Speaker name (directory name of speaker) (e.g., `almaram`, `oliver`) + * : Path to the pre-trained model of `fasttext` (e.g., `/path_to_your_fasttext_dir/wiki-news-300d-1M-subword.bin/`) + * : Dataset type (`train` or `test`) + * : Number of frame for training data (we used 64 for training and 192 for test data) + +* Example (Create Oliver's data) +```shell +# training data +python dataset.py --base_path --speaker oliver --wordvec_file --dataset_type train --frames 64 + +# test data +python dataset.py --base_path --speaker oliver --wordvec_file --dataset_type test --frames 192 +``` + +After run the script, the direcrories containing the training or test data are created in your dataset folder. After this step, your dataset folder should be like: +``` +Gestures +├── frames_df_10_19_19.csv +├── almaram + ├── frames + ├── keypoints_all + ├── keypoints_simple + ├── test-192 + ├── train-64 + ├── videos + └── words +... +├── shelly + ├── frames + ├── keypoints_all + ├── keypoints_simple + ├── test-192 + ├── train-64 + ├── videos + └── words +``` + +## Model training +* Run the script as +```shell +python train.py --outdir_path --speaker --gpu_num --base_path --train_dir +``` + +* Options + * : Directory for saving training result (e.g., `./out_training/`) + * : Speaker name (directory name of speaker) (e.g., `almaram`, `oliver`) + * : GPU ID + * : Path to dataset folder (e.g., `/path_to_your_dataset/Gestures/`) + * : Directory name containing training data (e.g., `train-64`) + +The experimental settings (e.g., number of epochs, loss function) can change by specifying the argument. Please see the script file of `train.py` for the details. + +* Example (Training using Oliver's data) +```shell +python train.py --outdir_path ./out_training/ --speaker oliver --gpu_num 0 --base_path --train_dir train-64 +``` +The resulting files will be created in `./out_training/oliver_YYYYMMDD-AAAAAA/`. + +## Evaluation +* Predict the gesture motion for test data using a trained model +* Run the script as +```shell +python test.py --base_path --test_speaker --test_dir --model_dir --model_path --outdir_path +``` + +* Options + * : Path to dataset folder (e.g., `/path_to_your_dataset/Gestures/`) + * : Speaker name for testing (directory name of test speaker) (e.g., `almaram`, `oliver`) + * : Directory name containing test data (e.g., `test-192`) + * : Directory name of trained model (e.g., `oliver_YYYYMMDD-AAAAAA`) + * : Path to training result (e.g., `./out_training/`) + * : Directory for saving test result (e.g., `./out_test/`) + +* Example (Predict the Oliver's test data using Oliver's trained model) +```shell +python test.py --base_path --test_speaker oliver --test_dir test-192 --model_dir oliver_YYYYMMDD-AAAAAA --model_path ./out_training/ --outdir_path ./out_test/ +``` +The resulting files (`.npy` files for predicted motion) are created in `./out_test/oliver_by_oliver_YYYYMMDD-AAAAAA_test-192/`. + +* Example (Predict the Rock's test data using Oliver's trained model) +```shell +python test.py --base_path --test_speaker rock --test_dir test-192 --model_dir oliver_YYYYMMDD-AAAAAA --model_path ./out_training/ --outdir_path ./out_test/ +``` +The resulting files (`.npy` files for predicted motion) will be created in `./out_test/rock_by_oliver_YYYYMMDD-AAAAAA_test-192/`. + +## Visualization +* Create gesture movie files +* Run the script as +```shell +python make_gesture_video.py --base_path --test_out_path --test_out_dir --video_out_path +``` + +* Options + * : Path to dataset folder (e.g., `/path_to_your_dataset/Gestures/`) + * : Directory path of test output (e.g., `./out_test/`) + * : Directory name of output gestures (e.g., `oliver_by_oliver_YYYYMMDD-AAAAAA_test-192`) + * : Directory path of output videos (e.g., `./out_video/`) + +* Example (When using `XXXXX`) +```shell +python make_gesture_video.py --base_path --test_out_path ./out_test/ --test_out_dir oliver_by_oliver_YYYYMMDD-AAAAAA_test-192 --video_out_path ./out_video/ +``` + +The gesture videos (side-by-side video of ground truth and text-to-gesture) will be created in `./out_test/text2gesture/oliver_by_oliver_YYYYMMDD-AAAAAA_test-192/`. The left side gesture is ground truth, and the right side gesture is one generated by the text-to-gesture generation model. Also, the original videos of test intervals will be created in `./out_test/original/oliver/`. diff --git a/dataset.py b/dataset.py new file mode 100755 index 0000000..e497ce8 --- /dev/null +++ b/dataset.py @@ -0,0 +1,346 @@ +import numpy as np +import pandas as pd +import fasttext +import datetime +from pathlib import Path +from tqdm import tqdm +import argparse +import os +import subprocess +import cv2 +from decimal import Decimal + +WORD_CSV_DIR = 'words' + + +class FasttextVectorizer: + def __init__(self, model_path, dim=300): + # Load the trained model of fasttext + self.model = fasttext.load_model(model_path) + self.dim = dim + + def get_vec(self, corpus): + # Return embedded feature vectors from word list + vec = [] + if len(corpus) != 0: + for p in corpus: + vec.append(self.model.get_word_vector(p)) + else: + vec.append(np.zeros(self.dim)) + return vec + + +class WordEmb: + def __init__(self, model_path, df, speaker_path, dim=300): + self.vectorizer = FasttextVectorizer(model_path, dim) + + self.speaker_path = speaker_path + video_names = df['video_fn'].unique() + self.word_df_dict = {} + for vn in video_names: + video_name = vn.split('.')[0] + # Read utterance information for each frame + csv_name = '15FPS-{}.csv'.format(video_name) + fps_df = pd.read_csv(speaker_path / WORD_CSV_DIR / csv_name) + self.word_df_dict[video_name] = fps_df + + def get_wordemb(self, df, i, frames): + zero = datetime.datetime.strptime('0', '%S') + start = datetime.datetime.strptime(df['start'][i], '%H:%M:%S.%f') - zero + # Get start time of pronounce + start_sec = start.total_seconds() + # Calculate starting frame number + start_frame = round(start_sec * 15) + + video_name = df['video_fn'][i].split('.')[0] + + # Get utterance words for each frame + fps_df = self.word_df_dict[video_name] + word_list = fps_df['word'][start_frame:start_frame+frames].copy(deep=True).to_list() + # Pad the data if the number of frames is insufficient (when extracting the data near the end of video) + if not len(word_list) == frames: + print('not length = {} ({})'.format(frames, len(word_list))) + word_list.extend([''] * (frames - len(word_list))) + # Transform '' + words = ['' if w == '' else w for w in word_list] + # Vectorize the words + vec = self.vectorizer.get_vec(words) + return word_list, vec + + +def save_voice(speaker_path, i, df, fname, frames, voice_path): + zero = datetime.datetime.strptime('0', '%S') + start = datetime.datetime.strptime(df['start'][i], '%H:%M:%S.%f') - zero + + # Get start time and end time of pronounce + start_sec = start.total_seconds() + # Get total time (sec) of pronounce + total_sec = 1./15 * frames + + # Path + video_path = speaker_path / 'videos' / df['video_fn'][i] + voice_name = fname + '.wav' + voice_path = voice_path / voice_name + + # Save 'wav' file with the same name with 'npz' by ffmpeg + cmd = 'ffmpeg -i "{}" -ss {} -t {} -ab 160k -ac 2 -ar {} -vn "{}" -y -loglevel warning'.format(str(video_path), start_sec, total_sec, 44100, str(voice_path)) + subprocess.call(cmd, shell=True) + + +class KptExtractor: + def __init__(self, df, speaker_path): + video_names = df['video_fn'].unique() + self.speaker_path = speaker_path + self.kpt_face_dict = {} + for vn in video_names: + video_name = vn.split('.')[0] + keypoints_all_path = self.speaker_path / 'keypoints_all' / video_name + # Get file names + proc = subprocess.run(['ls', str(keypoints_all_path)], stdout=subprocess.PIPE, stderr=subprocess.PIPE) + fstr = proc.stdout.decode('utf8') + flist = fstr.split('\n') + # Extract keypoint data for face + face_kpts_name = [[f] for f in flist if 'face' in f] + # Calculate time (sec) from extracted file name + face_kpts_time = [[float(Decimal(float(t[0].split('_')[-3]) * 60.0 + float(t[0].split('_')[-2])).quantize(Decimal('0.000001')))] for t in face_kpts_name] + if len(face_kpts_name) != 0: + # Concatenate + face_kpts_list = np.concatenate([face_kpts_name, face_kpts_time], 1) + # Sorting based on the frame number + face_kpts_list = face_kpts_list[np.argsort(face_kpts_list[:, 1].astype(float))] + self.kpt_face_dict[video_name] = face_kpts_list + + def get_kpts(self, df, i, frames, face_calc=True): + pose_arr = np.empty((0, 2, 49), int) + face_arr = np.empty((0, 2, 70), int) + + video_name = df['video_fn'][i].split('.')[0] + # Get start time + start = df['start'][i].split(':') + start_sec = float(start[-2]) * 60.0 + float(start[-1]) # Transform to [sec] + # Get data after the starting time + frame_tf = self.kpt_face_dict[video_name][:, 1].astype(float) > start_sec + start_face_kpts = self.kpt_face_dict[video_name][frame_tf] + + # Path for keypoint data + keypoints_simple_path = self.speaker_path / 'keypoints_simple' / video_name + keypoints_all_path = self.speaker_path / 'keypoints_all' / video_name + + for s in start_face_kpts[:frames]: + # Keypoint for body + pose_kpt = s[0][:-9] + '.txt' + np_pose = np.loadtxt(keypoints_simple_path / pose_kpt) + # Deletes two eyes and nose + np_pose = np.delete(np_pose, [7, 8, 9], 1) # 7 is nose, 8,9 are eyes + base_p = np_pose[:, 0] + np_pose = np_pose - np.reshape(base_p, (2, 1)) + pose_arr = np.append(pose_arr, np.reshape(np_pose, [1, 2, 49]), axis=0) + + # Keypoint for face + if face_calc: + fs = cv2.FileStorage(str(keypoints_all_path / s[0]), cv2.FILE_STORAGE_READ) + face = fs.getNode("face_0").mat() + x, y, _ = np.split(face, [1, 2], axis=2) + # Use first person's data + if not x.shape[0] == 1: + x, y = x[0], y[0] + # Make (x, y)-array and arrange coordinate (the origin becomes base keypoint) + face_2d = np.vstack((x - base_p[0], y - base_p[1])) + face_arr = np.append(face_arr, np.reshape(face_2d, [1, 2, 70]), axis=0) + + return pose_arr, face_arr + + +def save_dataset(df, speaker_path, embed_path, save_path, voice=True, dataset='train', save_option=True, frames=64): + + npz_path = save_path / 'text-pose-npz' + os.makedirs(npz_path, exist_ok=True) + df['npz_fn'] = '' + df['min_sh_width'] = 0. + df['max_sh_width'] = 0. + voice_path = save_path / 'audio' + if voice: + os.makedirs(voice_path, exist_ok=True) + df['audio_fn'] = '' + df['words'] = '' + df['words_per_frame'] = '' + + # Load trained Word2Vec model + print('loading word2vec') + word_emb = WordEmb(embed_path, df, speaker_path) + print('completed loading word2vec') + + kpt_ext = KptExtractor(df, speaker_path) + + for i in tqdm(range(len(df))): + try: + # word data + words, wvec = word_emb.get_wordemb(df, i, frames) # Get text vector + df.at[i, 'words_per_frame'] = ' '.join(words) # save words + df.at[i, 'words'] = ' '.join(sorted(set(words), key=words.index)) + + # keypoint data + poses, face = kpt_ext.get_kpts(df, i, frames, face_calc=save_option) + + # Exception + if poses.shape != (frames, 2, 49): + continue + elif save_option and face.shape != (frames, 2, 70): + continue + + # Save + fname = '{}-{:06d}'.format(df['dataset'][i], i) + if save_option: + np.savez(npz_path / fname, words=words, poses=poses, wvec=wvec, face=face) + else: + np.savez(npz_path / fname, poses=poses, wvec=wvec) + df.at[i, 'npz_fn'] = fname + '.npz' + shoulder_width = np.sqrt((poses[:, 0, 4] - poses[:, 0, 1])**2 + (poses[:, 1, 4] - poses[:, 1, 1])**2) + df.at[i, 'min_sh_width'] = np.min(shoulder_width) + df.at[i, 'max_sh_width'] = np.max(shoulder_width) + + # Save audio data + if voice and os.path.isfile(speaker_path / 'videos' / df['video_fn'][i]): + save_voice(speaker_path, i, df, fname, frames, voice_path) + df.at[i, 'audio_fn'] = fname + '.wav' + + except Exception as e: + print(e) + continue + + df.to_csv(save_path / '{}.csv'.format(dataset)) + return df + + +def video_samples_train(df, base_path, speaker, num_frames=64): + df = df[df['speaker'] == speaker] + df = df[(df['dataset'] == 'train') | (df['dataset'] == 'dev')] + speaker_path = base_path / speaker + + data_dict = {'dataset': [], 'start': [], 'end': [], 'interval_id': [], 'video_fn': [], 'speaker': []} + intervals = df['interval_id'].unique() + + total_frames = 0 + total_train_frames = 0 + total_dev_frames = 0 + for interval in tqdm(intervals): + try: + df_interval = df[df['interval_id'] == interval].sort_values('frame_id', ascending=True) + video_fn = df_interval.iloc[0]['video_fn'] + speaker_name = df_interval.iloc[0]['speaker'] + if len(df_interval) < num_frames: + print("interval: %s, num frames: %s. skipped" % (interval, len(df_interval))) + continue + + # word file exist? + word_csv_name = '15FPS-{}.csv'.format(video_fn.split('.')[0]) + if not os.path.isfile(speaker_path / WORD_CSV_DIR / word_csv_name): + continue + total_frames += len(df_interval) + if df_interval.iloc[0]['dataset'] == 'train': + total_train_frames += len(df_interval) + elif df_interval.iloc[0]['dataset'] == 'dev': + total_dev_frames += len(df_interval) + + for idx in range(0, len(df_interval) - num_frames, 5): + sample = df_interval[idx:idx + num_frames] + data_dict["dataset"].append(df_interval.iloc[0]['dataset']) + data_dict["start"].append(sample.iloc[0]['pose_dt']) + data_dict["end"].append(sample.iloc[-1]['pose_dt']) + data_dict["interval_id"].append(interval) + data_dict["video_fn"].append(video_fn) + data_dict["speaker"].append(speaker_name) + except Exception as e: + print(e) + continue + return pd.DataFrame.from_dict(data_dict), [total_frames, total_train_frames, total_dev_frames] + + +def video_samples_test(df, base_path, speaker, num_samples, num_frames=64): + df = df[df['speaker'] == speaker] + df = df[df['dataset'] == 'test'] + speaker_path = base_path / speaker + + df['ones'] = 1 + grouped = df.groupby('interval_id').agg({'ones': sum}).reset_index() + grouped = grouped[grouped['ones'] >= num_frames][['interval_id']] + df = df.merge(grouped, on='interval_id') + + data_dict = {'dataset': [], 'start': [], 'end': [], 'interval_id': [], 'video_fn': [], 'speaker': []} + intervals = df['interval_id'].unique() + + i = 0 + pbar = tqdm(total=num_samples) + while i < num_samples: + try: + interval = intervals[np.random.randint(0, len(intervals))] + df_interval = df[df['interval_id'] == interval].sort_values('frame_id', ascending=True) + video_fn = df_interval.iloc[0]['video_fn'] + speaker_name = df_interval.iloc[0]['speaker'] + if len(df_interval) < num_frames: + continue + + # video file exist? + if not os.path.isfile(speaker_path / 'videos' / video_fn): + continue + + idx = np.random.randint(0, len(df_interval) - num_frames + 1) + sample = df_interval[idx:idx + num_frames] + data_dict["dataset"].append(df_interval.iloc[0]['dataset']) + data_dict["start"].append(sample.iloc[0]['pose_dt']) + data_dict["end"].append(sample.iloc[-1]['pose_dt']) + data_dict["interval_id"].append(interval) + data_dict["video_fn"].append(video_fn) + data_dict["speaker"].append(speaker_name) + i += 1 + pbar.update(1) + except Exception as e: + print(e) + continue + pbar.close() + return pd.DataFrame.from_dict(data_dict) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('-base_path', '--base_path', help='base folder path of dataset', required=True) + parser.add_argument('-speaker', '--speaker', default='specific speaker name', required=True) + parser.add_argument('-wordvec_file', '--wordvec_file', default='./wiki-news-300d-1M-subword.bin', + help='word vector file of firstText', required=True) + parser.add_argument('-dataset_type', '--dataset_type', default='train', help='dataset type (train / test)', + required=True) + parser.add_argument('-fs', '--frames', default=64, help='number of frames per sample', type=int) + parser.add_argument('-s', '--samples', default=4096, help='number of samples for test data', type=int) + args = parser.parse_args() + + speaker = args.speaker + base_path = Path(args.base_path) + save_path = base_path / speaker / (args.dataset_type + '-' + str(args.frames)) + speaker_path = base_path / speaker # Path of speaker + + df = pd.read_csv(base_path / 'frames_df_10_19_19.csv') + + if args.dataset_type == 'train': + df_samples, frames = video_samples_train(df, base_path, speaker, num_frames=args.frames) + elif args.dataset_type == 'test': + df_samples = video_samples_test(df, base_path, speaker, args.samples, num_frames=args.frames) + del df + + if args.dataset_type == 'train': + df_word = save_dataset(df_samples, speaker_path, args.wordvec_file, save_path, voice=False, + dataset=args.dataset_type, save_option=False, frames=args.frames) + elif args.dataset_type == 'test': + df_word = save_dataset(df_samples, speaker_path, args.wordvec_file, save_path, voice=True, + dataset=args.dataset_type, save_option=True, frames=args.frames) + + print('speaker: ', speaker) + print('dataset_type: ', args.dataset_type) + if args.dataset_type == 'train': + print('Number of total train/dev frames: ', frames[0]) + print('Number of total train frames: ', frames[1]) + print('Number of total dev frames: ', frames[2]) + print('Number of train/dev samples: ', len(df_word)) + print('Number of train samples: ', len(df_word[(df_word['dataset'] == 'train')])) + print('Number of dev samples: ', len(df_word[(df_word['dataset'] == 'dev')])) + elif args.dataset_type == 'test': + print('Number of test samples: ', len(df_word)) diff --git a/make_gesture_video.py b/make_gesture_video.py new file mode 100644 index 0000000..7a019f3 --- /dev/null +++ b/make_gesture_video.py @@ -0,0 +1,158 @@ +import os +import subprocess +import datetime +import argparse +import numpy as np +import pandas as pd +from pathlib import Path +from tqdm import tqdm + +from test import get_test_datalist +from utils.gesture_plot import save_video +from utils.speaker_const import SPEAKERS_CONFIG + + +def make_gesture_video(trg_speaker, i, face, pose_gt, pose_fake, audio_fn, video_out_path, tmp_path): + """ + # Decide position randomly + if np.random.randint(2) == 0: + video_name = trg_speaker + '_{:03d}_right.mp4'.format(i) + poses = np.array([pose_gt, pose_fake]) # Pose keypoints + info = [video_name, 'real', 'fake'] + else: + video_name = trg_speaker + '_{:03d}_left.mp4'.format(i) + poses = np.array([pose_fake, pose_gt]) # Pose keypoints + info = [video_name, 'fake', 'real'] + """ + + video_name = trg_speaker + '_{:03d}_right.mp4'.format(i) + poses = np.array([pose_gt, pose_fake]) # Pose keypoints + info = [video_name, 'real', 'fake'] + + # Save video + save_video(poses, face, audio_fn, str(video_out_path / video_name), str(tmp_path), delete_tmp=False) + + return info + + +def get_test_df(df_path, min_ratio=0.75, max_ratio=1.25): + df = pd.read_csv(df_path) + speaker = df['speaker'][0] + + shoulder_w = np.sqrt((SPEAKERS_CONFIG[speaker]['median'][4] - SPEAKERS_CONFIG[speaker]['median'][1]) ** 2 + + (SPEAKERS_CONFIG[speaker]['median'][53] - SPEAKERS_CONFIG[speaker]['median'][50]) ** 2) + min_w = shoulder_w * min_ratio + max_w = shoulder_w * max_ratio + shoulder_cond = (min_w < df['min_sh_width']) & (df['max_sh_width'] < max_w) + + file_exist = df['npz_fn'].notnull() + test_df = df[(df['dataset'] == 'test') & shoulder_cond & file_exist] + return test_df + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Making gesture videos') + parser.add_argument('--base_path', type=str, default='~/Gestures/', help='gesture dataset base path') + parser.add_argument('--test_out_path', type=str, default='./out_test/', help='directory path of test output') + parser.add_argument('--test_out_dir', type=str, default='oliver_by_oliver_YYYYMMDD-AAAAAA_test-192', + help='directory name of output gestures') + parser.add_argument('--video_out_path', type=str, default='./out_video/', help='directory path of output videos') + + parser.add_argument('--video_num', type=int, default=50, help='number of videos') + parser.add_argument('--tmp_path', type=str, default='./tmp', help='temporary directory path') + args = parser.parse_args() + + base_path = Path(args.base_path) + test_out_path = Path(args.test_out_path) + test_out_dir = args.test_out_dir + video_out_path = Path(args.video_out_path) / 'text2gesture' / test_out_dir + os.makedirs(video_out_path, exist_ok=True) + + NUM = 50 + tmp_path = Path(args.tmp_path) + os.makedirs(tmp_path, exist_ok=True) + + # Extract information + trg_speaker = test_out_dir.split('_')[0] + model_speaker = test_out_dir.split('_')[2] + frame_num = int(test_out_dir.split('-')[-1]) + test_dir_name = test_out_dir.split('_')[-1] + + # Video of Text2Gesture and GroundTruth + cols = ['test_fn', 'video_name', 'left', 'right'] + ['word-' + str(i + 1) for i in range(frame_num)] + data_path = base_path / trg_speaker / test_dir_name + audio_path = Path(data_path) / 'audio' + df_path = data_path / 'test.csv' + + # Test file list + test_list = get_test_datalist(df_path, min_ratio=0.75, max_ratio=1.25) + + # Test gesture file path + t2g_test_pose_path = test_out_path / test_out_dir + + df_record = pd.DataFrame(index=[], columns=cols) + + for i, npz_fn in tqdm(enumerate(test_list.values[:NUM])): + # Ground Truth + npz_gt = np.load(Path(data_path) / 'text-pose-npz' / npz_fn) + pose_gt = npz_gt['poses'] + word_gt = list(npz_gt['words']) + + # Face keypoints + face = np.array([npz_gt['face'], npz_gt['face']]) + + # Audio file names + audio_name = npz_fn[:-3] + 'wav' + + # Text2Gesture Prediction + pose_t2g = np.load(t2g_test_pose_path / (npz_fn[:-1] + 'y')) + pose_t2g = SPEAKERS_CONFIG[model_speaker]['scale_factor'] / SPEAKERS_CONFIG[trg_speaker][ + 'scale_factor'] * pose_t2g + + info = make_gesture_video(trg_speaker, i, face, pose_gt, pose_t2g, str(audio_path / audio_name), + video_out_path, tmp_path) + + # Save record + record = pd.Series([npz_fn] + info + word_gt, index=cols) + df_record = df_record.append(record, ignore_index=True) + + df_record.to_csv(video_out_path / 't2g.csv', index=False) + + # -------------------------------------------------- + # Save original video + cols = ['test_fn', 'video_name'] + ['word-' + str(i + 1) for i in range(frame_num)] + + # Test file list + test_df = get_test_df(df_path) + + # Save directory + video_out_path = Path(args.video_out_path) / 'original' / trg_speaker + os.makedirs(video_out_path, exist_ok=True) + + df_record = pd.DataFrame(index=[], columns=cols) + + for i in tqdm(range(NUM)): + npz_fn = test_df.iloc[i]['npz_fn'] + npz_gt = np.load(Path(data_path) / 'text-pose-npz' / npz_fn) + word = list(npz_gt['words']) + # print(npz_fn) + + zero = datetime.datetime.strptime('0', '%S') + start = datetime.datetime.strptime(test_df.iloc[i]['start'], '%H:%M:%S.%f') - zero + end = datetime.datetime.strptime(test_df.iloc[i]['end'], '%H:%M:%S.%f') - zero + video_fn = base_path / trg_speaker / 'videos' / test_df.iloc[i]['video_fn'] + # print(video_fn, start, end) + + video_out_fn = video_out_path / (trg_speaker + '_{:03d}.mp4'.format(i)) + # print(video_out_fn) + + # 動画の切り出し + cmd = 'ffmpeg -i {} -ss {} -to {} -y {}'.format(str(video_fn), start, end, str(video_out_fn)) + # print(cmd) + subprocess.call(cmd, shell=True) + + # Save record + record = pd.Series([npz_fn, video_out_fn] + word, index=cols) + df_record = df_record.append(record, ignore_index=True) + + df_record.to_csv(video_out_path / 'original.csv', index=False) diff --git a/test.py b/test.py new file mode 100644 index 0000000..801a63c --- /dev/null +++ b/test.py @@ -0,0 +1,91 @@ + +import torch +import numpy as np +import pandas as pd +import argparse +import os +from pathlib import Path +from tqdm import tqdm + +from utils.speaker_const import SPEAKERS_CONFIG +from utils.model import UnetDecoder + + +def get_test_datalist(df_path, min_ratio=-np.inf, max_ratio=np.inf): + df = pd.read_csv(df_path) + speaker = df['speaker'][0] + + shoulder_w = np.sqrt((SPEAKERS_CONFIG[speaker]['median'][4] - SPEAKERS_CONFIG[speaker]['median'][1]) ** 2 + + (SPEAKERS_CONFIG[speaker]['median'][53] - SPEAKERS_CONFIG[speaker]['median'][50]) ** 2) + min_w = shoulder_w * min_ratio + max_w = shoulder_w * max_ratio + shoulder_cond = (min_w < df['min_sh_width']) & (df['max_sh_width'] < max_w) + + file_exist = df['npz_fn'].notnull() + test_list = df[(df['dataset'] == 'test') & shoulder_cond & file_exist]['npz_fn'] + return test_list + + +def prediction(base_path, test_speaker, test_dir, model_dir, model_path, outdir_path): + # Save setting + out_path = Path(outdir_path) / (test_speaker + '_by_' + model_dir + '_' + test_dir) + + os.makedirs(out_path, exist_ok=True) + with open(Path(out_path).joinpath('setting.txt'), mode='w') as f: + f.write('Base path: {}\n'.format(base_path)) + f.write('Test speaker: {}\n'.format(test_speaker)) + f.write('Test dir: {}\n'.format(test_dir)) + f.write('Model dir: {}\n'.format(model_dir)) + f.write('Model path: {}\n'.format(model_path)) + + # Get data list + df_path = Path(base_path) / test_speaker / test_dir / 'test.csv' + test_list = get_test_datalist(df_path) + + # Load model speaker + exp_df = pd.read_csv(Path(model_path) / model_dir / 'experimental_settings.csv', index_col=0, header=None) + model_speaker = exp_df.loc['speaker', 1] + + # Load trained generator model + weights = torch.load(Path(model_path) / model_dir / 'trained-generator.pth', + map_location=lambda storage, loc: storage) + model = UnetDecoder(300, 300) + # Set trained weight + model.load_state_dict(weights) + + # Prediction + dataset_path = Path(base_path) / test_speaker / test_dir / 'text-pose-npz' + for fn in tqdm(test_list): + # Load word vectors + npz = np.load(dataset_path / fn) + wvec = npz['wvec'] # shape = (frames, 300) + wvec = np.transpose(wvec, (1, 0)) # shape = (300, frames) + inputs = torch.Tensor([wvec]) + + # Model prediction + with torch.no_grad(): + outputs = model(inputs) + gesture = np.transpose(outputs.numpy()[0], (1, 0)) # shape = (frames, 98) + + # De-normalizing gestures using SPEAKERS_CONFIG + gesture = (SPEAKERS_CONFIG[model_speaker]['std'] + np.finfo(float).eps) * gesture + SPEAKERS_CONFIG[model_speaker]['mean'] + gesture = np.reshape(gesture, (-1, 2, 49)) # shape = (frames, 2, 49) + + # Saving + np.save(out_path / fn[:-4], gesture) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Text to Gesture Generation by PyTorch') + parser.add_argument('--base_path', type=str, default='~/Gestures/', help='gesture base path') + parser.add_argument('--test_speaker', type=str, default='oliver', help='speaker name for test') + parser.add_argument('--test_dir', type=str, default='test-192', help='test file directory name') + parser.add_argument('--model_dir', type=str, default='oliver_YYYYMMDD-AAAAAA', + help='directory name of trained model') + parser.add_argument('--model_path', type=str, default='./out/', + help='directory path to training result') + parser.add_argument('--outdir_path', type=str, default='~/test_gesture_out/text2gesture/', + help='directory path of outputs') + args = parser.parse_args() + + prediction(args.base_path, args.test_speaker, args.test_dir, args.model_dir, args.model_path, args.outdir_path) diff --git a/train.py b/train.py new file mode 100644 index 0000000..f723c3b --- /dev/null +++ b/train.py @@ -0,0 +1,417 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +import torch +import torch.utils.data as data_utils +import torch.nn as nn +import torch.optim as optim + +import time +import os +import csv +import argparse +import datetime +import numpy as np +import pandas as pd +from pathlib import Path +from tqdm import tqdm + +from utils.torch_dataset import get_datalist, wp_dataset +from utils.model import UnetDecoder, PatchGan + + +def get_argument(): + """ + Experimental setting + + Returns + ------- + args: Namespace + Experimental parameters from command line + """ + parser = argparse.ArgumentParser(description='Text to Gesture Generation by PyTorch') + parser.add_argument('--batch_size', type=int, default=32, help='input batch size for training') + parser.add_argument('--epochs', type=int, default=50, help='number of epochs to train') + parser.add_argument('--speaker', type=str, default='oliver', help='choose speaker name') + parser.add_argument('--no_screening', action='store_true', help='Not use data screening') + parser.add_argument('--gan_loss', action='store_true', help='Use GAN loss') + parser.add_argument('--lam_p', type=float, default=1., help='coefficient of pose loss') + parser.add_argument('--lam_m', type=float, default=1., help='coefficient of motion loss') + parser.add_argument('--lam_g', type=float, default=1., help='coefficient of GAN loss') + parser.add_argument('--lr', type=float, default=0.0001, help='initial learning rate for training') + parser.add_argument('--device', type=str, default='cuda', help='cpu or cuda') + parser.add_argument('--gpu_num', type=int, default='0', help='GPU number') + parser.add_argument('--base_path', type=str, default='~/Gestures/', help='gesture base path') + parser.add_argument('--train_dir', type=str, default='train-64', help='training file directory') + parser.add_argument('--outdir_path', type=str, default='./out/', help='directory path of outputs') + parser.add_argument('--model_save_interval', type=int, default='10', help='Interval for saving model') + args = parser.parse_args() + return args + + +def write_parameters(args, outdir_path): + """ + Write hyperparameter settings to csv file + + Parameters + ---------- + args: Namespace + Experimental Settings + outdir_path: string + Output path + """ + fout = open(Path(outdir_path).joinpath('experimental_settings.csv'), "wt") + csvout = csv.writer(fout) + print('*' * 50) + print('Parameters') + print('*' * 50) + for arg in dir(args): + if not arg.startswith('_'): + csvout.writerow([arg, str(getattr(args, arg))]) + print('%-25s %-25s' % (arg, str(getattr(args, arg)))) + + +def train(args, outdir_path): + """ + Main function for training + + Returns + ------- + net: type(model) + Trained model at final iteration + """ + + # Load the dataset + df_path = Path(args.base_path) / args.speaker / args.train_dir / 'train.csv' + dataset_path = Path(args.base_path) / args.speaker / args.train_dir / 'text-pose-npz' + if args.no_screening: + train_list, dev_list = get_datalist(df_path, min_ratio=-np.inf, max_ratio=np.inf) + else: + train_list, dev_list = get_datalist(df_path) + train_num, val_num = len(train_list), len(dev_list) + print('Dataset size: {} (train), {} (validation)'.format(train_num, val_num)) + + train_dataset = wp_dataset(dataset_path, train_list, args.speaker) + val_dataset = wp_dataset(dataset_path, dev_list, args.speaker) + + # DataLoaders + train_loader = data_utils.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=2, + drop_last=True) + val_loader = data_utils.DataLoader(val_dataset, batch_size=256, shuffle=False, num_workers=2) + print('Complete the preparing dataset...') + + # Set the GPU usage + device = torch.device('cuda:' + str(args.gpu_num) if args.device == 'cuda' else 'cpu') + print('Device: ', device) + + # Set the network + # Gesture Generator + g_net = UnetDecoder(300, 300) + g_net.to(device) + g_optim = optim.Adam(g_net.parameters(), lr=args.lr) + + # Define Loss function + l1_criterion = nn.L1Loss() + + # Training loop + start_time = time.time() + ite = 0 + history = [] + g_net.train(True) + + for epoch in range(args.epochs): + + running_sample = train_frame_num = 0 + train_pose_sum, train_pose_sq_sum = np.zeros(98), np.zeros(98) + train_g_pose_l1 = train_g_motion_l1 = 0.0 + + # ********** Training Phase ********** + # Sample minibatch from DataLoader + for (x_train, t_train) in tqdm(train_loader): + ite += 1 + inputs, corrects = x_train.to(device), t_train.to(device) + + # Generator training + g_optim.zero_grad() + g_out = g_net(inputs) # shape = (batch, 98, frames=64) + g_pose_l1 = l1_criterion(g_out, corrects) # L1 Loss of each coordinate + # L1 loss of temporal difference + g_motion_l1 = l1_criterion(g_out[:, :, 1:] - g_out[:, :, :-1], corrects[:, :, 1:] - corrects[:, :, :-1]) + g_loss = args.lam_p * g_pose_l1 + args.lam_m * g_motion_l1 + g_loss.backward() + g_optim.step() + + # Record running loss and prediction + poses = np.reshape(g_out.detach().cpu().numpy().transpose(0, 2, 1), (-1, 98)) + train_pose_sum += poses.sum(axis=0) + train_pose_sq_sum += (poses ** 2).sum(axis=0) + train_frame_num += len(poses) + train_g_pose_l1 += g_pose_l1.item() * len(inputs) + train_g_motion_l1 += g_motion_l1.item() * len(inputs) + running_sample += len(inputs) + + # ********** Logging and Validation Phase ********** + g_net.train(False) + val_frame_num = 0 + val_pose_sum, val_pose_sq_sum = np.zeros(98), np.zeros(98) + val_g_pose_l1 = val_g_motion_l1 = 0.0 + + for x_val, t_val in val_loader: + inputs, corrects = x_val.to(device), t_val.to(device) + + with torch.no_grad(): + # Generator Calculation + g_out = g_net(inputs) + g_pose_l1 = l1_criterion(g_out, corrects) + g_motion_l1 = l1_criterion(g_out[:, :, 1:] - g_out[:, :, :-1], + corrects[:, :, 1:] - corrects[:, :, :-1]) + + # Record running loss and prediction + poses = np.reshape(g_out.detach().cpu().numpy().transpose(0, 2, 1), (-1, 98)) + val_pose_sum += poses.sum(axis=0) + val_pose_sq_sum += (poses ** 2).sum(axis=0) + val_frame_num += len(poses) + val_g_pose_l1 += g_pose_l1.item() * len(inputs) + val_g_motion_l1 += g_motion_l1.item() * len(inputs) + + g_net.train(True) + + # Record training log + train_pose_std = np.mean(np.sqrt(train_pose_sq_sum / train_frame_num + - (train_pose_sum / train_frame_num) ** 2)) + val_pose_std = np.mean(np.sqrt(val_pose_sq_sum / val_frame_num - (val_pose_sum / val_frame_num) ** 2)) + record = {'epoch': epoch + 1, 'iteration': ite, + 'train_pose_std': train_pose_std, 'val_pose_std': val_pose_std, + 'train_g_pose_l1': train_g_pose_l1 / running_sample, + 'train_g_motion_l1': train_g_motion_l1 / running_sample, + 'train_g_loss': (args.lam_p * train_g_pose_l1 + args.lam_m * train_g_motion_l1) / running_sample, + 'val_g_pose_l1': val_g_pose_l1 / val_num, + 'val_g_motion_l1': val_g_motion_l1 / val_num, + 'val_g_loss': (args.lam_p * val_g_pose_l1 + args.lam_m * val_g_motion_l1) / val_num} + history.append(record) + print(record, flush=True) + + # Save models + if (epoch + 1) % args.model_save_interval == 0: + torch.save(g_net.state_dict(), Path(outdir_path).joinpath('generator-{}.pth'.format(epoch + 1))) + pd.DataFrame.from_dict(history).to_csv(Path(outdir_path).joinpath('history.csv')) + + pd.DataFrame.from_dict(history).to_csv(Path(outdir_path).joinpath('history.csv')) + + # Training Time + elapsed_time = time.time() - start_time + print('Training complete in {:.0f}m {:.0f}s'.format(elapsed_time // 60, elapsed_time % 60)) + + # Save training time and dataset size + with open(Path(outdir_path).joinpath('train_summary.txt'), mode='w') as f: + f.write('Training size: {}, Val size: {}\n'.format(train_num, val_num)) + f.write('Training complete in {:.0f}m {:.0f}s\n'.format(elapsed_time // 60, elapsed_time % 60)) + + return g_net + + +def train_gan(args, outdir_path): + """ + Main function for training with GAN loss + + Returns + ------- + net: type(model) + Trained model at final iteration + """ + + # Load the dataset + df_path = Path(args.base_path) / args.speaker / args.train_dir / 'train.csv' + dataset_path = Path(args.base_path) / args.speaker / args.train_dir / 'text-pose-npz' + if args.no_screening: + train_list, dev_list = get_datalist(df_path, min_ratio=0.7, max_ratio=1.3) + else: + train_list, dev_list = get_datalist(df_path) + train_num, val_num = len(train_list), len(dev_list) + print('Dataset size: {} (train), {} (validation)'.format(train_num, val_num)) + train_dataset = wp_dataset(dataset_path, train_list, args.speaker) + val_dataset = wp_dataset(dataset_path, dev_list, args.speaker) + + # DataLoaders + train_loader = data_utils.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=2, + drop_last=True) + val_loader = data_utils.DataLoader(val_dataset, batch_size=256, shuffle=False, num_workers=2) + print('Complete the preparing dataset...') + + # Set the GPU usage + device = torch.device('cuda:' + str(args.gpu_num) if args.device == 'cuda' else 'cpu') + print('Device: ', device) + + # Set the network + # Gesture Generator + g_net = UnetDecoder(300, 300) + g_net.to(device) + g_optim = optim.Adam(g_net.parameters(), lr=args.lr) + + # Discriminator + d_net = PatchGan(ndf=64) + d_net.to(device) + d_optim = optim.Adam(d_net.parameters(), lr=args.lr) + + # Define Loss function + l1_criterion = nn.L1Loss() + mse_criterion = nn.MSELoss() + + # Training loop + start_time = time.time() + ite = 0 + history = [] + g_net.train(True) + d_net.train(True) + + for epoch in range(args.epochs): + + running_sample = train_frame_num = 0 + train_pose_sum, train_pose_sq_sum = np.zeros(98), np.zeros(98) + train_g_pose_l1 = train_g_motion_l1 = train_g_gan = train_d_real = train_d_fake = 0.0 + + # ********** Training Phase ********** + # Sample minibatch from DataLoader + for (x_train, t_train) in tqdm(train_loader): + ite += 1 + inputs, corrects = x_train.to(device), t_train.to(device) + + # Generator training + g_optim.zero_grad() + g_out = g_net(inputs) # shape = (batch, 98, frames=64) + d_fake_out = d_net(g_out[:, :, 1:] - g_out[:, :, :-1]) + g_gan_loss = mse_criterion(torch.ones(d_fake_out.shape).to(device), d_fake_out) + g_pose_l1 = l1_criterion(g_out, corrects) # L1 Loss of each coordinate + # L1 loss of temporal difference + g_motion_l1 = l1_criterion(g_out[:, :, 1:] - g_out[:, :, :-1], corrects[:, :, 1:] - corrects[:, :, :-1]) + g_loss = args.lam_p * g_pose_l1 + args.lam_m * g_motion_l1 + args.lam_g * g_gan_loss + g_loss.backward() + g_optim.step() + + # Discriminator training + d_optim.zero_grad() + fake_d_input = g_out[:, :, 1:] - g_out[:, :, :-1] + real_d_input = corrects[:, :, 1:] - corrects[:, :, :-1] + d_real_out, d_fake_out = d_net(real_d_input), d_net(fake_d_input.detach()) + d_real_loss = mse_criterion(torch.ones(d_real_out.shape).to(device), d_real_out) + d_fake_loss = mse_criterion(torch.zeros(d_fake_out.shape).to(device), d_fake_out) + d_loss = d_real_loss + d_fake_loss + d_loss.backward() + d_optim.step() + + # Record running loss and prediction + poses = np.reshape(g_out.detach().cpu().numpy().transpose(0, 2, 1), (-1, 98)) + train_pose_sum += poses.sum(axis=0) + train_pose_sq_sum += (poses**2).sum(axis=0) + train_frame_num += len(poses) + train_g_pose_l1 += g_pose_l1.item() * len(inputs) + train_g_motion_l1 += g_motion_l1.item() * len(inputs) + train_g_gan += g_gan_loss.item() * len(inputs) + train_d_real += d_real_loss.item() * len(inputs) + train_d_fake += d_fake_loss.item() * len(inputs) + running_sample += len(inputs) + + # ********** Logging and Validation Phase ********** + g_net.train(False) + d_net.train(False) + val_frame_num = 0 + val_pose_sum, val_pose_sq_sum = np.zeros(98), np.zeros(98) + val_g_pose_l1 = val_g_motion_l1 = val_g_gan = val_d_real = val_d_fake = 0.0 + + for x_val, t_val in val_loader: + inputs, corrects = x_val.to(device), t_val.to(device) + + with torch.no_grad(): + # Generator Calculation + g_out = g_net(inputs) + g_pose_l1 = l1_criterion(g_out, corrects) + g_motion_l1 = l1_criterion(g_out[:, :, 1:] - g_out[:, :, :-1], + corrects[:, :, 1:] - corrects[:, :, :-1]) + d_fake_out = d_net(g_out[:, :, 1:] - g_out[:, :, :-1]) + g_gan_loss = mse_criterion(torch.ones(d_fake_out.shape).to(device), d_fake_out) + + # Discriminator Calculation + fake_d_input = g_out[:, :, 1:] - g_out[:, :, :-1] + real_d_input = corrects[:, :, 1:] - corrects[:, :, :-1] + d_real_out, d_fake_out = d_net(real_d_input), d_net(fake_d_input) + d_real_loss = mse_criterion(torch.ones(d_real_out.shape).to(device), d_real_out) + d_fake_loss = mse_criterion(torch.zeros(d_fake_out.shape).to(device), d_fake_out) + + # Record running loss and prediction + poses = np.reshape(g_out.detach().cpu().numpy().transpose(0, 2, 1), (-1, 98)) + val_pose_sum += poses.sum(axis=0) + val_pose_sq_sum += (poses ** 2).sum(axis=0) + val_frame_num += len(poses) + val_g_pose_l1 += g_pose_l1.item() * len(inputs) + val_g_motion_l1 += g_motion_l1.item() * len(inputs) + val_g_gan += g_gan_loss.item() * len(inputs) + val_d_real += d_real_loss.item() * len(inputs) + val_d_fake += d_fake_loss.item() * len(inputs) + + g_net.train(True) + d_net.train(True) + + # Record training log + train_pose_std = np.mean(np.sqrt(train_pose_sq_sum / train_frame_num + - (train_pose_sum / train_frame_num)**2)) + val_pose_std = np.mean(np.sqrt(val_pose_sq_sum / val_frame_num - (val_pose_sum / val_frame_num)**2)) + record = {'epoch': epoch + 1, 'iteration': ite, + 'train_pose_std': train_pose_std, 'val_pose_std': val_pose_std, + 'train_g_pose_l1': train_g_pose_l1 / running_sample, + 'train_g_motion_l1': train_g_motion_l1 / running_sample, + 'train_g_gan': train_g_gan / running_sample, + 'train_g_loss': (args.lam_p * train_g_pose_l1 + args.lam_m * train_g_motion_l1 + args.lam_g * train_g_gan) / running_sample, + 'train_d_real': train_d_real / running_sample, 'train_d_fake': train_d_fake / running_sample, + 'train_d_loss': (train_d_real + train_d_fake) / running_sample, + 'val_g_pose_l1': val_g_pose_l1 / val_num, + 'val_g_motion_l1': val_g_motion_l1 / val_num, + 'val_g_gan': val_g_gan / val_num, + 'val_g_loss': (args.lam_p * val_g_pose_l1 + args.lam_m * val_g_motion_l1 + args.lam_g * val_g_gan) / val_num, + 'val_d_real': val_d_real / val_num, 'val_d_fake': val_d_fake / val_num, + 'val_d_loss': (val_d_real + val_d_fake) / val_num} + history.append(record) + print(record, flush=True) + + # Save models + if (epoch + 1) % args.model_save_interval == 0: + torch.save(g_net.state_dict(), Path(outdir_path).joinpath('generator-{}.pth'.format(epoch + 1))) + torch.save(d_net.state_dict(), Path(outdir_path).joinpath('discriminator-{}.pth'.format(epoch + 1))) + pd.DataFrame.from_dict(history).to_csv(Path(outdir_path).joinpath('history.csv')) + + pd.DataFrame.from_dict(history).to_csv(Path(outdir_path).joinpath('history.csv')) + + # Training Time + elapsed_time = time.time() - start_time + print('Training complete in {:.0f}m {:.0f}s'.format(elapsed_time // 60, elapsed_time % 60)) + + # Save training time and dataset size + with open(Path(outdir_path).joinpath('train_summary.txt'), mode='w') as f: + f.write('Training size: {}, Val size: {}\n'.format(train_num, val_num)) + f.write('Training complete in {:.0f}m {:.0f}s\n'.format(elapsed_time // 60, elapsed_time % 60)) + + return [g_net, d_net] + + +if __name__ == '__main__': + time_stamp = datetime.datetime.now().strftime('%Y%m%d-%H%M%S') + args = get_argument() + + # Make directory to save results + outdir_path = Path(args.outdir_path) / (args.speaker + '_' + time_stamp) + os.makedirs(outdir_path, exist_ok=True) + write_parameters(args, outdir_path) + + # Check GPU / CPU + if not torch.cuda.is_available(): + args.device = 'cpu' + + # Swith the training function according to GAN loss usage + if args.gan_loss: + nets = train_gan(args, outdir_path) + # Save trained network + for name, net in zip(['generator', 'discriminator'], nets): + torch.save(net.state_dict(), Path(outdir_path).joinpath('trained-{}.pth'.format(name))) + else: + net = train(args, outdir_path) + # Save trained network + torch.save(net.state_dict(), Path(outdir_path).joinpath('trained-{}.pth'.format('generator'))) diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/utils/gesture_plot.py b/utils/gesture_plot.py new file mode 100644 index 0000000..768afb0 --- /dev/null +++ b/utils/gesture_plot.py @@ -0,0 +1,142 @@ +import matplotlib +import subprocess +import os +import numpy as np +from matplotlib import cm, pyplot as plt +from PIL import Image + +matplotlib.use("Agg") + + +# Keypoint consts +BASE_KEYPOINT = [0] +RIGHT_BODY_KEYPOINTS = [1, 2, 3, 28] +LEFT_BODY_KEYPOINTS = [4, 5, 6, 7] +LEFT_HAND_KEYPOINTS = lambda x: [7] + [8 + (x * 4) + j for j in range(4)] +RIGHT_HAND_KEYPOINTS = lambda x: [28] + [29 + (x * 4) + j for j in range(4)] +LINE_WIDTH = 1.5 +MARKER_SIZE = 1.5 + + +def plot_body_right_keypoints(keypoints, color, alpha=None, line_width=LINE_WIDTH): + _keypoints = np.array(BASE_KEYPOINT + RIGHT_BODY_KEYPOINTS) + plt.plot(keypoints[0][_keypoints], keypoints[1][_keypoints], linewidth=line_width, alpha=alpha, color=color) + + +def plot_body_left_keypoints(keypoints, color, alpha=None, line_width=LINE_WIDTH): + _keypoints = np.array(BASE_KEYPOINT + LEFT_BODY_KEYPOINTS) + plt.plot(keypoints[0][_keypoints], keypoints[1][_keypoints], linewidth=line_width, alpha=alpha, color=color) + + +def plot_left_hand_keypoints(keypoints, color, alpha=None, line_width=LINE_WIDTH): + for i in range(5): + _keypoints = np.array(LEFT_HAND_KEYPOINTS(i)) + plt.plot(keypoints[0][_keypoints], keypoints[1][_keypoints], linewidth=line_width, alpha=alpha, color=color) + + +def plot_right_hand_keypoints(keypoints, color, alpha=None, line_width=LINE_WIDTH): + for i in range(5): + _keypoints = np.array(RIGHT_HAND_KEYPOINTS(i)) + plt.plot(keypoints[0][_keypoints], keypoints[1][_keypoints], linewidth=line_width, alpha=alpha, color=color) + + +def plot_face(keypoints, color, alpha=None, marker_size=MARKER_SIZE): + plt.plot(keypoints[0], keypoints[1], color, marker='.', lw=0, markersize=marker_size, alpha=alpha) + + +def draw_poses(img, frame_body_kpts, frame_face_kpts, img_size, output=None, show=None, title=None, sub_size=None, color=None): + # Number of persons to draw + persons = len(frame_body_kpts) + + plt.close('all') + fig = plt.figure(figsize=(persons * 2, 1 * 2), dpi=200) + + # Draw title + if title is not None: + plt.title(title) + + if img is not None: + img_ = Image.open(img) + else: + img_ = Image.new(mode='RGB', size=img_size, color='white') + plt.imshow(img_, alpha=0.5) + + if color is None: + color = ['dodgerblue'] * persons + + for i in range(len(frame_body_kpts)): + ax = plt.subplot(1, persons, i+1) + + plot_body_right_keypoints(frame_body_kpts[i], color[i]) + plot_body_left_keypoints(frame_body_kpts[i], color[i]) + plot_left_hand_keypoints(frame_body_kpts[i], color[i]) + plot_right_hand_keypoints(frame_body_kpts[i], color[i]) + if frame_face_kpts is not None: + plot_face(frame_face_kpts[i], color[i]) + # Plotting size specification + ax.set_xlim(sub_size[0], sub_size[1]) + ax.set_ylim(sub_size[1], sub_size[0]) + # Remove axis + plt.axis('off') + + # Remove axis + plt.axis('off') + + if show: + plt.show() + if output is not None: + plt.savefig(output) + + +def create_mute_video_from_images(output_fn, temp_folder): + """ + :param output_fn: output video file name + :param temp_folder: contains images in the format 0001.jpg, 0002.jpg.... + :return: + """ + subprocess.call('ffmpeg -r 30000/2002 -f image2 -i "%s" -r 30000/1001 "%s" -y' % ( + os.path.join(temp_folder, '%04d.jpg'), output_fn), shell=True) + + +def create_voice_video_from_voice_and_video(audio_input_path, input_video_path, output_video_path): + subprocess.call('ffmpeg -i "%s" -i "%s" -strict -2 "%s" -y' % (audio_input_path, input_video_path, + output_video_path), shell=True) + + +def save_video(body_kpts, face_kpts, voice_path, output_fn, temp_folder, delete_tmp=False, color=None, + img_size=(720, 480), img=None): + # Create temporary directory + os.makedirs(temp_folder, exist_ok=True) + # Number of frames + frames = len(body_kpts[0]) + # Temporary file name pattern + output_fn_pattern = os.path.join(temp_folder, '%04d.jpg') + + # Size for drawing + if face_kpts is None: + sub_max = np.max(body_kpts) + sub_min = np.min(body_kpts) + else: + sub_max = max([np.max(body_kpts), np.max(face_kpts)]) + sub_min = min([np.min(body_kpts), np.min(face_kpts)]) + + # Drawing per frame + for j in range(frames): + frame_body_kpts = body_kpts[:, j] + if face_kpts is None: + frame_face_kpts = None + else: + frame_face_kpts = face_kpts[:, j] + draw_poses(None, frame_body_kpts, frame_face_kpts, output=output_fn_pattern % j, show=False, color=color, + img_size=img_size, sub_size=[sub_min, sub_max]) + plt.close() + + # Create mute video + create_mute_video_from_images(output_fn, temp_folder) + # Create video with voice + if voice_path is not None: + create_voice_video_from_voice_and_video(voice_path, output_fn, str(output_fn)[:-4] + '-voice.mp4') + subprocess.call('rm "%s"' % output_fn, shell=True) + # Delete temporary files + if delete_tmp: + subprocess.call('rm -R "%s"' % temp_folder, shell=True) diff --git a/utils/model.py b/utils/model.py new file mode 100755 index 0000000..9827fe5 --- /dev/null +++ b/utils/model.py @@ -0,0 +1,182 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def weights_init(m): + if isinstance(m, nn.Conv1d): + torch.nn.init.xavier_uniform_(m.weight) + torch.nn.init.zeros_(m.bias) + + +class ConvNormRelu1d(nn.Module): + """(conv => BN => ReLU)""" + def __init__(self, in_channels, out_channels, k, s, p): + super(ConvNormRelu1d, self).__init__() + self.conv = nn.Sequential( + nn.Conv1d(in_channels, out_channels, kernel_size=k, stride=s, padding=p), + nn.BatchNorm1d(out_channels), + nn.LeakyReLU(0.2, inplace=True) + ) + self.conv.apply(weights_init) + + def forward(self, x): + x = self.conv(x) + return x + + +class DoubleConv1d(nn.Module): + """(conv => BN => ReLU) * 2""" + def __init__(self, in_channels, out_channels, k, s, p): + super(DoubleConv1d, self).__init__() + self.block = nn.Sequential( + ConvNormRelu1d(in_channels, out_channels, k, s, p), + ConvNormRelu1d(out_channels, out_channels, k, s, p), + ) + + def forward(self, x): + x = self.block(x) + return x + + +class Down1d(nn.Module): + def __init__(self, in_channels, out_channels): + super(Down1d, self).__init__() + self.block = nn.Sequential( + # nn.MaxPool1d(2, 2), + ConvNormRelu1d(in_channels, out_channels, k=4, s=2, p=1) + ) + + def forward(self, x): + x = self.block(x) + return x + + +class Up1d(nn.Module): + """Up sampling => add => Double Conv""" + def __init__(self, in_channels, out_channels): + super(Up1d, self).__init__() + self.block = nn.Sequential( + DoubleConv1d(in_channels, out_channels, k=3, s=1, p=1) + ) + + def forward(self, x, y): + """Following the implementation in PoseGAN""" + x = torch.repeat_interleave(x, 2, dim=2) + x = x + y + x = self.block(x) + return x + + +class UNet1d(nn.Module): + """ + Text Encoder + """ + def __init__(self, in_channels, out_channels): + super(UNet1d, self).__init__() + self.inconv = DoubleConv1d(in_channels, out_channels, k=3, s=1, p=1) + self.down1 = Down1d(out_channels, out_channels) + self.down2 = Down1d(out_channels, out_channels) + self.down3 = Down1d(out_channels, out_channels) + self.down4 = Down1d(out_channels, out_channels) + self.down5 = Down1d(out_channels, out_channels) + self.up1 = Up1d(out_channels, out_channels) + self.up2 = Up1d(out_channels, out_channels) + self.up3 = Up1d(out_channels, out_channels) + self.up4 = Up1d(out_channels, out_channels) + self.up5 = Up1d(out_channels, out_channels) + + def forward(self, x): + x0 = self.inconv(x) + x1 = self.down1(x0) + x2 = self.down2(x1) + x3 = self.down3(x2) + x4 = self.down4(x3) + x5 = self.down5(x4) + x = self.up1(x5, x4) + x = self.up2(x, x3) + x = self.up3(x, x2) + x = self.up4(x, x1) + x = self.up5(x, x0) + return x + + +class Decoder(nn.Module): + """ + CNN Decoder + """ + def __init__(self, in_channels, out_channels): + super(Decoder, self).__init__() + self.layers = nn.Sequential( + DoubleConv1d(in_channels, out_channels, k=3, s=1, p=1), + DoubleConv1d(out_channels, out_channels, k=3, s=1, p=1), + DoubleConv1d(out_channels, out_channels, k=3, s=1, p=1), + DoubleConv1d(out_channels, out_channels, k=3, s=1, p=1), + nn.Conv1d(out_channels, 98, kernel_size=1, stride=1, padding=0) + ) + self.layers.apply(weights_init) + + # x: shape = (batch, channels, frames) + def forward(self, x): + x = self.layers(x) + return x + + +class PatchGan(nn.Module): + """ + Motion Discriminator + default forward input shape = (batch_size, 98, 64) + """ + + def __init__(self, in_channel=98, ndf=64): + """ + Parameter + ---------- + in_channel: int + Size of input channels + ndf: int (default=64) + Size of feature maps in discriminator + """ + super(PatchGan, self).__init__() + self.layer1 = nn.Conv1d(in_channel, ndf, kernel_size=4, stride=2, padding=1) + self.layer2 = nn.LeakyReLU(0.2, inplace=True) + self.layer3 = ConvNormRelu1d(ndf, ndf * 2, k=4, s=2, p=1) + self.layer4 = ConvNormRelu1d(ndf * 2, ndf * 4, k=4, s=1, p=0) + self.layer5 = nn.Conv1d(ndf * 4, 1, kernel_size=4, stride=1, padding=0) + + self.layer1.apply(weights_init) + self.layer5.apply(weights_init) + + def forward(self, x): + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(F.pad(x, [1, 2], "constant", 0)) + x = self.layer5(F.pad(x, [1, 2], "constant", 0)) + return x + + +class UnetDecoder(nn.Module): + """ + unet => cnn_decoder + """ + def __init__(self, in_channels, out_channels): + """ + Parameter + ---------- + in_channels : int + input channel size + out_channels: int + output channel size + """ + super(UnetDecoder, self).__init__() + self.unet = UNet1d(in_channels, out_channels) + self.decoder = Decoder(out_channels, out_channels) + + def forward(self, x): + x = self.unet(x) + x = self.decoder(x) + return x diff --git a/utils/speaker_const.py b/utils/speaker_const.py new file mode 100644 index 0000000..837763e --- /dev/null +++ b/utils/speaker_const.py @@ -0,0 +1,568 @@ +import numpy as np + +SPEAKERS_CONFIG = { + 'almaram': + {'median': np.array([0., -106., -194., -104., 106., 162., 137., 134., + 144., 126., 118., 104., 119., 111., 105., 101., + 125., 122., 116., 111., 133., 131., 129., 124., + 140., 139., 140., 140., -101., -87., -82., -74., + -75., -86., -83., -81., -79., -85., -85., -82., + -81., -83., -85., -83., -82., -84., -81., -80., + -79., 0., -6., 117., 126., 5., 139., 139., + 116., 131., 144., 136., 135., 134., 131., 130., + 132., 135., 132., 133., 135., 137., 135., 136., + 137., 140., 138.5, 138., 139., 109., 110., 114., + 117., 117., 114., 116., 119., 119., 117., 120., + 123., 122., 122., 127., 126., 125., 125.5, 130., + 130., 128.]), + 'mean': np.array([0., -102.831, -187.225, -105.001, 103.684, 159.654, + 126.302, 122.931, 128.711, 112.718, 106.145, 96.845, + 108.125, 101.695, 97.327, 93.299, 112.474, 108.298, + 103.437, 99.548, 117.103, 113.895, 110.956, 106.925, + 121.72, 119.542, 118.971, 118.283, -96.014, -87.539, + -80.057, -73.948, -71.795, -82.32, -79.115, -76.73, + -74.388, -82.105, -80.296, -77.741, -76.255, -82.054, + -81.601, -79.339, -77.542, -83.858, -79.748, -79.177, + -78.283, 0., -5.559, 113.522, 114.089, 4.688, + 134.756, 128.186, 105.581, 113.738, 127.147, 112.519, + 113.059, 119.44, 113.858, 114.358, 116.299, 120.788, + 116.685, 117.822, 119.544, 126.314, 122.744, 122.36, + 122.702, 131.012, 127.488, 126.576, 125.979, 102.175, + 101.177, 100.914, 99.242, 96.896, 99.399, 99.966, + 100.655, 100.153, 102.595, 104.855, 107.204, 106.748, + 107.49, 111.118, 111.368, 110.997, 112.068, 114.87, + 115.63, 114.635]), + 'scale_factor': 1.518504709101034, + 'std': np.array([0., 15.77442357, 31.74083135, 38.76473912, 16.00981399, + 28.00046935, 40.03541927, 46.47948191, 52.37930392, 47.88657929, + 48.78288609, 48.42624263, 46.59846966, 48.26843663, 50.59365643, + 52.47084523, 48.17116694, 51.37568682, 53.37495696, 54.50430897, + 50.88302655, 54.47836245, 56.48290063, 57.29840639, 54.71323058, + 58.40441966, 60.72342348, 62.18302752, 46.9228708, 44.15867388, + 46.51470467, 49.62792859, 52.8391046, 49.06342426, 52.13875502, + 55.0459726, 57.37028374, 47.79499948, 51.69669606, 54.54852811, + 56.73739486, 48.16782208, 51.27987714, 53.97625477, 56.09597344, + 49.14352283, 51.41441914, 53.19296637, 54.5650521, 0., + 6.80856218, 25.98910379, 51.67654283, 6.51771862, 27.4053729, + 48.12067543, 44.45340751, 52.2424287, 71.22975074, 65.97734186, + 72.21212861, 66.43695056, 66.68810866, 70.34944091, 74.74944548, + 62.9067171, 64.10094988, 67.69410843, 71.30516155, 63.17366068, + 63.74567016, 65.48774236, 67.96067389, 63.03325992, 62.88430532, + 63.75716606, 64.36825739, 54.93795023, 60.33219432, 67.52062355, + 69.96675951, 72.42787574, 67.3515538, 68.93075398, 70.52320168, + 71.7720112, 64.30512402, 66.268982, 68.01139893, 69.14937813, + 62.07604933, 64.26042387, 64.83709259, 65.68755583, 59.83643853, + 61.15464905, 62.18182291, 62.71329823])}, + 'angelica': { + 'median': np.array([0., -112., -170., -51., 112., 169., 89., 65., + 49.5, 35., 37., 40., 40., 41., 41., 42., + 45., 45., 45., 45., 50., 49., 49., 48., + 55., 54., 53., 53., -39., -20., 4., 19., + 28., 4.5, 21., 28., 32., -1., 17., 23., + 25., -6., 9., 15., 18., -11., 2., 7., + 10., 0., 4., 153., 154., 0., 155., 158., + 157., 156., 157., 159., 160., 158., 170., 179., + 188., 163., 178., 190., 197., 167., 183., 194., + 200., 173., 187., 195., 200., 153., 146., 139., + 137., 134., 139., 148., 157., 163., 146., 158., + 168., 173., 154., 166., 174., 179., 163., 172., + 177., 180.]), + 'mean': np.array([0., -109.943, -170.558, -58.044, 109.827, 168.036, + 94.467, 58.97, 52.736, 45.407, 44.249, 45.665, + 51.639, 49.735, 48.165, 46.499, 56.536, 53.554, + 51.313, 48.981, 60.293, 57.029, 54.887, 53.198, + 62.166, 59.695, 58.707, 57.117, -43.247, -27.359, + -8.778, 3.673, 12.136, -8.683, 7.022, 14.388, + 18.369, -13.301, 3.559, 10.722, 14.248, -18.083, + -2.826, 4.011, 7.615, -21.951, -9.82, -4.64, + -0.922, 0., 3.994, 150.072, 139.343, -0.505, + 152.459, 153.141, 150.053, 147.183, 146.11, 148.093, + 149.845, 146.275, 155.367, 162.994, 169.658, 152.608, + 163.311, 173.255, 179.806, 158.304, 168.954, 177.697, + 184.216, 163.439, 172.377, 178.879, 183.864, 137.297, + 129.411, 121.703, 118.802, 115.731, 120.401, 126.067, + 131.436, 136.041, 127.648, 134.987, 142.396, 146.487, + 135.083, 142.758, 148.748, 152.58, 143.117, 149.36, + 153.495, 156.382]), + 'scale_factor': 1.4269190517359058, + 'std': np.array([0., 16.41102529, 31.97449978, 46.67209085, 17.08171745, + 27.65228208, 37.84749544, 64.7177804, 56.76208509, 51.92178108, + 51.05505851, 53.54873271, 53.50798706, 53.75848561, 55.10932566, + 57.11749293, 53.06833994, 51.87146695, 53.27754716, 55.13937467, + 53.40742599, 51.39184915, 51.61963029, 53.751026, 54.54957785, + 52.46150946, 52.15593112, 52.43355139, 55.59562924, 58.58003174, + 62.68405472, 66.61520901, 72.01281486, 63.33601275, 68.34514991, + 69.40195571, 71.46270943, 61.3120086, 65.63071323, 66.73936407, + 68.31931276, 59.29323832, 62.04278946, 63.52757574, 65.27939012, + 57.65251598, 59.16748769, 60.96555093, 62.12406873, 0., + 8.30312977, 26.11135416, 53.26602436, 11.54088277, 29.07432405, + 38.48177645, 44.61889948, 46.82268159, 51.29107037, 55.75426756, + 61.06446573, 52.26399693, 56.81973522, 59.92917456, 62.63712187, + 52.8356351, 56.60500224, 59.76263026, 61.87583021, 53.49947275, + 56.04178695, 58.49736055, 60.91419657, 54.07829767, 55.98037934, + 57.79308228, 59.49702097, 56.72689654, 60.67602557, 67.52811852, + 73.58046477, 78.81067592, 71.40644368, 79.9644578, 85.45076889, + 89.87280634, 71.75058255, 79.95804419, 84.27783329, 87.16994798, + 71.19769737, 78.32231761, 82.11062353, 85.16347574, 70.63294777, + 75.82125296, 78.49320974, 80.60754354]) + }, + 'chemistry': {'median': np.array([0., -149., -196., -90., 145., 208., 235., 246., + 230.5, 216., 210., 211., 246., 243., 234., 227., + 254., 245., 231., 223., 258., 248., 235., 227., + 258., 249., 241., 234., -78., -57., -32., -10., + 7., -48., -22., -7., 1., -54., -26., -10., + -3.5, -58., -31., -17., -11., -57., -35., -25., + -19., 0., 6., 218., 261., -5., 163., 145., + 133., 120.5, 108., 104., 103., 101., 99., 102., + 105., 115., 115., 119., 121., 129., 131., 131., + 132., 142., 143., 144., 144., 264., 244., 229., + 225., 217., 226., 225., 232., 237., 241., 241., + 249., 252., 257., 258., 264., 266., 275., 278., + 280., 279.]), + 'mean': np.array([0., -143.897, -190.058, -86.903, 140.926, 217.721, + 264.637, 268.052, 258.891, 250.801, 250.009, 253.757, + 274.376, 274.907, 272.892, 271.535, 283.333, 280.266, + 273.551, 270.236, 287.871, 283.827, 275.362, 270.974, + 288.984, 285.165, 278.785, 274.5, -77.25, -57.731, + -34.007, -13.901, 0.662, -47.15, -22.265, -8.555, + -1.11, -52.623, -24.446, -11.432, -6.058, -55.034, + -29.605, -18.433, -13.892, -54.819, -35.755, -27.038, + -22.352, 0., 6.167, 206.923, 216.101, -6.115, + 152.987, 128.342, 124.441, 115.255, 103.173, 99.066, + 98.378, 89.811, 85.753, 88.11, 90.737, 99.924, + 99.372, 105.05, 107.36, 111.877, 113.49, 117.972, + 119.45, 124.973, 126.641, 128.84, 129.162, 217.45, + 200.878, 185.423, 180.274, 176.279, 180.945, 177.688, + 181.782, 185.668, 194.886, 193.136, 196.672, 198.439, + 209.824, 208.904, 210.14, 210.388, 224.81, 224.579, + 223.847, 222.846]), + 'scale_factor': 1.1010136119625171, + 'std': np.array([0., 17.54703368, 35.7581688, 69.24384154, + 16.8056099, 53.67166067, 105.93793103, 109.97522128, + 119.78379322, 134.17275207, 147.06528115, 156.34138272, + 135.87825663, 152.2457827, 163.87231107, 172.43461014, + 134.22994491, 152.57799725, 160.81889006, 165.36223361, + 132.95311339, 148.92147283, 155.09689538, 158.15342337, + 132.44675815, 145.00957822, 149.17857344, 151.5445677, + 70.96735517, 74.05220212, 79.28031881, 86.12266368, + 91.10061337, 85.69535285, 94.60962306, 97.50350237, + 98.95066397, 87.13493485, 96.15947735, 97.23624518, + 97.90133112, 87.86163465, 96.21616795, 96.25821269, + 96.51474673, 88.89298194, 95.43154078, 95.6565134, + 95.6844193, 0., 9.75823299, 41.59960422, + 85.24175502, 9.68275658, 49.8952586, 94.65167212, + 102.899779, 105.34503299, 110.68926358, 117.52109446, + 123.02262847, 116.44959974, 125.51849263, 131.95820513, + 136.88519946, 120.07525234, 129.27803223, 132.21424848, + 134.24947821, 122.7018332, 129.69374657, 130.27591956, + 130.28394951, 124.37462873, 129.23850092, 129.14637587, + 128.23497868, 92.32908263, 93.77521589, 97.21511236, + 102.35374406, 107.13176541, 95.68562052, 103.19170827, + 108.58317768, 113.0156174, 96.96912397, 104.87073712, + 109.90668049, 113.97662163, 98.84680584, 106.12451547, + 109.89005597, 112.6861813, 101.29443173, 106.64136045, + 109.06442862, 110.48723132])}, + 'conan': {'mean': np.array([0., -109.583, -156.881, -120.313, 109.631, 152.443, + 108.144, 101.458, 92.114, 81.513, 76.847, 74.993, + 91.03, 83.972, 81.598, 81.622, 95.874, 87.808, + 85.537, 85.997, 99.22, 92.344, 90.057, 89.867, + 102.278, 97.839, 95.828, 94.729, -112.329, -107.723, + -101.084, -98.403, -97.372, -113.236, -106.546, -100.633, + -97.867, -115.78, -107.062, -100.376, -97.709, -115.577, + -107.58, -101.694, -98.566, -113.886, -107.761, -103.555, + -100.291, 0., -3.122, 154.074, 210.401, 3.051, + 167.82, 232.961, 237.28, 231.731, 227.719, 228.673, + 230.681, 232.244, 237.448, 240.62, 242.516, 240.195, + 246.415, 249.652, 249.489, 247.002, 253.542, 255.679, + 254.841, 254.024, 258.919, 260.762, 260.172, 216.536, + 210.82, 205.289, 204.407, 205.708, 212.984, 215.676, + 216.734, 217.611, 222.036, 225.523, 225.602, 225.524, + 230.145, 233.169, 232.708, 231.662, 236.648, 239.776, + 239.635, 238.414]), + 'scale_factor': 1.4305381955532037, + 'std': np.array([0., 10.59231377, 32.5747577, 87.34243545, + 11.07947828, 28.23531036, 81.80955485, 94.92227471, + 96.95924404, 102.26105726, 108.35750824, 113.98373985, + 105.4720489, 115.84258809, 120.26970689, 122.74060907, + 105.35736388, 116.27625353, 119.70779687, 121.70631451, + 104.85591829, 114.45731809, 117.19618488, 118.52391873, + 104.49567798, 111.82849851, 114.05433975, 114.97168155, + 99.95218236, 100.40852688, 105.43183079, 111.24512839, + 116.71825742, 109.49891462, 119.41624632, 124.78480801, + 128.55843539, 111.29522721, 121.82060645, 127.12367452, + 129.74770256, 112.66641945, 122.448853, 126.91806161, + 128.77232484, 114.58103248, 122.30987646, 125.25131925, + 126.38526939, 0., 6.75537682, 32.26199814, + 97.00566065, 6.26788633, 24.78805357, 94.33954356, + 105.77244254, 112.9200719, 123.72302146, 132.94217567, + 139.79612741, 125.72067636, 134.83989505, 139.963108, + 143.67274531, 124.0620771, 133.07969332, 137.87443888, + 140.63537208, 122.2001882, 130.13695953, 134.74150051, + 137.58544152, 120.85215523, 126.25547291, 130.08463151, + 131.9334545, 105.97377366, 113.70617222, 125.14332375, + 132.79858189, 140.62497195, 128.00962364, 137.0962473, + 141.59214401, 144.98927436, 126.15767398, 134.63195561, + 138.48896561, 140.98466379, 124.44862384, 131.4354459, + 134.99263956, 137.01364077, 122.28814373, 127.69337424, + 130.25399716, 131.22054185]), + 'median': np.array([0., -110., -153., -106., 110., 149., 106., 97., + 83., 74., 71., 70., 85., 82., 81., 81., + 90., 89., 87., 86., 93., 92., 91., 90., + 95., 93., 94., 92., -101., -91., -78., -74., + -74., -90., -80., -73., -69., -93., -80., -72., + -69., -93., -80., -72., -69.5, -92., -81., -76., + -73., 0., -4., 159., 207., 4., 171., 230., + 235., 225., 213., 212., 216., 219., 226., 233., + 238., 231., 241., 247., 248., 242., 252., 256., + 256., 252., 260., 262., 262., 215., 202., 191., + 189., 193., 207., 210., 209., 210., 220., 222., + 220., 220., 230., 232., 229., 228., 238., 240., + 238., 236.]), + }, + 'ellen': {'median': np.array([0., -129., -172., -147., 129., 171., 141., 133., + 128., 119., 107.5, 101., 124., 121., 118., 118., + 125., 119., 114., 114., 123., 116., 113., 112., + 121., 113., 111., 109., -143., -137., -130., -120., + -116., -139., -137., -135., -135., -142., -138., -134., + -133., -141., -136., -132., -130., -139., -134., -131.5, + -129., 0., -2., 198., 229., 0., 199., 253., + 261., 251., 243., 238., 235., 263., 266., 265., + 265., 273., 277., 277., 274., 280., 284., 283., + 282., 285., 287., 288., 286., 228., 222., 217., + 218., 219., 230., 236., 238., 240., 242., 249., + 251., 250., 251., 257., 257., 255., 257., 262., + 262., 260.]), + 'mean': np.array([0., -118.909, -164.154, -147.246, 118.925, 160.707, + 133.202, 126.031, 122.365, 116.286, 108.719, 106.006, + 118.932, 113.969, 113.305, 115.011, 118.991, 112.665, + 109.438, 109.385, 117.644, 111.169, 108.96, 107.772, + 116.547, 110.923, 109.41, 107.968, -144.238, -140.508, + -135.596, -132.849, -131.928, -145.429, -143.271, -141.212, + -141.904, -146.314, -142.76, -138.772, -137.964, -145.332, + -140.897, -137.975, -136.323, -143.97, -139.602, -138.326, + -136.509, 0., -1.674, 186.215, 222.679, 1.261, + 186.058, 252.238, 256.284, 254.343, 250.461, 249.989, + 248.952, 262.385, 264.888, 265.272, 265.825, 268.538, + 271.047, 272.562, 272.724, 273.782, 276.197, 277.758, + 278.16, 276.863, 279.468, 282.137, 281.792, 221.492, + 219.701, 217.669, 217.751, 219.06, 225.513, 228.427, + 229.216, 229.918, 232.378, 236.486, 238.776, 239.216, + 238.599, 242.803, 244.281, 244.736, 242.657, 246.497, + 248.185, 248.794]), + 'scale_factor': 1.3185415037379011, + 'std': np.array([0., 25.8187668, 41.89625621, 77.94319396, + 26.03765302, 36.60665446, 60.07478003, 68.73073577, + 71.76289971, 78.7986688, 85.17248405, 91.8425063, + 81.57343548, 90.51250764, 96.94054866, 103.783818, + 80.62698629, 90.21175519, 96.14923898, 101.66000578, + 80.05647547, 88.31894723, 94.51180032, 98.51348139, + 80.68749464, 88.48727067, 91.83456811, 94.89260759, + 85.40677582, 91.29359198, 100.86450706, 110.52797021, + 117.97610273, 102.74902899, 113.07612285, 119.03282344, + 124.60098228, 101.47065292, 111.75965462, 117.53972101, + 121.85632812, 99.92104771, 109.12218102, 114.81892864, + 118.45598622, 99.84089893, 107.37479032, 112.15367905, + 113.79890122, 0., 5.87807145, 42.05188194, + 94.98783058, 5.69902439, 41.39882409, 85.89787748, + 92.3838262, 100.51717938, 109.75256935, 116.35522712, + 122.94281474, 110.31635769, 115.28223391, 118.60499153, + 121.26193292, 105.71004, 110.80258477, 115.78894661, + 118.8850614, 103.46256558, 108.43064231, 113.94664293, + 117.36026755, 102.35001823, 107.34869806, 111.92460065, + 114.05449897, 103.00641697, 112.17914066, 121.99203023, + 127.82240414, 133.5572177, 124.74606138, 130.2611403, + 133.63072006, 136.23373032, 120.91488377, 127.17007433, + 131.01745618, 132.5830432, 119.50055313, 125.03061302, + 127.97185643, 129.91735952, 117.63924239, 121.72976625, + 124.69851954, 126.28790743])}, + 'jon': {'mean': np.array([0., -153.331, -220.516, -210.796, 160.263, 215.725, + 113.459, 106.156, 81.586, 56.863, 41.095, 30.188, + 70.704, 50.813, 37.224, 28.324, 78.826, 54.748, + 40.502, 32.23, 84.64, 62.526, 49.402, 42.049, + 87.848, 71.01, 61.826, 55.67, -205.849, -185.743, + -168.707, -161.135, -152.156, -190.978, -188.077, -185.31, + -183.466, -206.802, -204.832, -201.539, -199.346, -218.691, + -218.876, -215.868, -212.868, -227.754, -229.963, -229.562, + -227.931, 0., -2.292, 190.993, 231.52, -1.083, + 220.65, 257.155, 261.606, 251.047, 243.471, 243.151, + 243.737, 237.272, 240.32, 248.198, 255.309, 249.663, + 256.813, 265.245, 269.657, 262.648, 268.498, 275.582, + 278.131, 275.337, 280.445, 284.139, 285.841, 241.15, + 242.714, 245.433, 247.933, 247.489, 234.907, 242.81, + 254.631, 265.567, 239.662, 249.287, 264.716, 275.327, + 246.93, 256.517, 270.548, 280.467, 255.458, 263.393, + 273.444, 281.315]), + 'median': np.array([0., -155., -220., -208., 162., 214., 102., 94., + 66., 37., 19., 6., 53., 32., 16.5, 6., + 64., 38., 22., 13., 71., 48., 33., 25., + 76., 59., 49., 41., -201., -177., -157., -147., + -135., -180., -177., -173., -171., -200., -198., -193., + -191., -214.5, -215., -212., -208., -227., -229., -229., + -226., 0., -4., 192., 253., 0., 231., 283., + 290., 281., 277., 279., 283., 269., 276., 289., + 298., 282., 295., 309., 314., 295., 307., 318., + 321., 309., 318., 324., 327., 268., 273., 282., + 290., 290., 269., 283., 300., 315., 274., 289., + 308., 322., 281., 295., 312., 324., 289., 300., + 313., 322.]), + 'scale_factor': 1.0, + 'std': np.array([0., 13.60674241, 33.3032993, 58.68593003, + 14.80816771, 41.0566849, 86.83690643, 91.52707613, + 96.46133217, 103.81552018, 111.29798729, 116.84988941, + 108.953423, 118.5386858, 121.7592864, 123.60705896, + 109.22557267, 118.82313115, 122.45853174, 125.15311862, + 109.25032906, 117.86855952, 121.46853253, 123.86147342, + 109.55747759, 116.71509714, 119.36929976, 120.8869021, + 69.0435674, 73.14218312, 81.05269367, 88.65717554, + 97.30981278, 84.00480651, 92.37692932, 96.86237608, + 100.46773036, 82.47765028, 91.27185643, 96.26633097, + 99.31366615, 80.67351188, 89.02701064, 93.30791272, + 96.34672063, 80.04087383, 86.67089264, 90.58648992, + 93.14206482, 0., 10.88231299, 30.72401261, + 64.20069782, 11.80025894, 38.96731323, 67.39309293, + 73.62073596, 79.57721276, 89.74364133, 98.04464391, + 105.0374687, 90.90715052, 99.81535754, 105.95001084, + 110.26701011, 91.25035579, 100.25840629, 105.86755393, + 108.97285603, 91.38375182, 99.14092997, 103.26581853, + 106.88325331, 91.70622351, 97.17438436, 99.73720308, + 102.19299251, 75.27206321, 82.64365798, 93.43973197, + 101.8475258, 107.69945162, 95.75379027, 106.19971704, + 113.93437953, 121.08109477, 94.56100547, 104.48614564, + 110.77791, 116.16508973, 94.30819212, 101.95049637, + 106.9177146, 111.37067348, 93.46864841, 99.52853134, + 103.17545669, 105.90626882])}, + 'oliver': {'mean': np.array([0., -163.658, -211.057, -134.649, 164.739, 209.073, + 143.511, 127.913, 116.269, 103.244, 96.397, 93.487, + 127.25, 111.415, 99.029, 90.083, 132.054, 113.831, + 99.465, 91.78, 130.525, 111.719, 97.237, 89.394, + 125.449, 109.247, 100.154, 93.562, -117.402, -100.158, + -82.374, -73.054, -67.696, -111.641, -100.162, -88.474, + -79.601, -119.885, -105.883, -91.469, -83.153, -122.363, + -108.076, -95.898, -88.328, -121.041, -110.486, -103.011, + -97.96, 0., 1.885, 226.397, 229.072, -4.987, + 213.13, 243.52, 255.207, 231.348, 204.754, 188.757, + 177.801, 206.763, 198.08, 197.119, 196.928, 223.176, + 215.778, 214.844, 213.122, 239.312, 233.451, 230.949, + 228.755, 253.464, 248.937, 246.076, 243.138, 232.561, + 211.604, 184.101, 165.544, 151.386, 178.313, 165.155, + 163.223, 162.372, 192.993, 181.789, 181.059, 180.34, + 208.741, 200.467, 198.325, 196.636, 224.591, 218.906, + 215.748, 213.304]), + 'scale_factor': 0.9549234615419752, + 'std': np.array([0., 11.99879311, 29.82817043, 54.03489427, + 12.11713163, 35.70626935, 100.127588, 114.55109529, + 117.61694027, 122.42752331, 129.34672547, 135.35726737, + 120.82736238, 130.10660542, 136.48968517, 139.7453903, + 122.16887936, 132.40670844, 138.38685911, 141.79375727, + 123.35014947, 133.45671223, 139.08317954, 142.11474506, + 126.82329202, 133.81198747, 138.99624558, 141.54416327, + 61.03361693, 62.88763818, 68.25358689, 73.29697868, + 78.72461867, 68.02637811, 76.81368209, 81.46072258, + 84.88038524, 67.26606704, 77.41239766, 82.52517821, + 85.17540485, 67.84581955, 77.28375136, 81.87896919, + 84.03993346, 69.44094843, 77.42084864, 80.58891288, + 82.31571174, 0., 9.47827912, 29.89226306, + 65.29199657, 9.53880658, 33.56934763, 59.26234555, + 71.23782809, 69.40388243, 72.38679081, 78.91335724, + 86.99427222, 74.79921678, 82.52043141, 88.62611827, + 93.39188838, 77.27343026, 84.96602095, 90.18950972, + 93.68292863, 79.41305092, 85.97462183, 89.45524243, + 92.88002463, 80.98173068, 85.93479523, 88.12695515, + 89.9635646, 74.11004169, 73.49017066, 76.71515365, + 82.82322177, 89.53648979, 78.33528599, 87.32146915, + 95.1113835, 101.48254833, 81.44670006, 91.03851097, + 96.9321181, 101.12747599, 84.49668585, 93.10840408, + 97.03436183, 99.7947569, 87.07782564, 93.60984544, + 96.28646061, 97.48070365]), + 'median': np.array([0., -165., -211., -128., 167., 205., 118., 96.5, + 82., 66., 56., 50., 94., 77., 62., 51., + 99., 82., 63., 53., 99., 77., 59., 50., + 92., 72., 60., 53., -110., -93., -74., -63., + -56., -104., -94., -81., -71., -112., -99., -83., + -73., -114., -101., -86., -77., -112., -101., -93., + -86., 0., 2., 227., 257., -4., 219., 267., + 279., 253., 220., 200., 182., 227., 219., 221., + 222., 247., 241., 242., 242., 267., 262., 260., + 258., 284., 280., 277., 273., 264., 241., 210., + 188., 167., 211., 196., 192., 190., 226., 215., + 214., 212., 243., 235., 235., 232., 260., 255., + 253., 251.]), + }, + 'median': np.array([0., -165., -211., -128., 167., 205., 118., 96.5, + 82., 66., 56., 50., 94., 77., 62., 51., + 99., 82., 63., 53., 99., 77., 59., 50., + 92., 72., 60., 53., -110., -93., -74., -63., + -56., -104., -94., -81., -71., -112., -99., -83., + -73., -114., -101., -86., -77., -112., -101., -93., + -86., 0., 2., 227., 257., -4., 219., 267., + 279., 253., 220., 200., 182., 227., 219., 221., + 222., 247., 241., 242., 242., 267., 262., 260., + 258., 284., 280., 277., 273., 264., 241., 210., + 188., 167., 211., 196., 192., 190., 226., 215., + 214., 212., 243., 235., 235., 232., 260., 255., + 253., 251.]), + 'rock': {'mean': np.array([0., -50.691, -71.908, -59.31, 52.451, 78.76, 39.172, + 35.689, 30.193, 21.265, 14.115, 9.367, 21.133, 11.618, + 6.135, 2.322, 20.812, 10.702, 4.949, 1.614, 20.973, + 11.516, 6.355, 3.296, 20.881, 12.962, 9.191, 6.544, + -56.623, -53.151, -46.877, -42.133, -38.616, -51.007, -45.297, + -41.204, -38.231, -51.745, -45.244, -40.455, -37.947, -51.901, + -45.412, -40.849, -38.384, -51.166, -45.603, -42.475, -40.21, + 0., 2.919, 71.638, 62.753, -2.236, 68.846, 61.088, + 62.271, 55.003, 47.461, 44.187, 41.541, 47.717, 45.33, + 44.813, 44.858, 53.181, 51.377, 51.07, 50.835, 58.485, + 57.221, 56.71, 56.121, 63.687, 62.794, 62.234, 61.799, + 62.994, 56.582, 49.842, 46.872, 44.653, 49.173, 46.988, + 47.094, 47.456, 54.367, 52.627, 52.811, 52.692, 59.522, + 58.338, 58.022, 57.695, 64.378, 63.897, 63.478, 62.855]), + 'median': np.array([0., -53., -74., -61., 54., 82., 39., 36., 29., + 18., 9., 3., 19., 9., 3., -1., 19., 8., + 2., -2., 19., 9., 3., 0., 19., 10., 6., + 3., -60., -54., -47., -41., -36., -52., -45., -39.5, + -36., -53., -45., -39., -35., -54., -44., -39., -36., + -52., -45., -41., -38., 0., 3., 75., 67., -2., + 71., 63., 65., 57., 49., 46., 44., 50., 49., + 50., 50., 57., 56., 57., 57., 63., 63., 63., + 62., 69., 69., 68., 68., 67., 60., 52., 49., + 47., 52., 50., 51., 52., 58., 57., 58., 58., + 64., 64., 64., 64., 70., 70., 69., 69.]), + 'scale_factor': 3.0404103081189042, + 'std': np.array([0., 10.4203416, 17.22636166, 24.77684201, 11.10331478, + 17.75664383, 18.53861958, 20.16805095, 20.55582037, 23.09192012, + 25.92936125, 28.61632945, 23.87838585, 26.28121146, 27.0651949, + 28.08466336, 23.62597418, 25.974626, 26.90238649, 28.14425348, + 22.98891626, 25.03772641, 26.18272283, 27.28751333, 22.26856167, + 23.93914276, 24.86223077, 25.93993184, 29.16729112, 28.9929681, + 30.65772123, 32.47656557, 34.01532807, 33.28412461, 35.77885955, + 36.57866023, 37.38547364, 33.74483627, 36.30537789, 36.96885142, + 37.77928256, 34.01107465, 36.46546662, 37.06834497, 37.90272476, + 34.39137165, 36.44109481, 36.96086275, 37.45362332, 0., + 2.89973085, 18.24277819, 20.42860717, 2.76048981, 13.92466459, + 17.1719031, 19.71708799, 19.53716947, 21.61500588, 24.15897413, + 26.41163227, 24.7449977, 27.6888985, 29.30245776, 30.29448524, + 25.62081652, 28.73668163, 30.19561392, 31.28564807, 25.91670841, + 28.75941861, 30.06825402, 31.11173989, 26.03480422, 28.16124223, + 29.02597533, 29.7304322, 23.28072087, 22.67441898, 24.5621464, + 26.42123419, 28.12181699, 26.97838155, 29.05460129, 30.11018373, + 30.98118887, 27.92701758, 30.10780415, 30.95133727, 31.63882324, + 28.46049044, 30.59816589, 30.90853468, 31.19961498, 28.92246732, + 30.4676614, 30.74221066, 30.69778453])}, + 'seth': {'mean': np.array([0.00000e+00, -1.56603e+02, -1.92734e+02, -1.32233e+02, + 1.60156e+02, 1.93264e+02, 7.84210e+01, 6.13590e+01, + 3.63920e+01, 1.03240e+01, 3.24000e+00, 5.41000e-01, + 2.59350e+01, 7.05000e+00, -5.32800e+00, -1.25190e+01, + 3.17200e+01, 8.77700e+00, -6.74700e+00, -1.21180e+01, + 3.12780e+01, 1.12180e+01, -3.30400e+00, -8.38000e+00, + 2.93330e+01, 1.41880e+01, 3.79600e+00, -4.10200e+00, + -1.22910e+02, -1.04130e+02, -8.88500e+01, -7.73330e+01, + -6.66150e+01, -1.06764e+02, -9.15200e+01, -7.91270e+01, + -7.09780e+01, -1.12913e+02, -9.24190e+01, -7.87410e+01, + -7.28720e+01, -1.11232e+02, -9.46040e+01, -8.03730e+01, + -7.46850e+01, -1.05766e+02, -9.38100e+01, -8.45570e+01, + -7.68080e+01, 0.00000e+00, -1.52000e-01, 2.18895e+02, + 2.68892e+02, 3.11000e-01, 2.37138e+02, 2.60438e+02, + 2.62680e+02, 2.46628e+02, 2.28206e+02, 2.22251e+02, + 2.19554e+02, 2.28401e+02, 2.20632e+02, 2.23361e+02, + 2.28428e+02, 2.42971e+02, 2.34319e+02, 2.39472e+02, + 2.41920e+02, 2.55531e+02, 2.52430e+02, 2.51910e+02, + 2.51151e+02, 2.67728e+02, 2.67377e+02, 2.66409e+02, + 2.64099e+02, 2.76742e+02, 2.61222e+02, 2.46399e+02, + 2.44755e+02, 2.41459e+02, 2.54082e+02, 2.51658e+02, + 2.56695e+02, 2.62400e+02, 2.68415e+02, 2.68020e+02, + 2.72175e+02, 2.74369e+02, 2.80791e+02, 2.83345e+02, + 2.84913e+02, 2.83823e+02, 2.92152e+02, 2.94754e+02, + 2.95650e+02, 2.94466e+02]), + 'median': np.array([0., -160., -188., -90., 164., 188., 74., 61., + 30., -5., -13., -15., 14., -17., -37., -47., + 22., -15., -39., -48., 21., -12., -35., -45., + 15., -9., -26., -40., -82., -55., -35., -19., + -1., -50., -19., 3., 17., -54., -12.5, 12., + 23., -50., -14., 12., 23., -38., -14., 6., + 20., 0., 0., 231., 280.5, 0., 243.5, 284., + 288., 270., 243., 237., 237., 250., 236., 243., + 253., 271., 260., 269., 273., 286., 285., 286., + 285., 303., 306., 305., 302., 287., 268., 244., + 241., 234., 260., 251., 258., 266., 279., 277., + 283., 284., 295., 299., 301., 297., 309., 314., + 314., 312.]), + 'scale_factor': 0.9900081765632547, + 'std': np.array([0., 18.55767741, 35.88235282, 101.25383307, + 19.41946611, 38.02917701, 54.45006666, 61.64462766, + 64.83266411, 72.31580065, 77.08544869, 83.25050342, + 74.10669858, 79.17970384, 85.60273603, 88.84267915, + 71.75098327, 79.86841222, 86.6311664, 90.07606828, + 73.40176235, 79.49821681, 84.53137633, 89.01422134, + 72.98420453, 77.71971858, 83.04868683, 87.2696946, + 111.3822423, 118.55779645, 126.86364925, 137.85477181, + 149.29311027, 136.38734657, 153.90880287, 164.70497525, + 173.08522616, 139.26315173, 160.31123304, 173.69513499, + 183.07759452, 143.54942764, 162.06102303, 175.3751404, + 183.9359502, 148.61482175, 163.27092178, 173.38245803, + 181.60665499, 0., 9.34959336, 37.72078969, + 49.82852934, 9.37562153, 35.3828342, 61.75611837, + 69.91869278, 69.9318641, 72.31696595, 75.38982689, + 78.29895966, 76.43157855, 80.22123519, 83.73378457, + 87.34744882, 76.80201924, 79.60171631, 84.10195727, + 87.62829223, 78.27895655, 82.02017496, 84.37008889, + 86.20353936, 80.41539664, 84.47919786, 85.72448728, + 86.26808911, 54.1489006, 55.01033281, 59.79178705, + 62.95996327, 68.53929033, 60.30863351, 64.76271332, + 66.97766773, 69.2184224, 60.48514508, 63.85550564, + 65.756052, 68.09354477, 61.19543544, 64.2909634, + 65.19778701, 67.21045805, 62.52371467, 65.29180258, + 66.37902907, 65.94090418])}, + 'shelly': {'median': np.array([0., -44., -60., -53., 44., 65., 40., 34., 33., 32., 32., + 32., 31., 28., 26., 25., 31., 28., 26., 25., 31., 28., + 26., 25., 30., 28., 27., 26., -49., -49., -52., -55., -56., + -55., -57., -57., -58., -56., -57., -57., -57., -56., -56., -57., + -56., -55., -56., -56., -56., 0., 0., 70., 99., 0., 69., + 84., 86., 84., 81., 79., 77., 85., 85., 84., 83., 88., + 88., 87., 86., 91., 91., 90., 89., 92., 92., 91., 91., + 104., 102., 100., 99., 99., 107., 109., 110., 111., 111., 114., + 115., 115., 115., 118., 119., 119., 117., 120., 121., 121.]), + 'mean': np.array([0.00000e+00, -4.39670e+01, -6.16120e+01, -5.66750e+01, + 4.38510e+01, 6.52460e+01, 4.23140e+01, 3.73650e+01, + 3.57040e+01, 3.45580e+01, 3.44570e+01, 3.45120e+01, + 3.52910e+01, 3.31350e+01, 3.19850e+01, 3.09850e+01, + 3.51300e+01, 3.28280e+01, 3.14670e+01, 3.08340e+01, + 3.51450e+01, 3.27940e+01, 3.16110e+01, 3.05940e+01, + 3.47090e+01, 3.29040e+01, 3.22040e+01, 3.11530e+01, + -5.37470e+01, -5.32070e+01, -5.42620e+01, -5.62220e+01, + -5.78420e+01, -5.81480e+01, -5.93390e+01, -5.98870e+01, + -6.00440e+01, -5.89260e+01, -5.98730e+01, -5.95250e+01, + -5.97740e+01, -5.86410e+01, -5.93170e+01, -5.91710e+01, + -5.94890e+01, -5.80360e+01, -5.85560e+01, -5.88470e+01, + -5.87110e+01, 0.00000e+00, -9.70000e-02, 6.69060e+01, + 9.19000e+01, -4.16000e-01, 6.65610e+01, 8.19670e+01, + 8.38610e+01, 8.18550e+01, 7.90090e+01, 7.72360e+01, + 7.61760e+01, 8.11380e+01, 8.04890e+01, 7.99180e+01, + 7.96400e+01, 8.42210e+01, 8.41600e+01, 8.36870e+01, + 8.31360e+01, 8.67250e+01, 8.67590e+01, 8.62470e+01, + 8.56190e+01, 8.85240e+01, 8.86890e+01, 8.82870e+01, + 8.75840e+01, 9.63500e+01, 9.51850e+01, 9.30020e+01, + 9.14720e+01, 9.06760e+01, 9.54290e+01, 9.63560e+01, + 9.65460e+01, 9.66780e+01, 9.88890e+01, 9.99080e+01, + 1.00404e+02, 1.00404e+02, 1.01832e+02, 1.02755e+02, + 1.03448e+02, 1.03503e+02, 1.04009e+02, 1.05167e+02, + 1.05725e+02, 1.06154e+02]), + 'scale_factor': 3.570953563050855, + 'std': np.array([0., 5.97075464, 18.08920827, 33.06550128, 6.05266875, + 18.51673524, 32.79398427, 37.3746408, 38.17355608, 41.0028857, + 43.89186885, 46.5736176, 42.3892005, 46.75329694, 48.91665131, + 50.78065355, 42.70645267, 46.58178202, 48.63481172, 50.08577087, + 42.50759903, 45.92044821, 47.71441794, 49.30368307, 42.41971616, + 45.19445524, 46.42829293, 47.59075111, 38.04543325, 38.97140171, + 40.60807008, 42.62166956, 45.15082542, 42.1350934, 44.92540572, + 47.14573396, 49.22062641, 42.10518405, 45.64207347, 47.96439695, + 49.88419513, 42.26076335, 45.48921313, 47.56071655, 49.44677825, + 42.74354576, 45.27401975, 46.8043544, 48.27524706, 0., + 4.71376612, 19.68423643, 46.76840814, 4.8071763, 17.32877027, + 37.98475893, 44.62642355, 44.27660754, 46.27339321, 49.00555381, + 51.54084811, 48.50106139, 51.8787999, 54.01876781, 55.82720126, + 49.17186349, 53.1482869, 55.37935564, 56.9398411, 49.60537647, + 53.15278844, 55.01163505, 56.67173757, 49.79045515, 52.37627592, + 53.64364483, 54.5685527, 53.91743225, 54.13352727, 55.90663642, + 58.86213737, 61.06556332, 58.86913418, 62.15002224, 64.03989291, + 65.85472129, 59.38530693, 62.81211297, 64.88760116, 66.08761445, + 59.60905784, 62.46751936, 64.10454973, 65.54253574, 59.33884831, + 61.90099443, 63.20263741, 64.26842369])}} diff --git a/utils/torch_dataset.py b/utils/torch_dataset.py new file mode 100644 index 0000000..c5dd7fc --- /dev/null +++ b/utils/torch_dataset.py @@ -0,0 +1,48 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import torch +import torch.utils.data as data_utils +import numpy as np +import pandas as pd +from utils.speaker_const import SPEAKERS_CONFIG + + +class wp_dataset(torch.utils.data.Dataset): + def __init__(self, path, fnames, speaker): + self.path = path + self.fnames = fnames + self.speaker = speaker + + def __len__(self): + return len(self.fnames) + + def __getitem__(self, idx): + npz = np.load(self.path / self.fnames[idx]) + wvec = npz['wvec'] # shape = (frames, 300) + poses = npz['poses'] # shape = (frames, 2, 49) + poses = np.reshape(poses, (poses.shape[0], poses.shape[1] * poses.shape[2])) # shape = (frames, 98) + # Standardization using SPEAKERS_CONFIG + poses = (poses - SPEAKERS_CONFIG[self.speaker]['mean']) / (SPEAKERS_CONFIG[self.speaker]['std'] + np.finfo(float).eps) + wvec = np.transpose(wvec, (1, 0)) # shape = (300, frames) + poses = np.transpose(poses, (1, 0)) # shape = (98, frames) + return torch.Tensor(wvec), torch.Tensor(poses) + + +def get_datalist(df_path, min_ratio=0.75, max_ratio=1.25): + df = pd.read_csv(df_path) + speaker = df['speaker'][0] + + shoulder_w = np.sqrt((SPEAKERS_CONFIG[speaker]['median'][4] - SPEAKERS_CONFIG[speaker]['median'][1]) ** 2 + + (SPEAKERS_CONFIG[speaker]['median'][53] - SPEAKERS_CONFIG[speaker]['median'][50]) ** 2) + min_w = shoulder_w * min_ratio + max_w = shoulder_w * max_ratio + shoulder_cond = (min_w < df['min_sh_width']) & (df['max_sh_width'] < max_w) + + file_exist = df['npz_fn'].notnull() + train_list = df[(df['dataset'] == 'train') & shoulder_cond & file_exist]['npz_fn'] + dev_list = df[(df['dataset'] == 'dev') & shoulder_cond & file_exist]['npz_fn'] + + print('train: ', len(train_list), ' / ', len(df[df['dataset'] == 'train'])) + print('dev: ', len(dev_list), ' / ', len(df[df['dataset'] == 'dev'])) + return train_list.to_list(), dev_list.to_list()