Skip to content

Commit

Permalink
allow a geometry argument to match dataloader schema for DeepForest
Browse files Browse the repository at this point in the history
  • Loading branch information
bw4sz committed Dec 2, 2024
1 parent 478157d commit 07d1d78
Show file tree
Hide file tree
Showing 7 changed files with 193 additions and 55 deletions.
83 changes: 42 additions & 41 deletions docs/examples/Datasets.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@
"text": [
"/blue/ewhite/b.weinstein/miniconda3/envs/MillionTrees/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n",
"/blue/ewhite/b.weinstein/miniconda3/envs/MillionTrees/lib/python3.10/site-packages/albumentations/__init__.py:13: UserWarning: A new version of Albumentations is available: 1.4.20 (you have 1.4.15). Upgrade using: pip install -U albumentations. To disable automatic update checks, set the environment variable NO_ALBUMENTATIONS_UPDATE to 1.\n",
"/blue/ewhite/b.weinstein/miniconda3/envs/MillionTrees/lib/python3.10/site-packages/albumentations/__init__.py:13: UserWarning: A new version of Albumentations is available: 1.4.21 (you have 1.4.15). Upgrade using: pip install -U albumentations. To disable automatic update checks, set the environment variable NO_ALBUMENTATIONS_UPDATE to 1.\n",
" check_for_updates()\n"
]
}
Expand All @@ -85,16 +85,16 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Metadata length: 2\n",
"Image shape: (81, 2), Image type: <class 'numpy.ndarray'>\n",
"Label shape: torch.Size([3, 448, 448]), Label type: <class 'torch.Tensor'>\n"
"Image shape: torch.Size([3, 448, 448]), Image type: <class 'torch.Tensor'>\n",
"Targets keys: dict_keys(['y', 'labels']), Label type: <class 'dict'>\n"
]
}
],
Expand All @@ -104,10 +104,10 @@
"train_dataset = dataset.get_subset(\"train\")\n",
"\n",
"# View the first image in the dataset\n",
"image, label, metadata = train_dataset[0]\n",
"metadata, image, targets = train_dataset[0]\n",
"print(f\"Metadata length: {len(metadata)}\")\n",
"print(f\"Image shape: {image.shape}, Image type: {type(image)}\")\n",
"print(f\"Label shape: {label.shape}, Label type: {type(label)}\")"
"print(f\"Targets keys: {targets.keys()}, Label type: {type(targets)}\")"
]
},
{
Expand All @@ -119,17 +119,17 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Warning: You are loading the entire dataset. Consider using dataset.get_subset('train') for a portion of the dataset if intended.\n",
"Targets is a list of dictionaries with the following keys: dict_keys(['boxes', 'labels'])\n",
"Targets is a list of dictionaries with the following keys: dict_keys(['y', 'labels'])\n",
"Image shape: torch.Size([2, 3, 448, 448]), Image type: <class 'torch.Tensor'>\n",
"Annotation shape of the first image: torch.Size([4, 4])\n"
"Annotation shape of the first image: torch.Size([17, 4])\n"
]
}
],
Expand All @@ -140,20 +140,49 @@
"for metadata, image, targets in train_loader:\n",
" print(\"Targets is a list of dictionaries with the following keys: \", targets[0].keys())\n",
" print(f\"Image shape: {image.shape}, Image type: {type(image)}\")\n",
" print(f\"Annotation shape of the first image: {targets[0]['boxes'].shape}\")\n",
" print(f\"Annotation shape of the first image: {targets[0]['y'].shape}\")\n",
" break # Just show the first batch"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Reformat\n",
"\n",
"DeepForest expects the target tensor have the key 'boxes', not 'y', we can ask MillionTrees for help by passing the geometry name to the dataset class. We do not need to redownload the dataset, we can all TreeBoxesDataset directly."
]
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Warning: You are loading the entire dataset. Consider using dataset.get_subset('train') for a portion of the dataset if intended.\n"
]
}
],
"source": [
"from milliontrees.datasets.TreeBoxes import TreeBoxesDataset\n",
"dataset = TreeBoxesDataset(download=False, root_dir=\"/orange/ewhite/DeepForest/MillionTrees/\", geometry_name=\"boxes\") \n",
"train_dataset = dataset.get_subset(\"train\")\n",
"train_loader = get_train_loader(\"standard\", train_dataset, batch_size=2)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Reading config file: /blue/ewhite/b.weinstein/miniconda3/envs/MillionTrees/lib/python3.10/site-packages/deepforest/data/deepforest_config.yml\n"
"Reading config file: deepforest_config.yml\n"
]
},
{
Expand All @@ -170,7 +199,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Reading config file: /blue/ewhite/b.weinstein/miniconda3/envs/MillionTrees/lib/python3.10/site-packages/deepforest/data/deepforest_config.yml\n"
"Reading config file: deepforest_config.yml\n"
]
},
{
Expand Down Expand Up @@ -207,34 +236,6 @@
"text": [
"Epoch 0: 0%| | 0/1 [00:00<?, ?it/s] "
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"UserWarning: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 2. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 0: 100%|██████████| 1/1 [00:33<00:00, 0.03it/s]"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"`Trainer.fit` stopped: `max_steps=1` reached.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 0: 100%|██████████| 1/1 [00:33<00:00, 0.03it/s]\n"
]
}
],
"source": [
Expand Down
123 changes: 122 additions & 1 deletion docs/examples/baseline_boxes.ipynb

Large diffs are not rendered by default.

4 changes: 3 additions & 1 deletion milliontrees/datasets/TreeBoxes.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,11 @@ def __init__(self,
version=None,
root_dir='data',
download=False,
split_scheme='official'):
split_scheme='official',
geometry_name='y'):
self._version = version
self._split_scheme = split_scheme
self.geometry_name = geometry_name
if self._split_scheme not in ['official', 'random']:
raise ValueError(
f'Split scheme {self._split_scheme} not recognized')
Expand Down
5 changes: 4 additions & 1 deletion milliontrees/datasets/TreePoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,12 @@ def __init__(self,
version=None,
root_dir='data',
download=False,
split_scheme='official'):
split_scheme='official',
geometry_name='y'):
self._version = version
self._split_scheme = split_scheme
self.geometry_name = geometry_name

