From 51771c459b5ea6c4a76473c7a8f45da133aed17f Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Wed, 26 Jun 2024 16:30:45 +0800 Subject: [PATCH] fix: fill missing values in empirical_mean with 0, make grud more robust; --- pypots/classification/grud/data.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pypots/classification/grud/data.py b/pypots/classification/grud/data.py index fc23132e..3287a6f6 100644 --- a/pypots/classification/grud/data.py +++ b/pypots/classification/grud/data.py @@ -63,6 +63,8 @@ def __init__( self.empirical_mean = torch.sum( self.missing_mask * self.X, dim=[0, 1] ) / torch.sum(self.missing_mask, dim=[0, 1]) + # fill nan with 0, in case some features have no observations + self.empirical_mean = torch.nan_to_num(self.empirical_mean, 0) def _fetch_data_from_array(self, idx: int) -> Iterable: """Fetch data according to index.