diff --git a/didactic/tasks/cardiac_multimodal_representation.py b/didactic/tasks/cardiac_multimodal_representation.py index 4cf1dee6..0622b89a 100644 --- a/didactic/tasks/cardiac_multimodal_representation.py +++ b/didactic/tasks/cardiac_multimodal_representation.py @@ -190,7 +190,7 @@ def __init__( ) # Hyperparameter to easily access target attributes for attr in self.predict_losses: if attr in ClinicalAttribute.numerical_attrs(): - self.metrics[attr] = {"mae": functools.partial(mean_absolute_error)} + self.metrics[attr] = {"mae": mean_absolute_error} elif attr in ClinicalAttribute.binary_attrs(): self.metrics[attr] = {"acc": functools.partial(accuracy, task="binary")} else: # attr in ClinicalAttribute.categorical_attrs()