forked from acrarshin/RPNet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
evaluate_datasets.py
75 lines (66 loc) · 2.86 KB
/
evaluate_datasets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import wfdb as wf
import numpy as np
from glob import glob
from tqdm import tqdm
from matplotlib import pyplot as plt
import torch.nn as nn
import pandas as pd
import scipy.signal
import matplotlib.pyplot as plt
import pickle
from sklearn.preprocessing import StandardScaler
import argparse
import torch
import torch.nn as n
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torch.utils.data import TensorDataset
from py_ecg.ecgdetectors import Detectors
import py_ecg._tester_utils
from utils import score,load_model_CNN,obtain_data
def main(args):
patient_ecg,windowed_beats = obtain_data(args)
BATCH_SIZE = 64
patient_ecg_t = torch.from_numpy(patient_ecg).float()
patient_ecg_t = patient_ecg_t.view((patient_ecg_t.shape[0],1,patient_ecg_t.shape[1]))
patient_ecg_tl = TensorDataset(patient_ecg_t)
trainloader = DataLoader(patient_ecg_tl, batch_size=BATCH_SIZE)
SAVED_MODEL_PATH = args.model_path
y_pred = load_model_CNN(SAVED_MODEL_PATH,trainloader,args.device)
y_pred_1 = []
for batch in range(len(y_pred)):
for record in range(len(y_pred[batch])):
y_pred_1.append(y_pred[batch][record].cpu().numpy())
y_pred_array = np.asarray(y_pred_1)
y_pred_array_1 = np.asarray(y_pred_1)
resampled_dt = []
for record in range(y_pred_array.shape[0]):
resampled_dt.append(scipy.signal.resample(y_pred_array_1[record],3600))
y_pred_array = np.asarray(resampled_dt)
peak_locs = []
for i in range(y_pred_array.shape[0]):
peak_locs.append(scipy.signal.find_peaks(-y_pred_array[i,:],distance = 45,height = -0.2,prominence = 0.035)[0])
### Getting the amplitude values at valley location.
y_roll_valleys = []
y = []
for j in range(len(peak_locs)):
y = [y_pred_array[j,i] for i in peak_locs[j]]
y_roll_valleys.append(y)
### Calling the scoring Function
FS = 360
THR = 0.075
rec_acc,all_FP,all_FN,all_TP = score(windowed_beats,peak_locs, FS, THR)
def argparse_func():
parser = argparse.ArgumentParser()
parser.add_argument('--dataset',type = str , help = 'Choose one out of three datasets')
parser.add_argument('--datapath',type = str , help = 'Path to the dataset')
parser.add_argument('--db',default = 99,type = int,help = 'Decibel level to consider for NSTDB(Not required here)')
parser.add_argument('--evaluate_nstdb',action='store_true',help = 'To be used to evaluate nstdb decibel-wise')
parser.add_argument('--device', type = str , default = 'cpu', help = 'cuda / cpu')
parser.add_argument('--model_path', type = str , help = 'Path to the model')
args = parser.parse_args()
return args
if __name__ == "__main__":
args = argparse_func()
main(args)
pass