Skip to content

Commit

Permalink
add load_model
Browse files Browse the repository at this point in the history
  • Loading branch information
Mu-Magdy committed Sep 13, 2024
1 parent 87db1fc commit 9876b9f
Showing 1 changed file with 21 additions and 34 deletions.
55 changes: 21 additions & 34 deletions deepforest/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@
from huggingface_hub import PyTorchModelHubMixin


class deepforest(
pl.LightningModule,): # PyTorchModelHubMixin):
class deepforest(pl.LightningModule, PyTorchModelHubMixin):
"""Class for training and predicting tree crowns in RGB images."""

def __init__(self,
Expand Down Expand Up @@ -120,42 +119,30 @@ def __init__(self,

self.save_hyperparameters()

def from_pretrained(self,
repo_id: str,
filename: str = "model.safetensors",
config_file: str = "config.json"):
"""Load a pre-trained deepforest model from the Hugging Face Hub.
def load_model(self, model_name="weecology/deepforest-tree", revision='main'):
"""Load DeepForest models from Hugging Face using from_pretrained().
Args:
repo_id (str): The Hugging Face repository where the model is stored.
filename (str): The model file to download.
config_file (str): The config file to download. Defaults to "config.json".
"""
# Download model weights
model_file = hf_hub_download(repo_id=repo_id, filename=filename)

# Download and load the config file
downloaded_config_file = hf_hub_download(repo_id=repo_id, filename=config_file)
self.config = utilities.read_config(downloaded_config_file)

# Initialize the model class with the downloaded config
self.create_model() # Initialize the model architecture based on the config

# Load model weights from the checkpoint
self.model.load_state_dict(
torch.load(model_file)) # Use the class method directly
model_name (str): A repository ID for huggingface in the form of organization/repository
version (str): The model version ('main', 'v1.0.0', etc.).
def load_model(self, repo_id: str, **kwargs):
"""Wrapper method to load both the model and config file.
Args:
repo_id (str): The Hugging Face repository ID where the model is stored.
kwargs: Additional arguments to pass to from_pretrained method.
Returns:
self (object):A trained PyTorch model with its config and weights.
"""
# Call from_pretrained to load the model and config
self.from_pretrained(repo_id, **kwargs)

print(f"Model and configuration loaded from {repo_id}")
# Load the model using from_pretrained
loaded_model = self.from_pretrained(model_name, revision=revision)
self.config = loaded_model.config
self.label_dict = loaded_model.label_dict
self.model = loaded_model.model
self.numeric_to_label_dict = loaded_model.numeric_to_label_dict

# Set bird-specific settings if loading the bird model
if model_name == "weecology/deepforest-bird":
self.config['retinanet']["score_thresh"] = 0.3
self.label_dict = {"Bird": 0}
self.numeric_to_label_dict = {v: k for k, v in self.label_dict.items()}

return self

def use_release(self, check_release=True):
"""Use the latest DeepForest model release from github and load model.
Expand Down

0 comments on commit 9876b9f

Please sign in to comment.