From f272a9b3e15067951886dbffb903057ceafc8d54 Mon Sep 17 00:00:00 2001 From: jameschapman19 Date: Thu, 11 May 2023 23:50:14 +0100 Subject: [PATCH] b --- multiviewdata/test/test_outputs.py | 13 +++++++++++++ multiviewdata/torchdatasets/mnist.py | 6 +++--- multiviewdata/torchdatasets/xrmb.py | 15 ++++++++++----- 3 files changed, 26 insertions(+), 8 deletions(-) diff --git a/multiviewdata/test/test_outputs.py b/multiviewdata/test/test_outputs.py index 0e7f425..c321440 100644 --- a/multiviewdata/test/test_outputs.py +++ b/multiviewdata/test/test_outputs.py @@ -21,6 +21,15 @@ def test_splitmnist(): from multiviewdata.torchdatasets import SplitMNIST a = SplitMNIST(os.getcwd(), download=True)[0] + b=a[0] + assert "index" in a + assert "views" in a + +def test_xrmb(): + from multiviewdata.torchdatasets import XRMB + + a = XRMB(os.getcwd(), download=True) + b=a[0] assert "index" in a assert "views" in a @@ -31,3 +40,7 @@ def test_mfeat(): a = MFeat(os.getcwd(), download=True)[0] assert "index" in a assert "views" in a + +if __name__ == '__main__': + test_splitmnist() + test_xrmb() \ No newline at end of file diff --git a/multiviewdata/torchdatasets/mnist.py b/multiviewdata/torchdatasets/mnist.py index 9e7513b..47642ca 100644 --- a/multiviewdata/torchdatasets/mnist.py +++ b/multiviewdata/torchdatasets/mnist.py @@ -5,7 +5,7 @@ from torch.utils.data import Dataset from torchvision import datasets from torchvision import transforms - +from PIL import Image class SplitMNIST(Dataset): """ @@ -35,8 +35,8 @@ def __len__(self): def __getitem__(self, idx): x_a, label = self.dataset[idx] - x_b = x_a[:, :, 14:] - x_a = x_a[:, :, :14] + x_b = x_a[:, :, 14:]/255. + x_a = x_a[:, :, :14]/255. if self.flatten: x_a = torch.flatten(x_a) x_b = torch.flatten(x_b) diff --git a/multiviewdata/torchdatasets/xrmb.py b/multiviewdata/torchdatasets/xrmb.py index 4dd3e7c..70ed09d 100644 --- a/multiviewdata/torchdatasets/xrmb.py +++ b/multiviewdata/torchdatasets/xrmb.py @@ -4,7 +4,7 @@ from torch.utils.data import Dataset from torchvision.datasets.utils import download_url import numpy as np - +from sklearn.preprocessing import StandardScaler class XRMB(Dataset): """ @@ -63,12 +63,17 @@ def __init__( loadmat(view_1_file), loadmat(view_2_file), ) + + scaler_1= StandardScaler().fit(view_1["X1"]) + scaler_2= StandardScaler().fit(view_2["X2"]) + if train: - view_1=view_1["X1"] - view_2=view_2["X2"] + view_1=scaler_1.transform(view_1["X1"]) + view_2=scaler_2.transform(view_2["X2"]) else: - view_1 = view_1["XTe1"] - view_2 = view_2["XTe2"] + view_1 = scaler_1.transform(view_1["XTe1"]) + view_2 = scaler_2.transform(view_2["XTe2"]) + self.dataset = dict(view_1=view_1, view_2=view_2) @property