-
Notifications
You must be signed in to change notification settings - Fork 174
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add load model and fetch model #762
Merged
Merged
Changes from 5 commits
Commits
Show all changes
22 commits
Select commit
Hold shift + click to select a range
5b5a224
add load model and fetch model
Mu-Magdy ba96843
test
Mu-Magdy b42cb87
test
Mu-Magdy 9705ed5
test
Mu-Magdy c0227a5
test
Mu-Magdy 87db1fc
add from_pretrained and load_model methods
Mu-Magdy 9876b9f
add load_model
Mu-Magdy d33b3ff
Remove use_release and use_bird_release from utilities and add unit t…
Mu-Magdy 6b7a617
Add safetensors to requirements
Mu-Magdy d846f77
Merge branch 'main' into hf-hub-download
Mu-Magdy 1531aa5
modify setup'
Mu-Magdy bc2c7d3
modify setup'
Mu-Magdy b686db7
solve test failures
Mu-Magdy 0ffc838
test using ethanwhite/df-test
Mu-Magdy df5f8c8
test
Mu-Magdy 78ede81
replace utilities.use_release() with main.deepforest().use_release()
Mu-Magdy f8041ee
add the right repo_id
Mu-Magdy 84b59c8
remove test for checking NEON.pt path
Mu-Magdy e609754
remove test for checking NEON.pt path
Mu-Magdy 0cc67b2
add docs to the new function
Mu-Magdy 9bd9998
add docs to the new function
Mu-Magdy 143ab82
Merge branch 'main' into hf-hub-download
henrykironde File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,8 +18,10 @@ | |
from deepforest import dataset, visualize, get_data, utilities, predict | ||
from deepforest import evaluate as evaluate_iou | ||
|
||
from huggingface_hub import PyTorchModelHubMixin | ||
|
||
class deepforest(pl.LightningModule): | ||
|
||
class deepforest(pl.LightningModule, PyTorchModelHubMixin): | ||
"""Class for training and predicting tree crowns in RGB images.""" | ||
|
||
def __init__(self, | ||
|
@@ -117,6 +119,47 @@ def __init__(self, | |
|
||
self.save_hyperparameters() | ||
|
||
def load_model(self, model_name="deepforest-tree", version='main'): | ||
"""Load DeepForest models from Hugging Face using from_pretrained(). | ||
|
||
Args: | ||
model_name (str): The name of the model to load ('deepforest-tree', 'bird', 'livestock', 'nest', 'deadtrees'). | ||
version (str): The model version ('main', 'v1.0.0', etc.). | ||
|
||
Returns: | ||
model (object): A trained PyTorch model. | ||
""" | ||
# Map model names to Hugging Face Hub repository IDs | ||
model_repo_dict = { | ||
"deepforest-tree": "weecology/deepforest-tree", | ||
"bird": "weecology/deepforest-bird", | ||
"livestock": "weecology/deepforest-livestock", | ||
"nest": "weecology/everglades-nest-detection", | ||
"deadtrees": "weecology/cropmodel-deadtrees" | ||
} | ||
|
||
# Validate model name | ||
if model_name not in model_repo_dict: | ||
raise ValueError( | ||
"Invalid model_name specified. Choose from 'deepforest-tree', 'bird', 'livestock', 'nest', or 'deadtrees'." | ||
) | ||
|
||
# Retrieve the repository ID for the model | ||
repo_id = model_repo_dict[model_name] | ||
|
||
# Load the model using from_pretrained | ||
model = deepforest.from_pretrained(repo_id) | ||
|
||
# Set bird-specific settings if loading the bird model | ||
if model_name == "bird": | ||
model.config["score_thresh"] = 0.3 | ||
model.label_dict = {"Bird": 0} | ||
model.numeric_to_label_dict = {v: k for k, v in model.label_dict.items()} | ||
|
||
print(f"Loading model: {model_name} from version: {version}") | ||
|
||
return model | ||
|
||
def use_release(self, check_release=True): | ||
"""Use the latest DeepForest model release from github and load model. | ||
Optionally download if release doesn't exist. | ||
|
@@ -126,20 +169,10 @@ def use_release(self, check_release=True): | |
Returns: | ||
model (object): A trained PyTorch model | ||
""" | ||
# Download latest model from github release | ||
release_tag, self.release_state_dict = utilities.use_release( | ||
check_release=check_release) | ||
if self.config["architecture"] != "retinanet": | ||
warnings.warn( | ||
"The config file specifies architecture {}, but the release model is torchvision retinanet. Reloading main.deepforest with a retinanet model" | ||
.format(self.config["architecture"])) | ||
self.config["architecture"] = "retinanet" | ||
self.create_model() | ||
self.model.load_state_dict(torch.load(self.release_state_dict, weights_only=True)) | ||
|
||
# load saved model and tag release | ||
self.__release_version__ = release_tag | ||
print("Loading pre-built model: {}".format(release_tag)) | ||
warnings.warn("use_release will be deprecated in 2.0. use load_model() instead", | ||
DeprecationWarning) | ||
self.load_model() | ||
|
||
def use_bird_release(self, check_release=True): | ||
"""Use the latest DeepForest bird model release from github and load | ||
|
@@ -150,21 +183,11 @@ def use_bird_release(self, check_release=True): | |
Returns: | ||
model (object): A trained pytorch model | ||
""" | ||
# Download latest model from github release | ||
release_tag, self.release_state_dict = utilities.use_bird_release( | ||
check_release=check_release) | ||
self.model.load_state_dict(torch.load(self.release_state_dict, weights_only=True)) | ||
|
||
# load saved model and tag release | ||
self.__release_version__ = release_tag | ||
print("Loading pre-built model: {}".format(release_tag)) | ||
|
||
print("Setting default score threshold to 0.3") | ||
self.config["score_thresh"] = 0.3 | ||
|
||
# Set label dictionary to Bird | ||
self.label_dict = {"Bird": 0} | ||
self.numeric_to_label_dict = {v: k for k, v in self.label_dict.items()} | ||
warnings.warn( | ||
"use_bird_release will be deprecated in 2.0. use load_model('bird') instead", | ||
DeprecationWarning) | ||
self.load_model('bird') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should be |
||
|
||
def create_model(self): | ||
"""Define a deepforest architecture. This can be done in two ways. | ||
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We want the docstring to focus on what the function does, not how it does it, and be understandable to a novice user. Something like: