forked from svc-develop-team/so-vits-svc
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathconfig_editor.py
101 lines (79 loc) · 2.61 KB
/
config_editor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import json
import os.path
import sys
from typing import Any
from enum import Enum
from ruamel.yaml import YAML
from loguru import logger
class ConfigType(Enum):
index = "index"
tiny = "tiny"
diffusion = "diffusion"
def edit_config(
config_type: ConfigType,
path: str,
value: Any,
is_template: bool = False,
is_num: bool = False,
is_bool: bool = False
):
folder = "configs_template" if is_template else "configs"
name = {
ConfigType.index: "config",
ConfigType.tiny: "config_tiny",
ConfigType.diffusion: "diffusion"
}[config_type]
extension = ".json" if config_type is not ConfigType.diffusion else ".yaml"
config_full_path = folder + "/" + name + ("_template" if is_template else "") + extension
logger.info(f"Config: {config_full_path}")
if not os.path.exists(config_full_path):
raise FileNotFoundError(config_full_path)
with open(config_full_path, "r+") as f:
if extension == ".json":
config = json.load(f)
elif extension == ".yaml":
yaml = YAML()
yaml.preserve_quotes = True
yaml.indent(offset=2)
config = yaml.load(f)
else:
raise RuntimeError(f"Unexpected extension: {extension}")
keys = path.split(".")
current_dict = config
for key in keys[:-1]:
if key not in current_dict:
raise NameError(f"{key} is not in config.")
current_dict = current_dict[key]
if is_num:
if "." in value:
value = float(value)
logger.info(f"{value} is a float")
else:
value = int(value)
logger.info(f"{value} is a int")
if is_bool:
value = bool(value)
current_dict[keys[-1]] = value
f.seek(0)
if extension == ".json":
json.dump(config, f, indent=2)
elif extension == ".yaml":
yaml.dump(config, f)
logger.info(f"Done: {path} -> {value}")
def get_value_from_args(key: str, default: Any = None) -> Any:
if key in sys.argv:
index = sys.argv.index(key) + 1
if index < len(sys.argv):
return sys.argv[index]
return default
if __name__ == '__main__':
edit_config(
ConfigType(
get_value_from_args("-t", "index")
),
get_value_from_args("-p"),
get_value_from_args("-v"),
"--template" in sys.argv or "-T" in sys.argv,
"--num" in sys.argv or "-N" in sys.argv,
"--bool" in sys.argv or "-B" in sys.argv
)