diff --git a/deepforest/main.py b/deepforest/main.py index b0d0847f..f4a38ba6 100644 --- a/deepforest/main.py +++ b/deepforest/main.py @@ -21,8 +21,7 @@ from huggingface_hub import PyTorchModelHubMixin -class deepforest( - pl.LightningModule,): # PyTorchModelHubMixin): +class deepforest(pl.LightningModule, PyTorchModelHubMixin): """Class for training and predicting tree crowns in RGB images.""" def __init__(self, @@ -120,42 +119,30 @@ def __init__(self, self.save_hyperparameters() - def from_pretrained(self, - repo_id: str, - filename: str = "model.safetensors", - config_file: str = "config.json"): - """Load a pre-trained deepforest model from the Hugging Face Hub. + def load_model(self, model_name="weecology/deepforest-tree", revision='main'): + """Load DeepForest models from Hugging Face using from_pretrained(). Args: - repo_id (str): The Hugging Face repository where the model is stored. - filename (str): The model file to download. - config_file (str): The config file to download. Defaults to "config.json". - """ - # Download model weights - model_file = hf_hub_download(repo_id=repo_id, filename=filename) - - # Download and load the config file - downloaded_config_file = hf_hub_download(repo_id=repo_id, filename=config_file) - self.config = utilities.read_config(downloaded_config_file) - - # Initialize the model class with the downloaded config - self.create_model() # Initialize the model architecture based on the config - - # Load model weights from the checkpoint - self.model.load_state_dict( - torch.load(model_file)) # Use the class method directly + model_name (str): A repository ID for huggingface in the form of organization/repository + version (str): The model version ('main', 'v1.0.0', etc.). - def load_model(self, repo_id: str, **kwargs): - """Wrapper method to load both the model and config file. - - Args: - repo_id (str): The Hugging Face repository ID where the model is stored. - kwargs: Additional arguments to pass to from_pretrained method. + Returns: + self (object):A trained PyTorch model with its config and weights. """ - # Call from_pretrained to load the model and config - self.from_pretrained(repo_id, **kwargs) - - print(f"Model and configuration loaded from {repo_id}") + # Load the model using from_pretrained + loaded_model = self.from_pretrained(model_name, revision=revision) + self.config = loaded_model.config + self.label_dict = loaded_model.label_dict + self.model = loaded_model.model + self.numeric_to_label_dict = loaded_model.numeric_to_label_dict + + # Set bird-specific settings if loading the bird model + if model_name == "weecology/deepforest-bird": + self.config['retinanet']["score_thresh"] = 0.3 + self.label_dict = {"Bird": 0} + self.numeric_to_label_dict = {v: k for k, v in self.label_dict.items()} + + return self def use_release(self, check_release=True): """Use the latest DeepForest model release from github and load model.