Skip to content

Commit

Permalink
Merge pull request #27 from FREVA-CLINT/dev-jm
Browse files Browse the repository at this point in the history
Option of multiple input files in JSON format
  • Loading branch information
eplesiat authored Jul 10, 2024
2 parents f37081c + 381d976 commit 2dbcd01
Show file tree
Hide file tree
Showing 10 changed files with 115 additions and 43 deletions.
7 changes: 6 additions & 1 deletion climatereconstructionai/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,10 @@ def global_args(parser, arg_file=None, prog_func=None):

assert len(time_steps) == 2

if all('.json' in data_name for data_name in data_names) and (lstm_steps or channel_steps):
print('Warning: Each input file defined in your ".json" files will be considered individually.'
' This means the defined timesteps will not go beyond each files\' boundary.')

return args


Expand All @@ -147,7 +151,8 @@ def set_common_args():
arg_parser.add_argument('--mask-dir', type=str, default='masks/', help="Directory containing the mask datasets")
arg_parser.add_argument('--log-dir', type=str, default='logs/', help="Directory where the log files will be stored")
arg_parser.add_argument('--data-names', type=str_list, default='train.nc',
help="Comma separated list of netCDF files (climate dataset) for training/infilling")
help="Comma separated list of netCDF files (climate dataset) or JSON files"
" containing a list of paths to netCDF files for training/infilling")
arg_parser.add_argument('--mask-names', type=str_list, default=None,
help="Comma separated list of netCDF files (mask dataset). "
"If None, it extracts the masks from the climate dataset")
Expand Down
2 changes: 1 addition & 1 deletion climatereconstructionai/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def evaluate(arg_file=None, prog_func=None):
dataset_val = NetCDFLoader(cfg.data_root_dir, cfg.data_names, cfg.mask_dir, cfg.mask_names, "infill",
cfg.data_types, cfg.time_steps, data_stats)

n_samples = dataset_val.img_length
n_samples = len(dataset_val)

