Skip to content

Commit

Permalink
Merge pull request #26 from FREVA-CLINT/fixes
Browse files Browse the repository at this point in the history
Fixes
  • Loading branch information
faxmitte authored Dec 21, 2023
2 parents ef9e9dc + 109af91 commit f37081c
Show file tree
Hide file tree
Showing 22 changed files with 147 additions and 134 deletions.
1 change: 1 addition & 0 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ jobs:
strategy:
matrix:
python-version: ["3.10"]
fail-fast: false
steps:
- uses: actions/checkout@v2
- name: Setup conda with Python ${{ matrix.python-version }}
Expand Down
2 changes: 1 addition & 1 deletion climatereconstructionai/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def set_common_args():
arg_parser.add_argument('--n-filters', type=int, default=None, help="Number of filters for the first/last layer")
arg_parser.add_argument('--out-channels', type=int, default=1, help="Number of channels for the output data")
arg_parser.add_argument('--dataset-name', type=str, default=None, help="Name of the dataset for format checking")
arg_parser.add_argument('--min-bounds', type=float_list, default="inf",
arg_parser.add_argument('--min-bounds', type=float_list, default="-inf",
help="Comma separated list of values defining the permitted lower-bound of output values")
arg_parser.add_argument('--max-bounds', type=float_list, default="inf",
help="Comma separated list of values defining the permitted upper-bound of output values")
Expand Down
27 changes: 19 additions & 8 deletions climatereconstructionai/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@

def store_encoding(ds):
global encoding
ds = ds.assign_coords({"member": 0})
encoding = ds['time'].encoding
return ds

Expand Down Expand Up @@ -44,6 +43,9 @@ def evaluate(arg_file=None, prog_func=None):
if data_stats is None:
if cfg.normalize_data:
print("* Warning! Using mean and std from current data.")
if cfg.n_target_data != 0:
print("* Warning! Mean and std from target data will be used to renormalize output."
" Mean and std from training data can be used with use_train_stats option.")
data_stats = {"mean": dataset_val.img_mean, "std": dataset_val.img_std}

image_sizes = dataset_val.img_sizes
Expand Down Expand Up @@ -80,20 +82,29 @@ def evaluate(arg_file=None, prog_func=None):
infill(model, iterator_val, eval_path, output_names, data_stats, dataset_val.xr_dss, count)

for name in output_names:
if len(output_names[name]) == 1:
os.rename(output_names[name][0], name + ".nc")
if len(output_names[name]) == 1 and len(output_names[name][1]) == 1:
os.rename(output_names[name][1][0], name + ".nc")
else:
if not cfg.split_outputs:
ds = xr.open_mfdataset(output_names[name], preprocess=store_encoding, autoclose=True, combine='nested',
data_vars='minimal', concat_dim="member", chunks={})
dss = []
for i_model in output_names[name]:
dss.append(xr.open_mfdataset(output_names[name][i_model], preprocess=store_encoding, autoclose=True,
combine='nested', data_vars='minimal', concat_dim="time", chunks={}))
dss[-1] = dss[-1].assign_coords({"member": i_model})

if len(dss) == 1:
ds = dss[-1].drop("member")
else:
ds = xr.concat(dss, dim="member")

ds["member"] = range(1, len(output_names[name]) + 1)
ds['time'].encoding = encoding
ds['time'].encoding['original_shape'] = len(ds["time"])
ds = ds.transpose("time", ...).reset_coords(drop=True)
ds.to_netcdf(name + ".nc")
for output_name in output_names[name]:
os.remove(output_name)

for i_model in output_names[name]:
for output_name in output_names[name][i_model]:
os.remove(output_name)


if __name__ == "__main__":
Expand Down
9 changes: 5 additions & 4 deletions climatereconstructionai/loss/feature_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,15 @@ def __init__(self, extractor):
self.l1 = nn.L1Loss()
self.extractor = extractor

def forward(self, mask, output, gt):
def forward(self, data_dict):
loss_dict = {
'prc': 0.0,
'style': 0.0,
}

# create output_comp
output_comp = mask * gt + (1 - mask) * output
output_comp = data_dict['comp']
output = data_dict['output']
gt = data_dict['gt']

# calculate loss for all channels
for channel in range(output.shape[1]):
Expand All @@ -38,4 +39,4 @@ def forward(self, mask, output, gt):
loss_dict['style'] += self.l1(gram_matrix(feat_output_comp[i]),
gram_matrix(feat_gt[i]))

return loss_dict
return loss_dict
103 changes: 51 additions & 52 deletions climatereconstructionai/loss/get_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,87 +8,86 @@
from ..utils.featurizer import VGG16FeatureExtractor


class ModularizedFunction(torch.nn.Module):
def __init__(self, forward_op):
super().__init__()
self.forward_op = forward_op
def prepare_data_dict(img_mask, loss_mask, output, gt, tensor_keys):
data_dict = dict(zip(list(tensor_keys),[None]*len(tensor_keys)))

def forward(self, *args, **kwargs):
return self.forward_op(*args, **kwargs)
mask = img_mask
loss_mask = img_mask
if loss_mask is not None:
mask += loss_mask
mask[mask < 0] = 0
mask[mask > 1] = 1
assert ((mask == 0) | (mask == 1)).all(), "Not all values in mask are zeros or ones!"

output = output[:, cfg.recurrent_steps, :, :, :]
mask = mask[:, cfg.recurrent_steps, :, :, :]
gt = gt[:, cfg.recurrent_steps, cfg.gt_channels, :, :]

class CriterionParallel(torch.nn.Module):
def __init__(self, criterion):
super().__init__()
if not isinstance(criterion, torch.nn.Module):
criterion = ModularizedFunction(criterion)
self.criterion = torch.nn.DataParallel(criterion)
data_dict['mask'] = mask
data_dict['output'] = output
data_dict['gt'] = gt

def forward(self, *args, **kwargs):
multi_dict = self.criterion(*args, **kwargs)
for key in multi_dict.keys():
multi_dict[key] = multi_dict[key].mean()
return multi_dict
if 'comp' in tensor_keys:
data_dict['comp'] = mask * gt + (1 - mask) * output

return data_dict


class loss_criterion(torch.nn.Module):
def __init__(self):
def __init__(self, lambda_dict):
super().__init__()

self.criterions = torch.nn.ModuleDict()
self.criterions = []
self.tensors = ['output', 'gt', 'mask']
style_added = False

for loss, lambda_ in cfg.lambda_dict.items():
for loss, lambda_ in lambda_dict.items():
if lambda_ > 0:
if loss == 'style' or loss == 'prc':
criterion = FeatureLoss(VGG16FeatureExtractor()).to(cfg.device)
if (loss == 'style' or loss == 'prc') and not style_added:
self.criterions.append(FeatureLoss(VGG16FeatureExtractor()).to(cfg.device))
self.tensors.append('comp')
style_added = True

elif loss == 'valid':
criterion = ValidLoss().to(cfg.device)
self.criterions.append(ValidLoss().to(cfg.device))
self.tensors.append('valid')

elif loss == 'hole':
criterion = HoleLoss().to(cfg.device)
self.criterions.append(HoleLoss().to(cfg.device))
self.tensors.append('hole')

elif loss == 'tv':
criterion = TotalVariationLoss().to(cfg.device)
self.criterions.append(TotalVariationLoss().to(cfg.device))
if 'comp' not in self.tensors:
self.tensors.append('comp')


if criterion not in self.criterions.values():
self.criterions[loss] = criterion
def forward(self, img_mask, loss_mask, output, gt):

def forward(self, mask, output, gt):
data_dict = prepare_data_dict(img_mask, loss_mask, output, gt, self.tensors)

loss_dict = {}
for _, criterion in self.criterions.items():
loss_dict.update(criterion(mask, output, gt))
for criterion in self.criterions:
loss_dict.update(criterion(data_dict))

loss_dict["total"] = 0
for loss, lambda_value in cfg.lambda_dict.items():
if lambda_value > 0:
if lambda_value > 0 and loss in loss_dict.keys():
loss_w_lambda = loss_dict[loss] * lambda_value
loss_dict["total"] += loss_w_lambda
loss_dict[loss] = loss_w_lambda.item()

return loss_dict


class LossComputation():
def __init__(self):
class LossComputation(torch.nn.Module):
def __init__(self, lambda_dict):
super().__init__()
if cfg.multi_gpus:
self.criterion = CriterionParallel(loss_criterion())
self.criterion = torch.nn.DataParallel(loss_criterion(lambda_dict))
else:
self.criterion = loss_criterion()
self.criterion = loss_criterion(lambda_dict)

def get_loss(self, img_mask, loss_mask, output, gt):

mask = img_mask[:, cfg.recurrent_steps, cfg.gt_channels, :, :]

if cfg.n_target_data != 0:
mask = torch.ones_like(mask)

if loss_mask is not None:
mask += loss_mask
mask[mask < 0] = 0
mask[mask > 1] = 1
assert ((mask == 0) | (mask == 1)).all(), "Not all values in mask are zeros or ones!"

loss_dict = self.criterion(mask, output[:, cfg.recurrent_steps, :, :, :],
gt[:, cfg.recurrent_steps, cfg.gt_channels, :, :])

return loss_dict
def forward(self, img_mask, loss_mask, output, gt):
loss_dict = self.criterion(img_mask, loss_mask, output ,gt)
return loss_dict
8 changes: 6 additions & 2 deletions climatereconstructionai/loss/hole_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,15 @@ def __init__(self):
super().__init__()
self.l1 = nn.L1Loss()

def forward(self, mask, output, gt):
def forward(self, data_dict):
loss_dict = {
'hole': 0.0
}

output = data_dict['output']
gt = data_dict['gt']
mask = data_dict['mask']

# calculate loss for all channels
for channel in range(output.shape[1]):
# only select first channel
Expand All @@ -21,4 +25,4 @@ def forward(self, mask, output, gt):

# define different loss functions from output and output_comp
loss_dict['hole'] += self.l1((1 - mask_ch) * output_ch, (1 - mask_ch) * gt_ch)
return loss_dict
return loss_dict
9 changes: 5 additions & 4 deletions climatereconstructionai/loss/total_variation_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,15 @@ def __init__(self):
super().__init__()
self.l1 = nn.L1Loss()

def forward(self, mask, output, gt):
def forward(self, data_dict):
loss_dict = {
'tv': 0.0
}
output_comp = mask * gt + (1 - mask) * output

output_comp = data_dict['comp']

# calculate loss for all channels
for channel in range(output.shape[1]):
for channel in range(output_comp.shape[1]):
output_comp_ch = torch.unsqueeze(output_comp[:, channel, :, :], dim=1)
loss_dict['tv'] += total_variation_loss(output_comp_ch)
return loss_dict
return loss_dict
7 changes: 5 additions & 2 deletions climatereconstructionai/loss/valid_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,13 @@ def __init__(self):
super().__init__()
self.l1 = nn.L1Loss()

def forward(self, mask, output, gt):
def forward(self, data_dict):
loss_dict = {
'valid': 0.0
}
output = data_dict['output']
gt = data_dict['gt']
mask = data_dict['mask']

# calculate loss for all channels
for channel in range(output.shape[1]):
Expand All @@ -21,4 +24,4 @@ def forward(self, mask, output, gt):

# define different loss functions from output and output_comp
loss_dict['valid'] += self.l1(mask_ch * output_ch, mask_ch * gt_ch)
return loss_dict
return loss_dict
58 changes: 19 additions & 39 deletions climatereconstructionai/metrics/get_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,7 @@
import torchmetrics.image as t_metrics

from .. import config as cfg
from ..loss.feature_loss import FeatureLoss
from ..loss.hole_loss import HoleLoss
from ..loss.total_variation_loss import TotalVariationLoss
from ..loss.valid_loss import ValidLoss
from ..utils.featurizer import VGG16FeatureExtractor

from ..loss import get_loss

@torch.no_grad()
def get_metrics(img_mask, loss_mask, output, gt, setname):
Expand Down Expand Up @@ -38,53 +33,38 @@ def get_metrics(img_mask, loss_mask, output, gt, setname):
}
}

