-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathMain_SampleSplit.py
37 lines (31 loc) · 1.36 KB
/
Main_SampleSplit.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import os
import pickle
import random
import numpy as np
from DataUtils import LoadLIBSDataforNetworkTraining_StandardSplit, DownLoadDatasets
from Options import *
def main():
# Set up the environment
args = parser.parse_args()
cwd = os.getcwd()
ResultDir = cwd + '/%s/SampleSplits/'%(args.DataType)
if not os.path.exists(ResultDir):
os.makedirs(ResultDir)
# Download SuperCam or ChemCam datasets if it does not exist
Download = DownLoadDatasets(args)
Download.Parallel_AcquireData()
# set random seed
for i in range(args.TrialNum):
np.random.seed(i)
random.seed(i)
# Construct variables and get training, validation and testing splits of the original data
TrSampleName, ValSampleName, TeSampleName = LoadLIBSDataforNetworkTraining_StandardSplit(args)
# Save sample splits
with open(ResultDir + 'Seed%d_TrSampleName.pickle'%(i+1), 'wb') as file:
pickle.dump(TrSampleName, file, protocol=pickle.HIGHEST_PROTOCOL)
with open(ResultDir + 'Seed%d_TeSampleName.pickle'%(i+1), 'wb') as file:
pickle.dump(TeSampleName, file, protocol=pickle.HIGHEST_PROTOCOL)
with open(ResultDir + 'Seed%d_ValSampleName.pickle'%(i+1), 'wb') as file:
pickle.dump(ValSampleName, file, protocol=pickle.HIGHEST_PROTOCOL)
if __name__ == '__main__':
main()