if data_stats is None:
if cfg.normalize_data:
Expand Down
21 changes: 13 additions & 8 deletions climatereconstructionai/model/conv_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@ def init_enc_conv_configs(conv_factor, img_size, enc_dec_layers, pool_layers, st
conv_config['kernel'] = (3, 3)
conv_config['out_channels'] = conv_factor // (2 ** (enc_dec_layers - i - 1))
conv_config['skip_channels'] = 0
conv_config['img_size'] = [size // (2 ** i) for size in img_size]
conv_config['rec_size'] = [size // 2 for size in conv_config['img_size']]
conv_config['img_size'] = [size // (2 ** i) if size % (2 ** i) == 0 else
size // (2 ** i) + 1 for size in img_size]
conv_config['rec_size'] = [size // 2 if size % 2 == 0 else size // 2 + 1 for size in conv_config['img_size']]

conv_configs.append(conv_config)
for i in range(pool_layers):
Expand All @@ -32,8 +33,9 @@ def init_enc_conv_configs(conv_factor, img_size, enc_dec_layers, pool_layers, st
conv_config['kernel'] = (3, 3)
conv_config['out_channels'] = conv_factor
conv_config['skip_channels'] = 0
conv_config['img_size'] = [size // (2 ** (enc_dec_layers + i)) for size in img_size]
conv_config['rec_size'] = [size // 2 for size in conv_config['img_size']]
conv_config['img_size'] = [size // (2 ** (enc_dec_layers + i)) if size % (2 ** (enc_dec_layers + i)) == 0
else size // (2 ** (enc_dec_layers + i)) + 1 for size in img_size]
conv_config['rec_size'] = [size // 2 if size % 2 == 0 else size // 2 + 1 for size in conv_config['img_size']]
conv_configs.append(conv_config)

return conv_configs
Expand All @@ -48,8 +50,10 @@ def init_dec_conv_configs(conv_factor, img_size, enc_dec_layers, pool_layers, st
conv_config['kernel'] = (3, 3)
conv_config['out_channels'] = conv_factor
conv_config['skip_channels'] = cfg.skip_layers * conv_factor
conv_config['img_size'] = [size // (2 ** (enc_dec_layers + pool_layers - i - 1)) for size in img_size]
conv_config['rec_size'] = [size // 2 for size in conv_config['img_size']]
conv_config['img_size'] = [size // (2 ** (enc_dec_layers + pool_layers - i - 1))
if size % (2 ** (enc_dec_layers + pool_layers - i - 1)) == 0
else size // (2 ** (enc_dec_layers + pool_layers - i - 1)) + 1 for size in img_size]
conv_config['rec_size'] = [size // 2 if size % 2 == 0 else size // 2 + 1 for size in conv_config['img_size']]
conv_configs.append(conv_config)
for i in range(1, enc_dec_layers + 1):
conv_config = {}
Expand All @@ -63,8 +67,9 @@ def init_dec_conv_configs(conv_factor, img_size, enc_dec_layers, pool_layers, st
else:
conv_config['out_channels'] = conv_factor // (2 ** i)
conv_config['skip_channels'] = cfg.skip_layers * conv_factor // (2 ** i)
conv_config['img_size'] = [size // (2 ** (enc_dec_layers - i)) for size in img_size]
conv_config['rec_size'] = [size // 2 for size in conv_config['img_size']]
conv_config['img_size'] = [size // (2 ** (enc_dec_layers - i)) if size % (2 ** (enc_dec_layers - i)) == 0
else size // (2 ** (enc_dec_layers - i)) + 1 for size in img_size]
conv_config['rec_size'] = [size // 2 if size % 2 == 0 else size // 2 + 1 for size in conv_config['img_size']]
conv_configs.append(conv_config)
return conv_configs

Expand Down
24 changes: 12 additions & 12 deletions climatereconstructionai/utils/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def infill(model, dataset, eval_path, output_names, data_stats, xr_dss, i_model)
data_dict = {'image': [], 'mask': [], 'gt': [], 'output': [], 'infilled': []}

for split in tqdm(range(dataset.__len__())):

# TODO: implement evaluation for multiple data paths
data_dict["image"], data_dict["mask"], data_dict["gt"], index = next(dataset)

if split == 0 and cfg.create_graph:
Expand Down Expand Up @@ -140,7 +140,7 @@ def infill(model, dataset, eval_path, output_names, data_stats, xr_dss, i_model)
return output_names


def create_outputs(data_dict, eval_path, output_names, data_stats, xr_dss, i_model, split, index):
def create_outputs(data_dict, eval_path, output_names, data_stats, xr_dss, i_model, split, index, ds_index=0):

m_label = "." + str(i_model)
suffix = m_label + "-" + str(split + 1)
Expand Down Expand Up @@ -168,23 +168,23 @@ def create_outputs(data_dict, eval_path, output_names, data_stats, xr_dss, i_mod

output_names[rootname][i_model] += [rootname + suffix + ".nc"]

ds = xr_dss[i_data][1].copy()
ds = xr_dss[i_data][ds_index][1].copy()

if cfg.normalize_data and cname != "mask":
data_dict[cname][:, j, :, :] = renormalize(data_dict[cname][:, j, :, :],
data_stats["mean"][i_data], data_stats["std"][i_data])

ds[data_type] = xr.DataArray(data_dict[cname].to(torch.device('cpu')).detach().numpy()[:, j, :, :],
dims=xr_dss[i_data][2], coords=xr_dss[i_data][3])
ds["time"] = xr_dss[i_data][0]["time"].values[index]
dims=xr_dss[i_data][ds_index][2], coords=xr_dss[i_data][ds_index][3])
ds["time"] = xr_dss[i_data][ds_index][0]["time"].values[index]

ds = reformat_dataset(xr_dss[i_data][0], ds, data_type)
ds = reformat_dataset(xr_dss[i_data][ds_index][0], ds, data_type)

for var in xr_dss[i_data][0].keys():
if "time" in xr_dss[i_data][0][var].dims:
ds[var] = xr_dss[i_data][0][var].isel(time=index)
for var in xr_dss[i_data][ds_index][0].keys():
if "time" in xr_dss[i_data][ds_index][0][var].dims:
ds[var] = xr_dss[i_data][ds_index][0][var].isel(time=index)
else:
ds[var] = xr_dss[i_data][0][var]
ds[var] = xr_dss[i_data][ds_index][0][var]

ds.attrs["history"] = "Infilled using CRAI (Climate Reconstruction AI: " \
"https://github.com/FREVA-CLINT/climatereconstructionAI)\n" + ds.attrs["history"]
Expand All @@ -193,8 +193,8 @@ def create_outputs(data_dict, eval_path, output_names, data_stats, xr_dss, i_mod
for time_step in cfg.plot_results:
if time_step in index:
output_name = '{}_{}{}_{}.png'.format(eval_path[j], "combined", m_label, time_step)
plot_data(xr_dss[i_data][1].coords,
plot_data(xr_dss[i_data][ds_index][1].coords,
[data_dict[p][time_step - index[0], j, :, :].squeeze() for p in pnames],
["Original", "Reconstructed"], output_name, data_type,
str(xr_dss[i_data][0]["time"][time_step].values),
str(xr_dss[i_data][ds_index][0]["time"][time_step].values),
*cfg.dataset_format["scale"])
74 changes: 54 additions & 20 deletions climatereconstructionai/utils/netcdfloader.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import os
import random

Expand Down Expand Up @@ -94,18 +95,32 @@ def load_netcdf(path, data_names, data_types, keep_dss=False):
if data_names is None:
return None, None, None
else:
ndata = len(data_names)
assert ndata == len(data_types)
assert len(data_names) == len(data_types)
if all('.nc' in data_name or '.h5' in data_name for data_name in data_names):
data_paths = [['{}{}'.format(path, data_names[i])] for i in range(len(data_names))]
elif all('.json' in data_name for data_name in data_names):
data_paths = []
for data_name in data_names:
with open('{}/{}'.format(path, data_name)) as json_file:
data_paths.append(json.load(json_file))
else:
raise ValueError('Unsupported filetype. All data names must uniformly contain ".nc", ".h5" or ".json".')

dss, data, lengths, sizes = zip(*[nc_loadchecker('{}{}'.format(path, data_names[i]),
data_types[i]) for i in range(ndata)])
dss, data, lengths, sizes = [], [], [], []
for i in range(len(data_paths)):
dss_list, data_list, length, size = zip(*[nc_loadchecker(data_paths[i][j], data_types[i])
for j in range(len(data_paths[i]))])
dss.append(dss_list)
data.append(data_list)
lengths.append(length)
sizes.append(size[0])

assert len(set(lengths)) == 1
assert len(set(lengths)) == 1

if keep_dss:
return dss, data, lengths[0], sizes
else:
return data, lengths[0], sizes
if keep_dss:
return dss, data, lengths[0], sizes
else:
return data, lengths[0], sizes


class NetCDFLoader(Dataset):
Expand Down Expand Up @@ -146,15 +161,15 @@ def __init__(self, data_root, img_names, mask_root, mask_names, split, data_type

self.bounds = bnd_normalization(self.img_mean, self.img_std)

def load_data(self, ind_data, img_indices, mask_indices):
def load_data(self, ind_data, img_indices, ds_index, mask_indices, mask_ds_index):

if self.mask_data is None:
# Get masks from images
image = np.array(self.img_data[ind_data][mask_indices])
image = np.array(self.img_data[ind_data][mask_ds_index][mask_indices])
mask = torch.from_numpy((1 - (np.isnan(image))).astype(image.dtype))
else:
mask = torch.from_numpy(np.array(self.mask_data[ind_data][mask_indices]))
image = np.array(self.img_data[ind_data][img_indices])
mask = torch.from_numpy(np.array(self.mask_data[ind_data][mask_ds_index][mask_indices]))
image = np.array(self.img_data[ind_data][ds_index][img_indices])
image = torch.from_numpy(np.nan_to_num(image))

if cfg.normalize_data:
Expand All @@ -163,19 +178,38 @@ def load_data(self, ind_data, img_indices, mask_indices):
return image, mask

def get_single_item(self, ind_data, index, shuffle_masks):
# get index of dataset
ds_index = 0
current_index = 0
for l in range(len(self.img_length)):
if index > current_index + self.img_length[l]:
current_index += self.img_length[l]
ds_index += 1
index -= current_index

# define range of lstm or prev-next steps -> adjust, if out of boundaries
img_indices = np.array(list(range(index - self.time_steps[0], index + self.time_steps[1] + 1)))
img_indices[img_indices < 0] = 0
img_indices[img_indices > self.img_length - 1] = self.img_length - 1
img_indices[img_indices > self.img_length[ds_index] - 1] = self.img_length[ds_index] - 1
if shuffle_masks:
mask_indices = []
for j in range(self.n_time_steps):
mask_indices.append(self.random.randint(0, self.mask_length - 1))
mask_indices = sorted(mask_indices)
mask_index = self.random.randint(0, sum(self.mask_length) - 1)
mask_ds_index = 0
current_index = 0
for l in range(len(self.mask_length)):
if mask_index > current_index + self.mask_length[l]:
current_index += self.mask_length[l]
mask_ds_index += 1
mask_index -= current_index

# define range of lstm or prev-next steps -> adjust, if out of boundaries
mask_indices = np.array(list(range(mask_index - self.time_steps[0], mask_index + self.time_steps[1] + 1)))
mask_indices[mask_indices < 0] = 0
mask_indices[mask_indices > self.mask_length[mask_ds_index] - 1] = self.mask_length[mask_ds_index] - 1
else:
mask_indices = img_indices
mask_ds_index = ds_index
# load data from ranges
images, masks = self.load_data(ind_data, img_indices, mask_indices)
images, masks = self.load_data(ind_data, img_indices, ds_index, mask_indices, mask_ds_index)

# stack to correct dimensions
images = torch.stack([images], dim=1)
Expand Down Expand Up @@ -209,4 +243,4 @@ def __getitem__(self, index):
return torch.cat(masked, dim=1), torch.cat(masks, dim=1), torch.cat(images, dim=1), index

def __len__(self):
return self.img_length
return sum(self.img_length)
4 changes: 4 additions & 0 deletions data/train/20cr.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
[
"data/train/20cr-1ens.nc",
"data/train/20cr-1ens.nc"
]
4 changes: 4 additions & 0 deletions data/val/20cr.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
[
"data/val/20cr-1ens.nc",
"data/val/20cr-1ens.nc"
]
20 changes: 20 additions & 0 deletions tests/in/training/json-input.inp
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
--device cpu --batch-size 2
--n-threads 4
--data-root-dir data/
--mask-dir data/
--log-dir tests/out/training/logs/
--snapshot-dir tests/out/training/
--data-names 20cr.json,20cr.json
--data-types tas,tas
--encoding-layers 2,2
--pooling-layers 1,1
--mask-names hadcrut_missmask_1.nc,hadcrut_missmask_1.nc
--max-iter 10
--loss-criterion 0
--log-interval 1
--weights kaiming
--loop-random-seed 3
--cuda-random-seed 3
--deterministic
--shuffle-masks
--channel-steps 2,2
Binary file added tests/ref/json-input.inp.pth
Binary file not shown.
2 changes: 1 addition & 1 deletion tests/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@


@pytest.mark.training
@pytest.mark.parametrize("file", os.listdir(testdir + "in/training/"))
@pytest.mark.parametrize("file", sorted(os.listdir(testdir + "in/training/")))
def test_training_run(file):
from climatereconstructionai import train
train('{}in/training/{}'.format(testdir, file))
Expand Down

0 comments on commit 2dbcd01

Please sign in to comment.