Skip to content

Commit

Permalink
formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
jafermarq committed Jan 17, 2024
1 parent 5b4e20f commit cd04f6c
Showing 1 changed file with 10 additions and 10 deletions.
20 changes: 10 additions & 10 deletions examples/quickstart-monai/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@

def _partition(files_list, labels_list, num_shards, index):
total_size = len(files_list)
assert total_size == len(labels_list), f"List of datapoints and labels must be of the same length"
assert total_size == len(
labels_list
), f"List of datapoints and labels must be of the same length"
shard_size = total_size // num_shards

# Calculate start and end indices for the shard
Expand All @@ -30,17 +32,18 @@ def _partition(files_list, labels_list, num_shards, index):
end_idx = start_idx + shard_size

# Create a subset for the shard
files = files_list[start_idx: end_idx]
labels = labels_list[start_idx: end_idx]
files = files_list[start_idx:end_idx]
labels = labels_list[start_idx:end_idx]
return files, labels


def load_data(num_shards, index):
image_file_list, image_label_list, num_total, num_class = _download_data()

# get partition given index
files_list, labels_list = _partition(image_file_list, image_label_list, num_shards, index)
image_file_list, image_label_list, _, num_class = _download_data()

# Get partition given index
files_list, labels_list = _partition(
image_file_list, image_label_list, num_shards, index
)

trainX, trainY, valX, valY, testX, testY = _split_data(
files_list, labels_list, len(files_list)
Expand All @@ -57,9 +60,6 @@ def load_data(num_shards, index):
test_loader = DataLoader(test_ds, batch_size=300)

return train_loader, val_loader, test_loader, num_class
test_loader = DataLoader(test_ds, batch_size=300)

return train_loader, val_loader, test_loader, num_class


class MedNISTDataset(Dataset):
Expand Down

0 comments on commit cd04f6c

Please sign in to comment.