Skip to content

Commit

Permalink
Remove warning from docs
Browse files Browse the repository at this point in the history
  • Loading branch information
henrykironde committed Oct 17, 2023
1 parent a9391e1 commit 8128d38
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 31 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
.DS_Store
.RHistory
build/
CONTRIBUTING.md
current_bird_release.csv
current_release.csv
lightning_logs/*
Expand Down
60 changes: 29 additions & 31 deletions deepforest/main.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,27 @@
# entry point for deepforest model
import importlib
import os
import pandas as pd
from PIL import Image
import torch
import typing
import warnings

import cv2
import numpy as np
import pandas as pd
import pytorch_lightning as pl
import rasterio as rio
import torch
from PIL import Image
from pytorch_lightning.callbacks import LearningRateMonitor
from torch import optim
import numpy as np
from torchmetrics.detection import IntersectionOverUnion, MeanAveragePrecision

from deepforest import dataset, visualize, get_data, utilities, model, predict
from deepforest import dataset, visualize, get_data, utilities, predict
from deepforest import evaluate as evaluate_iou
from deepforest.callbacks import iou_callback
from pytorch_lightning.callbacks import LearningRateMonitor
import rasterio as rio
import cv2
import warnings
import importlib


class deepforest(pl.LightningModule):
"""Class for training and predicting tree crowns in RGB images
"""
"""Class for training and predicting tree crowns in RGB images"""

def __init__(self,
num_classes: int = 1,
Expand All @@ -31,13 +30,13 @@ def __init__(self,
config_file: str = 'deepforest_config.yml',
config_args=None,
model=None):
"""
Args:
"""Args:
num_classes (int): number of classes in the model
config_file (str): path to deepforest config file
model (model.Model()): a deepforest model object, see model.Model().
config_args (dict): a dictionary of key->value to update config file at run time. e.g. {"batch_size":10}
- This is useful for iterating over arguments during model testing.
config_args (dict): a dictionary of key->value to update
config file at run time. e.g. {"batch_size":10}
This is useful for iterating over arguments during model testing.
Returns:
self: a deepforest pytorch lightning module
"""
Expand Down Expand Up @@ -76,7 +75,7 @@ def __init__(self,
class_metrics=True, iou_threshold=self.config["validation"]["iou_threshold"])
self.mAP_metric = MeanAveragePrecision()

#Create a default trainer.
# Create a default trainer.
self.create_trainer()

# Label encoder and decoder
Expand Down Expand Up @@ -146,24 +145,23 @@ def use_bird_release(self, check_release=True):
self.numeric_to_label_dict = {v: k for k, v in self.label_dict.items()}

def create_model(self):
"""Define a deepforest architecture. This can be done in two ways.
"""Define a deepforest architecture. This can be done in two ways.
Passed as the model argument to deepforest __init__(),
or as a named architecture in config["architecture"],
or as a named architecture in config["architecture"],
which corresponds to a file in models/, as is a subclass of model.Model().
The config args in the .yaml are specified
retinanet:
nms_thresh: 0.1
score_thresh: 0.2
RCNN:
nms_thresh: 0.1
etc.
The config args in the .yaml are specified
>>> # retinanet:
>>> # ms_thresh: 0.1
>>> # score_thresh: 0.2
>>> # RCNN:
>>> # nms_thresh: 0.1
>>> # etc.
"""
if self.model is None:
model_name = importlib.import_module("deepforest.models.{}".format(
self.config["architecture"]))
self.model = model_name.Model(config=self.config).create_model()
else:
pass

def create_trainer(self, logger=None, callbacks=[], **kwargs):
"""Create a pytorch lightning training by reading config files
Expand Down Expand Up @@ -399,7 +397,7 @@ def predict_file(self, csv_file, root_dir, savedir=None, color=None, thickness=1
out = self.predict_step(batch, i)
batched_results.append(out)

#Flatten list from batched prediction
# Flatten list from batched prediction
prediction_list = []
for batch in batched_results:
for boxes in batch:
Expand Down Expand Up @@ -491,7 +489,7 @@ def predict_tile(self,
patch_size=patch_size)
batched_results = self.trainer.predict(self, self.predict_dataloader(ds))

#Flatten list from batched prediction
# Flatten list from batched prediction
results = []
for batch in batched_results:
for boxes in batch:
Expand Down Expand Up @@ -556,7 +554,7 @@ def validation_step(self, batch, batch_idx):
print("Empty batch encountered, skipping")
return None

#Get loss from "train" mode, but don't allow optimization
# Get loss from "train" mode, but don't allow optimization
self.model.train()
with torch.no_grad():
loss_dict = self.model.forward(images, targets)
Expand Down

0 comments on commit 8128d38

Please sign in to comment.