diff --git a/bert_gen.py b/bert_gen.py index 25cd7d97b..9e02ac00e 100644 --- a/bert_gen.py +++ b/bert_gen.py @@ -14,6 +14,8 @@ def process_line(line): if torch.cuda.is_available(): gpu_id = rank % torch.cuda.device_count() device = torch.device(f"cuda:{gpu_id}") + else: + device = torch.device("cpu") wav_path, _, language_str, text, phones, tone, word2ph = line.strip().split("|") phone = phones.split(" ") tone = [int(i) for i in tone.split(" ")]