diff --git a/mofax/core.py b/mofax/core.py index d33452d..787f067 100644 --- a/mofax/core.py +++ b/mofax/core.py @@ -1330,9 +1330,9 @@ def get_sample_r2( views: Optional[Union[str, int, List[str], List[int]]] = None, df: bool = True, ) -> pd.DataFrame: - findices, factors = self.__check_factors(factors, unique=True) - groups = self.__check_groups(groups) - views = self.__check_views(views) + findices, factors = self._check_factors(factors, unique=True) + groups = self._check_groups(groups) + views = self._check_views(views) r2s = [] for view in views: diff --git a/mofax/utils.py b/mofax/utils.py index 9587335..f38c619 100644 --- a/mofax/utils.py +++ b/mofax/utils.py @@ -17,6 +17,7 @@ def _load_samples_metadata(model): ], columns=["sample", "group"], ) + if "samples_metadata" in model.model: if len(list(model.model["samples_metadata"][model.groups[0]].keys())) > 0: _samples_metadata = pd.concat( @@ -41,23 +42,36 @@ def _load_samples_metadata(model): if "sample" in _samples_metadata.columns: del _samples_metadata["sample"] - samples_metadata = pd.concat( - [ - samples_metadata.reset_index(drop=True), - _samples_metadata.reset_index(drop=True), - ], - axis=1, - ) - # Decode objects as UTF-8 strings - for column in samples_metadata.columns: - if samples_metadata[column].dtype == "object": - try: - samples_metadata[column] = [ - i.decode() for i in samples_metadata[column].values - ] - except (UnicodeDecodeError, AttributeError): - pass + for df in [samples_metadata, _samples_metadata]: + for column in df.columns: + if df[column].dtype == "object": + try: + df[column] = [i.decode() for i in df[column].values] + except (UnicodeDecodeError, AttributeError): + pass + + if "sample" in _samples_metadata.columns: + + samples_metadata = pd.merge( + left=samples_metadata, + left_on="sample", + right=_samples_metadata, + right_on="index", + ) + + if "index" in samples_metadata.columns: + del samples_metadata["index"] + + else: + + samples_metadata = pd.concat( + [ + samples_metadata.reset_index(drop=True), + _samples_metadata.reset_index(drop=True), + ], + axis=1, + ) samples_metadata = samples_metadata.set_index("sample") return samples_metadata @@ -230,7 +244,7 @@ def _make_iterable(x): def calculate_r2(Z, W, Y): a = np.nansum((Y - Z.T.dot(W)) ** 2.0) - b = np.nansum(Y ** 2) + b = np.nansum(Y**2) r2 = (1.0 - a / b) * 100 return r2