-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgenerate_test_data.py
65 lines (57 loc) · 2.17 KB
/
generate_test_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import _pickle as pickle
import h5py
import numpy as np
import argparse
from sklearn.preprocessing import normalize
from sklearn.model_selection import train_test_split
from tqdm import trange
data_dir = "./data/RML2016.10a_dict.pkl"
with open(data_dir, "rb") as f:
data = pickle.load(f, encoding = "iso-8859-1")
mod_list = list({key[0] for key in data.keys()})
SNR_range = range(-12, 12, 2)
feature_dict = {}
logit_dict = {}
parser = argparse.ArgumentParser()
parser.add_argument("--num_sample")
parser.add_argument("--num_type")
args = parser.parse_args()
sample_per_pair = int(args.num_sample)
num_type = int(args.num_type)
for state in trange(1<<11):
idx_list = []
label = np.zeros((1,11), dtype=np.int)
name = ""
cnt = 0
for j in range(11):
if (state == 0): break
state >>= 1
if (state & 1):
cnt += 1
idx_list.append(j)
label[0][j] = 1
name += "_" + mod_list[j]
if (cnt != num_type): continue
label = np.repeat(label, sample_per_pair, axis=0)
for snr in SNR_range:
if (snr not in feature_dict):
feature_dict[snr] = []
logit_dict[snr] = []
cur_feature_arr = np.zeros((sample_per_pair, 2, 128))
logit_dict[snr].append(label)
for idx in idx_list:
mod = mod_list[idx]
cur_feature_arr += data[(mod, snr)][:sample_per_pair]
feature_dict[snr].append(cur_feature_arr)
for snr in SNR_range:
feature_dict[snr] = np.vstack(feature_dict[snr])
logit_dict[snr] = np.vstack(logit_dict[snr])
feature_dict[snr] = feature_dict[snr].reshape((feature_dict[snr].shape[0], -1))
feature_dict[snr] = normalize(feature_dict[snr], norm="l2")
feature_dict[snr] = feature_dict[snr].reshape((feature_dict[snr].shape[0], 2, 128))
feature_dict[snr] = feature_dict[snr].transpose([0, 2, 1])
with open(f"./data/test_data_{num_type}_{sample_per_pair}.pkl", "wb") as f:
pickle.dump((feature_dict, logit_dict), f, protocol=4);
# with h5py.File(f"./data/test_data_{num_type}_{sample_per_pair}.h5", "w") as f:
# f.create_dataset("logit_dict", data=logit_dict)
# f.create_dataset('feature_dict', data=feature_dict)