Skip to content

Commit

Permalink
move function up since I need it multiple times
Browse files Browse the repository at this point in the history
  • Loading branch information
Cryaaa committed Sep 24, 2024
1 parent 60dbaab commit f76a106
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 20 deletions.
10 changes: 1 addition & 9 deletions CPC/training_rs_masked_only_BF.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import train
import torchvision.transforms as trans
from torchvision.transforms import v2
from embed_time.transforms import CustomToTensor, CropAndReshapeTL
from embed_time.transforms import CustomToTensor, CropAndReshapeTL, SelectChannel
from embed_time.dataloader_rs import LiveTLSDatasetPairedOutput
from datetime import datetime
from pathlib import Path
Expand All @@ -24,14 +24,6 @@
lr = 4e-4


class SelectChannel(object):
def __init__(self,channel,channel_dim = 0):
self.channel = channel
self.channel_dim = channel_dim

def __call__(self, sample):
return np.take(sample,[self.channel],axis=self.channel_dim)

model_training_settings = dict(
GPU = GPU,
patience = patience,
Expand Down
11 changes: 1 addition & 10 deletions CPC/training_rs_masked_only_BRA.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import train
import torchvision.transforms as trans
from torchvision.transforms import v2
from embed_time.transforms import CustomToTensor, CropAndReshapeTL
from embed_time.transforms import CustomToTensor, CropAndReshapeTL, SelectChannel
from embed_time.dataloader_rs import LiveTLSDatasetPairedOutput
from datetime import datetime
from pathlib import Path
Expand All @@ -23,15 +23,6 @@
in_channels =1
lr = 4e-4


class SelectChannel(object):
def __init__(self,channel,channel_dim = 0):
self.channel = channel
self.channel_dim = channel_dim

def __call__(self, sample):
return np.take(sample,[self.channel],axis=self.channel_dim)

model_training_settings = dict(
GPU = GPU,
patience = patience,
Expand Down
9 changes: 8 additions & 1 deletion src/embed_time/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,13 +109,20 @@ def __call__(self, sample):
]
return sample[slice_objects]

class SelectChannel(object):
def __init__(self,channel,channel_dim = 0):
self.channel = channel
self.channel_dim = channel_dim

def __call__(self, sample):
return np.take(sample,[self.channel],axis=self.channel_dim)

class SelectRandomTPNumpy(object):
"""select a random timepoint form the time series
time_dimension: int
dimension index of time
"""

def __init__(self, time_dimension):
self.td = time_dimension

Expand Down

0 comments on commit f76a106

Please sign in to comment.