diff --git a/docs/examples/Datasets.ipynb b/docs/examples/Datasets.ipynb index 12e9318..f9619d2 100644 --- a/docs/examples/Datasets.ipynb +++ b/docs/examples/Datasets.ipynb @@ -63,7 +63,7 @@ "text": [ "/blue/ewhite/b.weinstein/miniconda3/envs/MillionTrees/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n", - "/blue/ewhite/b.weinstein/miniconda3/envs/MillionTrees/lib/python3.10/site-packages/albumentations/__init__.py:13: UserWarning: A new version of Albumentations is available: 1.4.20 (you have 1.4.15). Upgrade using: pip install -U albumentations. To disable automatic update checks, set the environment variable NO_ALBUMENTATIONS_UPDATE to 1.\n", + "/blue/ewhite/b.weinstein/miniconda3/envs/MillionTrees/lib/python3.10/site-packages/albumentations/__init__.py:13: UserWarning: A new version of Albumentations is available: 1.4.21 (you have 1.4.15). Upgrade using: pip install -U albumentations. To disable automatic update checks, set the environment variable NO_ALBUMENTATIONS_UPDATE to 1.\n", " check_for_updates()\n" ] } @@ -85,7 +85,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 4, "metadata": {}, "outputs": [ { @@ -93,8 +93,8 @@ "output_type": "stream", "text": [ "Metadata length: 2\n", - "Image shape: (81, 2), Image type: \n", - "Label shape: torch.Size([3, 448, 448]), Label type: \n" + "Image shape: torch.Size([3, 448, 448]), Image type: \n", + "Targets keys: dict_keys(['y', 'labels']), Label type: \n" ] } ], @@ -104,10 +104,10 @@ "train_dataset = dataset.get_subset(\"train\")\n", "\n", "# View the first image in the dataset\n", - "image, label, metadata = train_dataset[0]\n", + "metadata, image, targets = train_dataset[0]\n", "print(f\"Metadata length: {len(metadata)}\")\n", "print(f\"Image shape: {image.shape}, Image type: {type(image)}\")\n", - "print(f\"Label shape: {label.shape}, Label type: {type(label)}\")" + "print(f\"Targets keys: {targets.keys()}, Label type: {type(targets)}\")" ] }, { @@ -119,7 +119,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 5, "metadata": {}, "outputs": [ { @@ -127,9 +127,9 @@ "output_type": "stream", "text": [ "Warning: You are loading the entire dataset. Consider using dataset.get_subset('train') for a portion of the dataset if intended.\n", - "Targets is a list of dictionaries with the following keys: dict_keys(['boxes', 'labels'])\n", + "Targets is a list of dictionaries with the following keys: dict_keys(['y', 'labels'])\n", "Image shape: torch.Size([2, 3, 448, 448]), Image type: \n", - "Annotation shape of the first image: torch.Size([4, 4])\n" + "Annotation shape of the first image: torch.Size([17, 4])\n" ] } ], @@ -140,20 +140,49 @@ "for metadata, image, targets in train_loader:\n", " print(\"Targets is a list of dictionaries with the following keys: \", targets[0].keys())\n", " print(f\"Image shape: {image.shape}, Image type: {type(image)}\")\n", - " print(f\"Annotation shape of the first image: {targets[0]['boxes'].shape}\")\n", + " print(f\"Annotation shape of the first image: {targets[0]['y'].shape}\")\n", " break # Just show the first batch" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Reformat\n", + "\n", + "DeepForest expects the target tensor have the key 'boxes', not 'y', we can ask MillionTrees for help by passing the geometry name to the dataset class. We do not need to redownload the dataset, we can all TreeBoxesDataset directly." + ] + }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Warning: You are loading the entire dataset. Consider using dataset.get_subset('train') for a portion of the dataset if intended.\n" + ] + } + ], + "source": [ + "from milliontrees.datasets.TreeBoxes import TreeBoxesDataset\n", + "dataset = TreeBoxesDataset(download=False, root_dir=\"/orange/ewhite/DeepForest/MillionTrees/\", geometry_name=\"boxes\") \n", + "train_dataset = dataset.get_subset(\"train\")\n", + "train_loader = get_train_loader(\"standard\", train_dataset, batch_size=2)" + ] + }, + { + "cell_type": "code", + "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Reading config file: /blue/ewhite/b.weinstein/miniconda3/envs/MillionTrees/lib/python3.10/site-packages/deepforest/data/deepforest_config.yml\n" + "Reading config file: deepforest_config.yml\n" ] }, { @@ -170,7 +199,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "Reading config file: /blue/ewhite/b.weinstein/miniconda3/envs/MillionTrees/lib/python3.10/site-packages/deepforest/data/deepforest_config.yml\n" + "Reading config file: deepforest_config.yml\n" ] }, { @@ -207,34 +236,6 @@ "text": [ "Epoch 0: 0%| | 0/1 [00:00