mask = img_mask[:, cfg.recurrent_steps, cfg.gt_channels, :, :]

if loss_mask is not None:
mask += loss_mask
mask[mask < 0] = 0
mask[mask > 1] = 1
assert ((mask == 0) | (mask == 1)).all(), "Not all values in mask are zeros or ones!"

metric_dict = {}
if setname == 'train':
metrics = cfg.train_metrics
elif setname == 'val':
metrics = cfg.val_metrics
elif setname == 'test':
metrics = cfg.test_metrics
metrics = cfg.val_metrics

loss_metric_dict = dict(zip(metrics,[1]*len(metrics)))
if 'feature' in metrics:
loss_metric_dict.update(dict(zip(['style', 'prc'],[1,1])))

loss_comp = get_loss.LossComputation(loss_metric_dict)

loss_metrics = loss_comp(img_mask, loss_mask, output, gt)
loss_metrics['total'] = loss_metrics['total'].item()

for metric in metrics:
settings = metric_settings[metric]

if 'valid' in metric:
val_loss = ValidLoss().to(cfg.device)
metric_output = val_loss(mask, output[:, cfg.recurrent_steps, :, :, :],
gt[:, cfg.recurrent_steps, cfg.gt_channels, :, :])
metric_dict[f'metric/{setname}/valid'] = metric_output['valid']
metric_dict[f'metric/{setname}/valid'] = loss_metrics['valid']

