Skip to content

Commit

Permalink
Merge branch 'master' into master
Browse files Browse the repository at this point in the history
  • Loading branch information
1Pravi authored Aug 11, 2023
2 parents 13b0cc0 + 616b37b commit 4426030
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 6 deletions.
4 changes: 2 additions & 2 deletions GANDLF/cli/generate_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,14 +301,14 @@ def __percentile_clip(input_tensor, reference_tensor=None, p_min=0.5, p_max=99.5
overall_stats_dict[current_subject_id][
"psnr_01"
] = peak_signal_noise_ratio(
gt_image_infill, output_infill, data_range=1.0
gt_image_infill, output_infill, data_range=(0,1)
).item()

# same as above but with epsilon for robustness
overall_stats_dict[current_subject_id][
"psnr_01_eps"
] = peak_signal_noise_ratio(
gt_image_infill, output_infill, epsilon=sys.float_info.epsilon
gt_image_infill, output_infill, data_range=(0,1), epsilon=sys.float_info.epsilon
).item()

pprint(overall_stats_dict)
Expand Down
9 changes: 5 additions & 4 deletions GANDLF/metrics/synthesis.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,20 +55,21 @@ def peak_signal_noise_ratio(target, prediction, data_range=None, epsilon=None) -
Args:
target (torch.Tensor): The target tensor.
prediction (torch.Tensor): The prediction tensor.
data_range (float, optional): If not None, this data range is used as enumerator instead of computing it from the given data. Defaults to None.
data_range (tuple, optional): If not None, this data range (min, max) is used as enumerator instead of computing it from the given data. Defaults to None.
epsilon (float, optional): If not None, this epsilon is added to the denominator of the fraction to avoid infinity as output. Defaults to None.
"""

if epsilon == None:
psnr = PeakSignalNoiseRatio(data_range=data_range)
psnr = PeakSignalNoiseRatio() if data_range == None else PeakSignalNoiseRatio(data_range=data_range[1]-data_range[0])
return psnr(preds=prediction, target=target)
else: # implementation of PSNR that does not give 'inf'/'nan' when 'mse==0'
mse = mean_squared_error(target, prediction)
if data_range == None: #compute data_range like torchmetrics if not given
min_v = 0 if torch.min(target) > 0 else torch.min(target) #look at this line
max_v = torch.max(target)
data_range = max_v - min_v
return 10.0 * torch.log10((data_range ** 2) / (mse + epsilon))
else:
min_v, max_v = data_range
return 10.0 * torch.log10(((max_v-min_v) ** 2) / (mse + epsilon))


def mean_squared_log_error(target, prediction) -> torch.Tensor:
Expand Down

0 comments on commit 4426030

Please sign in to comment.