From f76a1065e0dd8da154af13590b2f70deeaff6bdb Mon Sep 17 00:00:00 2001 From: savill <65285466+Cryaaa@users.noreply.github.com> Date: Tue, 24 Sep 2024 14:01:41 +0200 Subject: [PATCH] move function up since I need it multiple times --- CPC/training_rs_masked_only_BF.py | 10 +--------- CPC/training_rs_masked_only_BRA.py | 11 +---------- src/embed_time/transforms.py | 9 ++++++++- 3 files changed, 10 insertions(+), 20 deletions(-) diff --git a/CPC/training_rs_masked_only_BF.py b/CPC/training_rs_masked_only_BF.py index 9f9ca49..975785a 100644 --- a/CPC/training_rs_masked_only_BF.py +++ b/CPC/training_rs_masked_only_BF.py @@ -5,7 +5,7 @@ import train import torchvision.transforms as trans from torchvision.transforms import v2 -from embed_time.transforms import CustomToTensor, CropAndReshapeTL +from embed_time.transforms import CustomToTensor, CropAndReshapeTL, SelectChannel from embed_time.dataloader_rs import LiveTLSDatasetPairedOutput from datetime import datetime from pathlib import Path @@ -24,14 +24,6 @@ lr = 4e-4 -class SelectChannel(object): - def __init__(self,channel,channel_dim = 0): - self.channel = channel - self.channel_dim = channel_dim - - def __call__(self, sample): - return np.take(sample,[self.channel],axis=self.channel_dim) - model_training_settings = dict( GPU = GPU, patience = patience, diff --git a/CPC/training_rs_masked_only_BRA.py b/CPC/training_rs_masked_only_BRA.py index 48a0650..5137284 100644 --- a/CPC/training_rs_masked_only_BRA.py +++ b/CPC/training_rs_masked_only_BRA.py @@ -5,7 +5,7 @@ import train import torchvision.transforms as trans from torchvision.transforms import v2 -from embed_time.transforms import CustomToTensor, CropAndReshapeTL +from embed_time.transforms import CustomToTensor, CropAndReshapeTL, SelectChannel from embed_time.dataloader_rs import LiveTLSDatasetPairedOutput from datetime import datetime from pathlib import Path @@ -23,15 +23,6 @@ in_channels =1 lr = 4e-4 - -class SelectChannel(object): - def __init__(self,channel,channel_dim = 0): - self.channel = channel - self.channel_dim = channel_dim - - def __call__(self, sample): - return np.take(sample,[self.channel],axis=self.channel_dim) - model_training_settings = dict( GPU = GPU, patience = patience, diff --git a/src/embed_time/transforms.py b/src/embed_time/transforms.py index 2a31bf9..ef528cf 100644 --- a/src/embed_time/transforms.py +++ b/src/embed_time/transforms.py @@ -109,13 +109,20 @@ def __call__(self, sample): ] return sample[slice_objects] +class SelectChannel(object): + def __init__(self,channel,channel_dim = 0): + self.channel = channel + self.channel_dim = channel_dim + + def __call__(self, sample): + return np.take(sample,[self.channel],axis=self.channel_dim) + class SelectRandomTPNumpy(object): """select a random timepoint form the time series time_dimension: int dimension index of time """ - def __init__(self, time_dimension): self.td = time_dimension