Skip to content

Commit

Permalink
collate
Browse files Browse the repository at this point in the history
  • Loading branch information
igunduz committed Jul 30, 2023
1 parent ae4c737 commit 919e34e
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 812 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ __pycache__/

# C extensions
*.so
data/

# Distribution / packaging
irem.sbatch
Expand Down
25 changes: 25 additions & 0 deletions src/data_loaders/AffordanceDataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,31 @@ def parse_object_labels(object_filename):
# Convert the data to a PyTorch tensor
return torch.tensor(object_labels)

def collate_fn(batch):
images = [item["image"] for item in batch]
affordances_labels = [item["affordances_labels"] for item in batch]
encoded_inputs = [item["encoded_input"] for item in batch]

# Pad images to the same size
max_width = max(image.shape[-1] for image in images)
max_height = max(image.shape[-2] for image in images)

padded_images = [
F.pad(image, (0, max_width - image.shape[-1], 0, max_height - image.shape[-2]))
for image in images
]
padded_images = torch.stack(padded_images)

# Stack the affordances labels
for k in encoded_inputs[0].keys():
encoded_inputs[k] = torch.stack([item[k] for item in encoded_inputs])

return {
"image": padded_images,
"affordances_labels": torch.stack(affordances_labels),
"encoded_input": encoded_inputs
}

class AffordanceDataset(Dataset):
def __init__(self, root_dir, split_file, feature_extractor=None, transform=None):
self.root_dir = root_dir
Expand Down
852 changes: 40 additions & 812 deletions src/notebooks/exercise.ipynb

Large diffs are not rendered by default.

0 comments on commit 919e34e

Please sign in to comment.