-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathbcgrun.py
126 lines (101 loc) · 3.88 KB
/
bcgrun.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import numpy as np
from scipy.io import savemat
import h5py
from unet import bcgunet
import argparse
import platform
import os
import time
import glob
import os
def remove_common_substrings(strings):
# Split each string by os.sep
split_strings = [s.split(os.sep) for s in strings]
# Find the shortest split string in the list
shortest_split_string = min(split_strings, key=len)
# Initialize common substrings list
common_substrings = []
# Iterate through the elements of the shortest split string
for i in range(len(shortest_split_string)):
# Check if the element is common in all split strings
element = shortest_split_string[i]
if all(element in string for string in split_strings):
common_substrings.append(element)
# Remove the common substrings from each split string
result = []
for string in split_strings:
new_string = [part for part in string if part not in common_substrings]
result.append(new_string)
# Join the parts of the strings using os.sep
return ["_".join(parts) for parts in result]
def main():
parser = argparse.ArgumentParser()
parser.add_argument("input", nargs="+", type=str, help="Input mat file")
parser.add_argument("-o", "--output", default=None, help="Output Path")
parser.add_argument(
"-i", "--iteration", default=5000, type=int, help="Number of iterations"
)
parser.add_argument(
"-l", "--learning-rate", default=1e-3, type=float, help="Learning rate"
)
parser.add_argument(
"-w", "--window-size", default=2, type=int, help="Window size (seconds)"
)
parser.add_argument(
"-noc",
"--no-one-cycle",
action="store_true",
help="Disable one cycle scheduler",
)
args = parser.parse_args()
print(f"Settings: {args}")
print("Starting BCGunet.....")
# if the input file is a folder or contains an asterisk
# use the glob function to find all input files.
ffs = args.input
if os.path.isdir(args.input[0]):
ffs = glob.glob(os.path.join(args.input[0], "*.mat"))
elif "*" in args.input[0]:
ffs = glob.glob(args.input[0])
short_ffs = remove_common_substrings(ffs)
print(f"Total files: {len(ffs)}")
for ii in range(len(ffs)):
f = ffs[ii]
short_f = short_ffs[ii]
# The following is for preparing the output directory
# If the user provides an output directory and it does not exist, create it for them
# If not provided, the default output directory is the folder of the input file
f_output_dir = args.output
if f_output_dir is None:
f_output_dir = os.path.dirname(os.path.abspath(f))
else:
os.makedirs(f_output_dir, exist_ok=True)
if len(ffs) > 1:
f_output = short_f.replace(".mat", "_unet.mat")
else:
f_output = os.path.basename(f).replace(".mat", "_unet.mat")
ff_output = os.path.join(f_output_dir, f_output)
print(f"{ii + 1}: Processing {f}.....")
# to create shorter filename for multiple mat files
t = time.time()
mat = h5py.File(f, "r")
ECG = np.array(mat["ECG"]).flatten()
EEG = np.array(mat["EEG_before_bcg"]).T
# (input_eeg, input_ecg, sfreq=5000, iter_num=5000, winsize_sec=2, lr=1e-3, onecycle=True)
EEG_unet = bcgunet.run(
EEG,
ECG,
iter_num=args.iteration,
winsize_sec=args.window_size,
lr=args.learning_rate,
onecycle=not args.no_one_cycle,
)
result = dict()
result["EEG_clean"] = EEG_unet
savemat(ff_output, result, do_compression=True)
print("Writing output:", ff_output)
print("Processing time: %d seconds" % (time.time() - t))
if __name__ == "__main__":
main()
if platform.system() == "Windows":
os.system("pause")