if self._split_scheme not in ['official', 'random']:
raise ValueError(
f'Split scheme {self._split_scheme} not recognized')
Expand Down
8 changes: 6 additions & 2 deletions milliontrees/datasets/TreePolygons.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@
from milliontrees.datasets.milliontrees_dataset import MillionTreesDataset
from milliontrees.common.grouper import CombinatorialGrouper
from milliontrees.common.metrics.all_metrics import Accuracy, Recall, F1
from albumentations import A, ToTensorV2
from torchvision.tv_tensors import BoundingBoxes, Mask
import torchvision.transforms as transforms
from torchvision.ops import masks_to_boxes
import albumentations as A
from albumentations.pytorch import ToTensorV2
import torch

class TreePolygonsDataset(MillionTreesDataset):
Expand Down Expand Up @@ -50,10 +51,13 @@ def __init__(self,
version=None,
root_dir='data',
download=False,
split_scheme='official'):
split_scheme='official',
geometry_name='y'):

self._version = version
self._split_scheme = split_scheme
self.geometry_name = geometry_name

if self._split_scheme != 'official':
raise ValueError(
f'Split scheme {self._split_scheme} not recognized')
Expand Down
16 changes: 7 additions & 9 deletions milliontrees/datasets/milliontrees_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,6 @@
import torch
import numpy as np

import albumentations as A
from albumentations.pytorch import ToTensorV2

class MillionTreesDataset:
"""Shared dataset class for all MillionTrees datasets.
Expand Down Expand Up @@ -39,7 +36,7 @@ def __getitem__(self, idx):
y_indices = self._input_lookup[self._input_array[idx]]
y = self.y_array[y_indices]
metadata = self.metadata_array[idx]
targets = {"y": y, "labels": np.zeros(len(y), dtype=int)}
targets = {self.geometry_name: y, "labels": np.zeros(len(y), dtype=int)}

return metadata, x, targets

Expand Down Expand Up @@ -88,7 +85,7 @@ def get_subset(self, split, frac=1.0, transform=None):
split_idx = np.sort(
np.random.permutation(split_idx)[:num_to_retain])

return MillionTreesSubset(self, split_idx, transform)
return MillionTreesSubset(self, split_idx, transform, self.geometry_name)

def check_init(self):
"""Convenience function to check that the WILDSDataset is properly
Expand Down Expand Up @@ -463,13 +460,14 @@ def standard_group_eval(metric,

class MillionTreesSubset(MillionTreesDataset):

def __init__(self, dataset, indices, transform=None):
def __init__(self, dataset, indices, transform=None, geometry_name="y"):
"""This acts like `torch.utils.data.Subset`, but on `milliontreesDatasets`. We
pass in `transform` (which is used for data augmentation) explicitly
because it can potentially vary on the training vs. test subsets.
"""
self.dataset = dataset
self.indices = indices
self.geometry_name = geometry_name
inherited_attrs = [
'_dataset_name', '_data_dir', '_collate', '_split_scheme',
'_split_dict', '_split_names', '_y_size', '_n_classes',
Expand All @@ -488,14 +486,14 @@ def __getitem__(self, idx):
if self._dataset_name == 'TreeBoxes':
augmented = self.transform(
image=x,
bboxes=targets["y"],
bboxes=targets[self.geometry_name],
labels=targets["labels"]
)
y = torch.from_numpy(augmented["bboxes"]).float()
elif self._dataset_name == 'TreePoints':
augmented = self.transform(
image=x,
keypoints=targets["y"],
keypoints=targets[self.geometry_name],
labels=targets["labels"]
)
y = torch.from_numpy(augmented["keypoints"]).float()
Expand All @@ -510,7 +508,7 @@ def __getitem__(self, idx):
elif self._dataset_name == 'TreePoints':
y = torch.zeros(0, 2)

targets = {"y": y, "labels": labels}
targets = {self.geometry_name: y, "labels": labels}

return metadata, x, targets

Expand Down
9 changes: 9 additions & 0 deletions tests/test_TreeBoxes.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,15 @@ def test_TreeBoxes_generic(dataset):
assert metadata.shape == (2,)
break

# confirm that we can change target name is needed
def test_get_dataset_with_geometry_name(dataset):
dataset = TreeBoxesDataset(download=False, root_dir=dataset, geometry_name="boxes")
train_dataset = dataset.get_subset("train")

for metadata, image, targets in train_dataset:
boxes, labels = targets["boxes"], targets["labels"]
break

@pytest.mark.parametrize("batch_size", [1, 2])
def test_get_train_dataloader(dataset, batch_size):
dataset = TreeBoxesDataset(download=False, root_dir=dataset)
Expand Down

0 comments on commit 07d1d78

Please sign in to comment.