Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add load model and fetch model #762

Merged
merged 22 commits into from
Sep 26, 2024
Merged
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 51 additions & 28 deletions deepforest/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@
from deepforest import dataset, visualize, get_data, utilities, predict
from deepforest import evaluate as evaluate_iou

from huggingface_hub import PyTorchModelHubMixin

class deepforest(pl.LightningModule):

class deepforest(pl.LightningModule, PyTorchModelHubMixin):
"""Class for training and predicting tree crowns in RGB images."""

def __init__(self,
Expand Down Expand Up @@ -117,6 +119,47 @@ def __init__(self,

self.save_hyperparameters()

def load_model(self, model_name="deepforest-tree", version='main'):
"""Load DeepForest models from Hugging Face using from_pretrained().
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We want the docstring to focus on what the function does, not how it does it, and be understandable to a novice user. Something like:

"""Load pretrained DeepForest model

Loads a model that has already been pretrained for a specific task,
like tree crown detection.

Models (technically model weights) are distributed via Hugging Face
and designated the Hugging Face repository ID (model_name), which
is in the form: 'organization/repository'. For a list of models distributed
by the DeepForest team (and the associated model names) see the
documentation:
https://deepforest.readthedocs.io/en/latest/installation_and_setup/prebuilt.html
...

"""


Args:
model_name (str): The name of the model to load ('deepforest-tree', 'bird', 'livestock', 'nest', 'deadtrees').
version (str): The model version ('main', 'v1.0.0', etc.).

Returns:
model (object): A trained PyTorch model.
"""
# Map model names to Hugging Face Hub repository IDs
model_repo_dict = {
"deepforest-tree": "weecology/deepforest-tree",
"bird": "weecology/deepforest-bird",
"livestock": "weecology/deepforest-livestock",
"nest": "weecology/everglades-nest-detection",
"deadtrees": "weecology/cropmodel-deadtrees"
}

# Validate model name
if model_name not in model_repo_dict:
raise ValueError(
"Invalid model_name specified. Choose from 'deepforest-tree', 'bird', 'livestock', 'nest', or 'deadtrees'."
)

# Retrieve the repository ID for the model
repo_id = model_repo_dict[model_name]

# Load the model using from_pretrained
model = deepforest.from_pretrained(repo_id)

# Set bird-specific settings if loading the bird model
if model_name == "bird":
model.config["score_thresh"] = 0.3
model.label_dict = {"Bird": 0}
model.numeric_to_label_dict = {v: k for k, v in model.label_dict.items()}

print(f"Loading model: {model_name} from version: {version}")

return model

def use_release(self, check_release=True):
"""Use the latest DeepForest model release from github and load model.
Optionally download if release doesn't exist.
Expand All @@ -126,20 +169,10 @@ def use_release(self, check_release=True):
Returns:
model (object): A trained PyTorch model
"""
# Download latest model from github release
release_tag, self.release_state_dict = utilities.use_release(
check_release=check_release)
if self.config["architecture"] != "retinanet":
warnings.warn(
"The config file specifies architecture {}, but the release model is torchvision retinanet. Reloading main.deepforest with a retinanet model"
.format(self.config["architecture"]))
self.config["architecture"] = "retinanet"
self.create_model()
self.model.load_state_dict(torch.load(self.release_state_dict, weights_only=True))

# load saved model and tag release
self.__release_version__ = release_tag
print("Loading pre-built model: {}".format(release_tag))
warnings.warn("use_release will be deprecated in 2.0. use load_model() instead",
DeprecationWarning)
self.load_model()

def use_bird_release(self, check_release=True):
"""Use the latest DeepForest bird model release from github and load
Expand All @@ -150,21 +183,11 @@ def use_bird_release(self, check_release=True):
Returns:
model (object): A trained pytorch model
"""
# Download latest model from github release
release_tag, self.release_state_dict = utilities.use_bird_release(
check_release=check_release)
self.model.load_state_dict(torch.load(self.release_state_dict, weights_only=True))

# load saved model and tag release
self.__release_version__ = release_tag
print("Loading pre-built model: {}".format(release_tag))

print("Setting default score threshold to 0.3")
self.config["score_thresh"] = 0.3

# Set label dictionary to Bird
self.label_dict = {"Bird": 0}
self.numeric_to_label_dict = {v: k for k, v in self.label_dict.items()}
warnings.warn(
"use_bird_release will be deprecated in 2.0. use load_model('bird') instead",
DeprecationWarning)
self.load_model('bird')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be 'weecology/deepforest-bird'.


def create_model(self):
"""Define a deepforest architecture. This can be done in two ways.
Expand Down
Loading