Skip to content
This repository has been archived by the owner on Oct 19, 2023. It is now read-only.

Commit

Permalink
b
Browse files Browse the repository at this point in the history
  • Loading branch information
jameschapman19 committed May 11, 2023
1 parent 077c99b commit f272a9b
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 8 deletions.
13 changes: 13 additions & 0 deletions multiviewdata/test/test_outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,15 @@ def test_splitmnist():
from multiviewdata.torchdatasets import SplitMNIST

a = SplitMNIST(os.getcwd(), download=True)[0]
b=a[0]
assert "index" in a
assert "views" in a

def test_xrmb():
from multiviewdata.torchdatasets import XRMB

a = XRMB(os.getcwd(), download=True)
b=a[0]
assert "index" in a
assert "views" in a

Expand All @@ -31,3 +40,7 @@ def test_mfeat():
a = MFeat(os.getcwd(), download=True)[0]
assert "index" in a
assert "views" in a

if __name__ == '__main__':
test_splitmnist()
test_xrmb()
6 changes: 3 additions & 3 deletions multiviewdata/torchdatasets/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision import transforms

from PIL import Image

class SplitMNIST(Dataset):
"""
Expand Down Expand Up @@ -35,8 +35,8 @@ def __len__(self):

def __getitem__(self, idx):
x_a, label = self.dataset[idx]
x_b = x_a[:, :, 14:]
x_a = x_a[:, :, :14]
x_b = x_a[:, :, 14:]/255.
x_a = x_a[:, :, :14]/255.
if self.flatten:
x_a = torch.flatten(x_a)
x_b = torch.flatten(x_b)
Expand Down
15 changes: 10 additions & 5 deletions multiviewdata/torchdatasets/xrmb.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from torch.utils.data import Dataset
from torchvision.datasets.utils import download_url
import numpy as np

from sklearn.preprocessing import StandardScaler

class XRMB(Dataset):
"""
Expand Down Expand Up @@ -63,12 +63,17 @@ def __init__(
loadmat(view_1_file),
loadmat(view_2_file),
)

scaler_1= StandardScaler().fit(view_1["X1"])
scaler_2= StandardScaler().fit(view_2["X2"])

if train:
view_1=view_1["X1"]
view_2=view_2["X2"]
view_1=scaler_1.transform(view_1["X1"])
view_2=scaler_2.transform(view_2["X2"])
else:
view_1 = view_1["XTe1"]
view_2 = view_2["XTe2"]
view_1 = scaler_1.transform(view_1["XTe1"])
view_2 = scaler_2.transform(view_2["XTe2"])

self.dataset = dict(view_1=view_1, view_2=view_2)

@property
Expand Down

0 comments on commit f272a9b

Please sign in to comment.