forked from jshermeyer/VDSR4Geo
-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathCreate_SR_NoGEO.py
58 lines (48 loc) · 1.83 KB
/
Create_SR_NoGEO.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import data
import predict
import numpy as np
import tensorflow as tf
from scipy import misc
from skimage import color
import os
import sys
import gdal
import glob
from tqdm import tqdm
#python3 Create_SR_NoGEO.py "input/data/" "/output/data/" 2
def SR_it(input_dir,output_dir,scaling_factor):
base_dir=os.getcwd()
file_names = []
projs=[]
geos=[]
SF=scaling_factor
if input_dir.endswith("/"):
O=input_dir.split("/")[-2]
else:
O=input_dir.split("/")[-1]
with tf.Session() as session:
network = predict.load_model(session)
driver = gdal.GetDriverByName("GTiff")
os.chdir(input_dir)
images = glob.glob('*.tif')
os.chdir(base_dir)
if not os.path.exists(output_dir):
os.mkdir(output_dir)
for file_name in tqdm(os.listdir(input_dir)):
file_names.append(file_name)
for set_name in [O]:
for scaling_factor in [SF]:
dataset = data.SR_Run(set_name, scaling_factors=[scaling_factor])
for I, file_name in tqdm(zip(dataset.images,file_names)):
Im=[I]
prediction = predict.predict(Im, session, network, targets=None, border=scaling_factor)
prediction=prediction[0]
prediction=np.swapaxes(prediction,-1,0)
prediction=np.swapaxes(prediction,-1,1)
out=output_dir+str(file_name)
DataSet = driver.Create(out, prediction.shape[2], prediction.shape[1], prediction.shape[0], gdal.GDT_Byte)
for i, image in enumerate(prediction, 1):
DataSet.GetRasterBand(i).WriteArray( image )
del DataSet
if __name__ == "__main__":
SR_it(sys.argv[1],sys.argv[2],int(sys.argv[3]))