-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathinference.py
124 lines (97 loc) · 4.43 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import argparse
import yaml
import numpy as np
import time
from ml_collections import ConfigDict
from omegaconf import OmegaConf
from tqdm import tqdm
import sys
import os
import glob
import torch
import soundfile as sf
import torch.nn as nn
from utils import demix_track, get_model_from_config
import warnings
warnings.filterwarnings("ignore")
def run_folder(model, args, config, device, verbose=False):
start_time = time.time()
model.eval()
all_mixtures_path = glob.glob(args.input_folder + '/*.wav')
total_tracks = len(all_mixtures_path)
print('Total tracks found: {}'.format(total_tracks))
instruments = config.training.instruments
if config.training.target_instrument is not None:
instruments = [config.training.target_instrument]
if not os.path.isdir(args.store_dir):
os.mkdir(args.store_dir)
if not verbose:
all_mixtures_path = tqdm(all_mixtures_path)
first_chunk_time = None
for track_number, path in enumerate(all_mixtures_path, 1):
print(f"\nProcessing track {track_number}/{total_tracks}: {os.path.basename(path)}")
mix, sr = sf.read(path)
original_mono = False
if len(mix.shape) == 1:
original_mono = True
mix = np.stack([mix, mix], axis=-1)
mixture = torch.tensor(mix.T, dtype=torch.float32)
if first_chunk_time is not None:
total_length = mixture.shape[1]
num_chunks = (total_length + config.inference.chunk_size // config.inference.num_overlap - 1) // (config.inference.chunk_size // config.inference.num_overlap)
estimated_total_time = first_chunk_time * num_chunks
print(f"Estimated total processing time for this track: {estimated_total_time:.2f} seconds")
sys.stdout.write(f"Estimated time remaining: {estimated_total_time:.2f} seconds\r")
sys.stdout.flush()
res, first_chunk_time = demix_track(config, model, mixture, device, first_chunk_time)
for instr in instruments:
vocals_output = res[instr].T
if original_mono:
vocals_output = vocals_output[:, 0]
vocals_path = "{}/{}_{}.wav".format(args.store_dir, os.path.basename(path)[:-4], instr)
sf.write(vocals_path, vocals_output, sr, subtype='FLOAT')
vocals_output = res[instruments[0]].T
if original_mono:
vocals_output = vocals_output[:, 0]
original_mix, _ = sf.read(path)
instrumental = original_mix - vocals_output
instrumental_path = "{}/{}_instrumental.wav".format(args.store_dir, os.path.basename(path)[:-4])
sf.write(instrumental_path, instrumental, sr, subtype='FLOAT')
time.sleep(1)
print("Elapsed time: {:.2f} sec".format(time.time() - start_time))
def proc_folder(args):
parser = argparse.ArgumentParser()
parser.add_argument("--model_type", type=str, default='mel_band_roformer')
parser.add_argument("--config_path", type=str, help="path to config yaml file")
parser.add_argument("--model_path", type=str, default='', help="Location of the model")
parser.add_argument("--input_folder", type=str, help="folder with songs to process")
parser.add_argument("--store_dir", default="", type=str, help="path to store model outputs")
parser.add_argument("--device_ids", nargs='+', type=int, default=0, help='list of gpu ids')
if args is None:
args = parser.parse_args()
else:
args = parser.parse_args(args)
torch.backends.cudnn.benchmark = True
with open(args.config_path) as f:
config = ConfigDict(yaml.load(f, Loader=yaml.FullLoader))
model = get_model_from_config(args.model_type, config)
if args.model_path != '':
print('Using model: {}'.format(args.model_path))
model.load_state_dict(
torch.load(args.model_path, map_location=torch.device('cpu'))
)
if torch.cuda.is_available():
device_ids = args.device_ids
if type(device_ids)==int:
device = torch.device(f'cuda:{device_ids}')
model = model.to(device)
else:
device = torch.device(f'cuda:{device_ids[0]}')
model = nn.DataParallel(model, device_ids=device_ids).to(device)
else:
device = 'cpu'
print('CUDA is not available. Run inference on CPU. It will be very slow...')
model = model.to(device)
run_folder(model, args, config, device, verbose=False)
if __name__ == "__main__":
proc_folder(None)