elif 'hole' in metric:
val_loss = HoleLoss().to(cfg.device)
metric_output = val_loss(mask, output[:, cfg.recurrent_steps, :, :, :],
gt[:, cfg.recurrent_steps, cfg.gt_channels, :, :])
metric_dict[f'metric/{setname}/hole'] = metric_output['hole']
metric_dict[f'metric/{setname}/hole'] = loss_metrics['hole']

elif 'tv' in metric:
val_loss = TotalVariationLoss().to(cfg.device)
metric_output = val_loss(mask, output[:, cfg.recurrent_steps, :, :, :],
gt[:, cfg.recurrent_steps, cfg.gt_channels, :, :])
metric_dict[f'metric/{setname}/tv'] = metric_output['tv']
metric_dict[f'metric/{setname}/tv'] = loss_metrics['tv']

elif 'feature' in metric:
feat_loss = FeatureLoss(VGG16FeatureExtractor()).to(cfg.device)
metric_output = feat_loss(mask, output[:, cfg.recurrent_steps, :, :, :],
gt[:, cfg.recurrent_steps, cfg.gt_channels, :, :])
metric_dict[f'metric/{setname}/style'] = metric_output['style']
metric_dict[f'metric/{setname}/prc'] = metric_output['prc']
metric_dict[f'metric/{setname}/style'] = loss_metrics['style']
metric_dict[f'metric/{setname}/prc'] = loss_metrics['prc']

else:
metric_outputs = calculate_metric(metric, mask, output[:, cfg.recurrent_steps, :, :, :],
gt[:, cfg.recurrent_steps, cfg.gt_channels, :, :],
data = get_loss.prepare_data_dict(img_mask, loss_mask, output, gt, ['mask','output','gt'])
metric_outputs = calculate_metric(metric, data['mask'], data['output'], data['gt'],
torchmetrics_settings=settings['torchmetric_settings'])

if len(metric_outputs) > 1:
Expand Down Expand Up @@ -150,4 +130,4 @@ def calculate_metric(name_expr, mask, output, gt, domain='valid', torchmetrics_s
result_out += result_ch
result_out = [result_out]

return result_out
return result_out
Loading

0 comments on commit f37081c

Please sign in to comment.