Skip to content

Commit

Permalink
add dataset class with subset functionality
Browse files Browse the repository at this point in the history
  • Loading branch information
namsaraeva committed Jun 5, 2024
1 parent 00c5590 commit 4680b57
Showing 1 changed file with 123 additions and 0 deletions.
123 changes: 123 additions & 0 deletions src/sparcscore/ml/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,129 @@ def stats(self):
def __len__(self):
return len(self.data_locator) # return length of data locator

def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist() # convert tensor to list

data_item = self.data_locator[idx] # get the data info for the current index, such as target, handle id, and row

if self.select_channel is not None: # select a specific channel
cell_tensor = self.handle_list[data_item[1]][data_item[2], self.select_channel]
t = torch.from_numpy(cell_tensor).float() # convert to float tensor
t = torch.unsqueeze(t, 0) # add channel dimension to tensor
else:
cell_tensor = self.handle_list[data_item[1]][data_item[2]]
t = torch.from_numpy(cell_tensor).float() # convert to float tensor

if self.transform:
t = self.transform(t) # apply transformation to the data

target = torch.tensor(data_item[0], dtype=torch.float) # get target value

if self.return_id:
ids = int(data_item[3])
sample = (t, target, torch.tensor(ids)) # return data, target, and id
elif self.return_fake_id:
sample = (t, target, torch.tensor(0)) # return data, target, and fake id
else:
sample = (t, target) # return data and target

return sample


class HDF5SingleCellDatasetRegression_Subset(Dataset):
"""
Class for handling SPARCSpy single cell datasets stored in HDF5 files for regression tasks.
Supports selecting a subset of the data based on given indices.
"""
HDF_FILETYPES = ["hdf", "hf", "h5", "hdf5"] # supported hdf5 filetypes

def __init__(self,
dir_list: list[str],
target_col: list[int],
index_list: list[int],
hours: False,
root_dir: str,
max_level: int = 5,
transform = None,
return_id: bool = False,
return_fake_id: bool = False,
select_channel = None):

self.dir_list = dir_list
self.target_col = target_col
self.index_list = index_list # list of indices to select
self.hours = hours
self.root_dir = root_dir
self.transform = transform
self.select_channel = select_channel

self.handle_list = []
self.data_locator = []

# scan all directories in dir_list
for i, directory in enumerate(dir_list):
path = os.path.join(self.root_dir, directory) # get full path

target_col = self.target_col[i] # get the target column for the current directory

filetype = directory.split(".")[-1] # get filetype

if filetype in self.HDF_FILETYPES: # check if filetype is supported
self.add_hdf_to_index(path, target_col) # add hdf5 files to index

else:
self.scan_directory(path, target_col, max_level) # recursively scan for files

self.return_id = return_id # return id
self.return_fake_id = return_fake_id # return fake id
self.stats() # print dataset stats at the end

def add_hdf_to_index(self, path, target_col):
try:
input_hdf = h5py.File(path, 'r') # read hdf5 file

index_handle = input_hdf.get('single_cell_index')[self.index_list] # get single cell index handle

current_target_col = input_hdf.get('single_cell_index_labelled')[self.index_list].asstr()[:, target_col] # get target column
current_target_col[current_target_col == ''] = np.nan # replace empty values with nan
current_target_col = current_target_col.astype(float) # convert to float for regression

handle_id = len(self.handle_list) # get handle id
self.handle_list.append(input_hdf.get('single_cell_data')[self.index_list]) # append data handle (i.e. extracted images)

for current_target, row in zip(current_target_col, index_handle): # iterate over rows in index handle, i.e. over all cells
if self.hours:
current_target = current_target / 3600 # convert seconds to hours
self.data_locator.append([current_target, handle_id] + list(row)) # append target, handle id, and row to data locator
except:
return

def scan_directory(self, path, target_col, levels_left):
if levels_left > 0: # iterate over all files and folders in a directory if levels_left > 0
current_level_directories = [os.path.join(path, name) for name in os.listdir(path) if os.path.isdir(os.path.join(path, name))] # get directories
current_level_files = [ name for name in os.listdir(path) if os.path.isfile(os.path.join(path, name))] # get files

for i, file in enumerate(current_level_files): # iterate over files from current level
filetype = file.split(".")[-1] # get filetypes

if filetype in self.HDF_FILETYPES:
self.add_hdf_to_index(os.path.join(path, file), target_col) # add hdf5 files to index if filetype is supported

for subdirectory in current_level_directories: # recursively scan subdirectories
self.scan_directory(subdirectory, target_col, levels_left - 1)
else:
return

def stats(self):
targets = [info[0] for info in self.data_locator] # get all targets from data locator
targets = np.array(targets, dtype=float) # convert to numpy array

print(f"Total samples: {len(targets)}")

def __len__(self):
return len(self.data_locator) # return length of data locator

def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist() # convert tensor to list
Expand Down

0 comments on commit 4680b57

Please sign in to comment.