From 5b5a2249ab6661275b16785c44378dd200bb968a Mon Sep 17 00:00:00 2001 From: Mu-Magdy Date: Wed, 28 Aug 2024 22:11:02 +0300 Subject: [PATCH 01/20] add load model and fetch model --- deepforest/main.py | 88 ++++++++++++++++++++++++++++++++-------------- 1 file changed, 61 insertions(+), 27 deletions(-) diff --git a/deepforest/main.py b/deepforest/main.py index 68a714fd..61253164 100644 --- a/deepforest/main.py +++ b/deepforest/main.py @@ -117,6 +117,60 @@ def __init__(self, self.save_hyperparameters() + def load_model(self, model_name="deepforest-tree", version='main'): + """Load DeepForest models from Hugging Face. + + Args: + model_name (str): The name of the model to load ('deepforest' or 'bird'). + version (str): The model version ('main', 'v1.0.0'). + + Returns: + model (object): A trained PyTorch model. + """ + if model_name == "deepforest-tree": + # Use DeepForest model release + release_tag, self.release_state_dict = utilities.fetch_model( + repo_id="weecology/deepforest-tree", model_filename="NEON.pt") + self.create_model() + + elif model_name == "bird": + # Use DeepForest bird model release + release_tag, self.release_state_dict = utilities.fetch_model( + repo_id="weecology/deepforest-bird", model_filename="bird.pt") + + # Set bird-specific settings + self.config["score_thresh"] = 0.3 + self.label_dict = {"Bird": 0} + self.numeric_to_label_dict = {v: k for k, v in self.label_dict.items()} + + elif model_name == "livestock": + # Use DeepForest bird model release + release_tag, self.release_state_dict = utilities.fetch_model( + repo_id="weecology/deepforest-livestock", model_filename="livestock.pt") + + elif model_name == "nest": + # Use DeepForest bird model release + release_tag, self.release_state_dict = utilities.fetch_model( + repo_id="weecology/everglades-nest-detection", model_filename="nest.pt") + + elif model_name == "deadtrees": + # Use DeepForest bird model release + release_tag, self.release_state_dict = utilities.fetch_model( + repo_id="weecology/cropmodel-deadtrees", + model_filename="cropmodel-deadtrees.pl") + + else: + raise ValueError( + "Invalid model_name specified. Choose from 'tree', 'bird', 'livestock', 'nest', or 'deadtrees'." + ) + + # Load the model state dict + self.model.load_state_dict(torch.load(self.release_state_dict)) + + print(f"Loading model: {release_tag}") + + return self.model + def use_release(self, check_release=True): """Use the latest DeepForest model release from github and load model. Optionally download if release doesn't exist. @@ -126,20 +180,10 @@ def use_release(self, check_release=True): Returns: model (object): A trained PyTorch model """ - # Download latest model from github release - release_tag, self.release_state_dict = utilities.use_release( - check_release=check_release) - if self.config["architecture"] != "retinanet": - warnings.warn( - "The config file specifies architecture {}, but the release model is torchvision retinanet. Reloading main.deepforest with a retinanet model" - .format(self.config["architecture"])) - self.config["architecture"] = "retinanet" - self.create_model() - self.model.load_state_dict(torch.load(self.release_state_dict, weights_only=True)) - # load saved model and tag release - self.__release_version__ = release_tag - print("Loading pre-built model: {}".format(release_tag)) + warnings.warn("use_release will be deprecated in 2.0. use load_model() instead", + DeprecationWarning) + self.load_model() def use_bird_release(self, check_release=True): """Use the latest DeepForest bird model release from github and load @@ -150,21 +194,11 @@ def use_bird_release(self, check_release=True): Returns: model (object): A trained pytorch model """ - # Download latest model from github release - release_tag, self.release_state_dict = utilities.use_bird_release( - check_release=check_release) - self.model.load_state_dict(torch.load(self.release_state_dict, weights_only=True)) - - # load saved model and tag release - self.__release_version__ = release_tag - print("Loading pre-built model: {}".format(release_tag)) - - print("Setting default score threshold to 0.3") - self.config["score_thresh"] = 0.3 - # Set label dictionary to Bird - self.label_dict = {"Bird": 0} - self.numeric_to_label_dict = {v: k for k, v in self.label_dict.items()} + warnings.warn( + "use_bird_release will be deprecated in 2.0. use load_model('bird') instead", + DeprecationWarning) + self.load_model('bird') def create_model(self): """Define a deepforest architecture. This can be done in two ways. From ba968436a8d9a5654eed8c7120410dd037ed346d Mon Sep 17 00:00:00 2001 From: Mu-Magdy Date: Fri, 30 Aug 2024 04:49:29 +0300 Subject: [PATCH 02/20] test --- deepforest/main.py | 71 ++++++++++++++++++++-------------------------- 1 file changed, 30 insertions(+), 41 deletions(-) diff --git a/deepforest/main.py b/deepforest/main.py index 61253164..bd1df9fe 100644 --- a/deepforest/main.py +++ b/deepforest/main.py @@ -18,8 +18,10 @@ from deepforest import dataset, visualize, get_data, utilities, predict from deepforest import evaluate as evaluate_iou +from huggingface_hub import PyTorchModelHubMixin -class deepforest(pl.LightningModule): + +class deepforest(pl.LightningModule, PyTorchModelHubMixin): """Class for training and predicting tree crowns in RGB images.""" def __init__(self, @@ -118,58 +120,45 @@ def __init__(self, self.save_hyperparameters() def load_model(self, model_name="deepforest-tree", version='main'): - """Load DeepForest models from Hugging Face. + """Load DeepForest models from Hugging Face using from_pretrained(). Args: - model_name (str): The name of the model to load ('deepforest' or 'bird'). - version (str): The model version ('main', 'v1.0.0'). + model_name (str): The name of the model to load ('deepforest-tree', 'bird', 'livestock', 'nest', 'deadtrees'). + version (str): The model version ('main', 'v1.0.0', etc.). Returns: model (object): A trained PyTorch model. """ - if model_name == "deepforest-tree": - # Use DeepForest model release - release_tag, self.release_state_dict = utilities.fetch_model( - repo_id="weecology/deepforest-tree", model_filename="NEON.pt") - self.create_model() - - elif model_name == "bird": - # Use DeepForest bird model release - release_tag, self.release_state_dict = utilities.fetch_model( - repo_id="weecology/deepforest-bird", model_filename="bird.pt") - - # Set bird-specific settings - self.config["score_thresh"] = 0.3 - self.label_dict = {"Bird": 0} - self.numeric_to_label_dict = {v: k for k, v in self.label_dict.items()} - - elif model_name == "livestock": - # Use DeepForest bird model release - release_tag, self.release_state_dict = utilities.fetch_model( - repo_id="weecology/deepforest-livestock", model_filename="livestock.pt") - - elif model_name == "nest": - # Use DeepForest bird model release - release_tag, self.release_state_dict = utilities.fetch_model( - repo_id="weecology/everglades-nest-detection", model_filename="nest.pt") - - elif model_name == "deadtrees": - # Use DeepForest bird model release - release_tag, self.release_state_dict = utilities.fetch_model( - repo_id="weecology/cropmodel-deadtrees", - model_filename="cropmodel-deadtrees.pl") + # Map model names to Hugging Face Hub repository IDs + model_repo_dict = { + "deepforest-tree": "weecology/deepforest-tree", + "bird": "weecology/deepforest-bird", + "livestock": "weecology/deepforest-livestock", + "nest": "weecology/everglades-nest-detection", + "deadtrees": "weecology/cropmodel-deadtrees" + } - else: + # Validate model name + if model_name not in model_repo_dict: raise ValueError( - "Invalid model_name specified. Choose from 'tree', 'bird', 'livestock', 'nest', or 'deadtrees'." + "Invalid model_name specified. Choose from 'deepforest-tree', 'bird', 'livestock', 'nest', or 'deadtrees'." ) - # Load the model state dict - self.model.load_state_dict(torch.load(self.release_state_dict)) + # Retrieve the repository ID for the model + repo_id = model_repo_dict[model_name] + + # Load the model using from_pretrained + model = deepforest.from_pretrained('weecology/deepforest-bird', 'bird.pt') + + # Set bird-specific settings if loading the bird model + if model_name == "bird": + model.config["score_thresh"] = 0.3 + model.label_dict = {"Bird": 0} + model.numeric_to_label_dict = {v: k for k, v in model.label_dict.items()} - print(f"Loading model: {release_tag}") + print(f"Loading model: {model_name} from version: {version}") - return self.model + return model def use_release(self, check_release=True): """Use the latest DeepForest model release from github and load model. From b42cb87f82368d6aa4b583bb3ba9d826e3481b92 Mon Sep 17 00:00:00 2001 From: Mu-Magdy Date: Fri, 30 Aug 2024 05:09:55 +0300 Subject: [PATCH 03/20] test --- deepforest/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepforest/main.py b/deepforest/main.py index bd1df9fe..1b26fe2f 100644 --- a/deepforest/main.py +++ b/deepforest/main.py @@ -148,7 +148,7 @@ def load_model(self, model_name="deepforest-tree", version='main'): repo_id = model_repo_dict[model_name] # Load the model using from_pretrained - model = deepforest.from_pretrained('weecology/deepforest-bird', 'bird.pt') + model = deepforest.from_pretrained('weecology/deepforest-bird') # Set bird-specific settings if loading the bird model if model_name == "bird": From 9705ed5ffb85a0f9e20fa575292958b2875ca5e1 Mon Sep 17 00:00:00 2001 From: Mu-Magdy Date: Fri, 30 Aug 2024 05:13:21 +0300 Subject: [PATCH 04/20] test --- deepforest/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepforest/main.py b/deepforest/main.py index 1b26fe2f..10047897 100644 --- a/deepforest/main.py +++ b/deepforest/main.py @@ -148,7 +148,7 @@ def load_model(self, model_name="deepforest-tree", version='main'): repo_id = model_repo_dict[model_name] # Load the model using from_pretrained - model = deepforest.from_pretrained('weecology/deepforest-bird') + model = deepforest.from_pretrained('MuMagdy/test_model') # Set bird-specific settings if loading the bird model if model_name == "bird": From c0227a545b7ffd07be57885f0e22e0391add5dd4 Mon Sep 17 00:00:00 2001 From: Mu-Magdy Date: Fri, 30 Aug 2024 05:50:06 +0300 Subject: [PATCH 05/20] test --- deepforest/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepforest/main.py b/deepforest/main.py index 10047897..766a672e 100644 --- a/deepforest/main.py +++ b/deepforest/main.py @@ -148,7 +148,7 @@ def load_model(self, model_name="deepforest-tree", version='main'): repo_id = model_repo_dict[model_name] # Load the model using from_pretrained - model = deepforest.from_pretrained('MuMagdy/test_model') + model = deepforest.from_pretrained(repo_id) # Set bird-specific settings if loading the bird model if model_name == "bird": From 87db1fc657f1eaf2f05ce34098e2c54c1bf212b7 Mon Sep 17 00:00:00 2001 From: Mu-Magdy Date: Fri, 13 Sep 2024 00:10:46 +0300 Subject: [PATCH 06/20] add from_pretrained and load_model methods --- deepforest/main.py | 63 ++++++++++++++++++++++------------------------ 1 file changed, 30 insertions(+), 33 deletions(-) diff --git a/deepforest/main.py b/deepforest/main.py index 766a672e..b0d0847f 100644 --- a/deepforest/main.py +++ b/deepforest/main.py @@ -14,14 +14,15 @@ from pytorch_lightning.callbacks import LearningRateMonitor from torch import optim from torchmetrics.detection import IntersectionOverUnion, MeanAveragePrecision - +from huggingface_hub import PyTorchModelHubMixin, hf_hub_download, snapshot_download from deepforest import dataset, visualize, get_data, utilities, predict from deepforest import evaluate as evaluate_iou from huggingface_hub import PyTorchModelHubMixin -class deepforest(pl.LightningModule, PyTorchModelHubMixin): +class deepforest( + pl.LightningModule,): # PyTorchModelHubMixin): """Class for training and predicting tree crowns in RGB images.""" def __init__(self, @@ -119,46 +120,42 @@ def __init__(self, self.save_hyperparameters() - def load_model(self, model_name="deepforest-tree", version='main'): - """Load DeepForest models from Hugging Face using from_pretrained(). + def from_pretrained(self, + repo_id: str, + filename: str = "model.safetensors", + config_file: str = "config.json"): + """Load a pre-trained deepforest model from the Hugging Face Hub. Args: - model_name (str): The name of the model to load ('deepforest-tree', 'bird', 'livestock', 'nest', 'deadtrees'). - version (str): The model version ('main', 'v1.0.0', etc.). - - Returns: - model (object): A trained PyTorch model. + repo_id (str): The Hugging Face repository where the model is stored. + filename (str): The model file to download. + config_file (str): The config file to download. Defaults to "config.json". """ - # Map model names to Hugging Face Hub repository IDs - model_repo_dict = { - "deepforest-tree": "weecology/deepforest-tree", - "bird": "weecology/deepforest-bird", - "livestock": "weecology/deepforest-livestock", - "nest": "weecology/everglades-nest-detection", - "deadtrees": "weecology/cropmodel-deadtrees" - } + # Download model weights + model_file = hf_hub_download(repo_id=repo_id, filename=filename) - # Validate model name - if model_name not in model_repo_dict: - raise ValueError( - "Invalid model_name specified. Choose from 'deepforest-tree', 'bird', 'livestock', 'nest', or 'deadtrees'." - ) + # Download and load the config file + downloaded_config_file = hf_hub_download(repo_id=repo_id, filename=config_file) + self.config = utilities.read_config(downloaded_config_file) - # Retrieve the repository ID for the model - repo_id = model_repo_dict[model_name] + # Initialize the model class with the downloaded config + self.create_model() # Initialize the model architecture based on the config - # Load the model using from_pretrained - model = deepforest.from_pretrained(repo_id) + # Load model weights from the checkpoint + self.model.load_state_dict( + torch.load(model_file)) # Use the class method directly - # Set bird-specific settings if loading the bird model - if model_name == "bird": - model.config["score_thresh"] = 0.3 - model.label_dict = {"Bird": 0} - model.numeric_to_label_dict = {v: k for k, v in model.label_dict.items()} + def load_model(self, repo_id: str, **kwargs): + """Wrapper method to load both the model and config file. - print(f"Loading model: {model_name} from version: {version}") + Args: + repo_id (str): The Hugging Face repository ID where the model is stored. + kwargs: Additional arguments to pass to from_pretrained method. + """ + # Call from_pretrained to load the model and config + self.from_pretrained(repo_id, **kwargs) - return model + print(f"Model and configuration loaded from {repo_id}") def use_release(self, check_release=True): """Use the latest DeepForest model release from github and load model. From 9876b9f576fd3e9e9056c45b1caf0c68a460c46c Mon Sep 17 00:00:00 2001 From: Mu-Magdy Date: Sat, 14 Sep 2024 00:54:53 +0300 Subject: [PATCH 07/20] add load_model --- deepforest/main.py | 55 ++++++++++++++++++---------------------------- 1 file changed, 21 insertions(+), 34 deletions(-) diff --git a/deepforest/main.py b/deepforest/main.py index b0d0847f..f4a38ba6 100644 --- a/deepforest/main.py +++ b/deepforest/main.py @@ -21,8 +21,7 @@ from huggingface_hub import PyTorchModelHubMixin -class deepforest( - pl.LightningModule,): # PyTorchModelHubMixin): +class deepforest(pl.LightningModule, PyTorchModelHubMixin): """Class for training and predicting tree crowns in RGB images.""" def __init__(self, @@ -120,42 +119,30 @@ def __init__(self, self.save_hyperparameters() - def from_pretrained(self, - repo_id: str, - filename: str = "model.safetensors", - config_file: str = "config.json"): - """Load a pre-trained deepforest model from the Hugging Face Hub. + def load_model(self, model_name="weecology/deepforest-tree", revision='main'): + """Load DeepForest models from Hugging Face using from_pretrained(). Args: - repo_id (str): The Hugging Face repository where the model is stored. - filename (str): The model file to download. - config_file (str): The config file to download. Defaults to "config.json". - """ - # Download model weights - model_file = hf_hub_download(repo_id=repo_id, filename=filename) - - # Download and load the config file - downloaded_config_file = hf_hub_download(repo_id=repo_id, filename=config_file) - self.config = utilities.read_config(downloaded_config_file) - - # Initialize the model class with the downloaded config - self.create_model() # Initialize the model architecture based on the config - - # Load model weights from the checkpoint - self.model.load_state_dict( - torch.load(model_file)) # Use the class method directly + model_name (str): A repository ID for huggingface in the form of organization/repository + version (str): The model version ('main', 'v1.0.0', etc.). - def load_model(self, repo_id: str, **kwargs): - """Wrapper method to load both the model and config file. - - Args: - repo_id (str): The Hugging Face repository ID where the model is stored. - kwargs: Additional arguments to pass to from_pretrained method. + Returns: + self (object):A trained PyTorch model with its config and weights. """ - # Call from_pretrained to load the model and config - self.from_pretrained(repo_id, **kwargs) - - print(f"Model and configuration loaded from {repo_id}") + # Load the model using from_pretrained + loaded_model = self.from_pretrained(model_name, revision=revision) + self.config = loaded_model.config + self.label_dict = loaded_model.label_dict + self.model = loaded_model.model + self.numeric_to_label_dict = loaded_model.numeric_to_label_dict + + # Set bird-specific settings if loading the bird model + if model_name == "weecology/deepforest-bird": + self.config['retinanet']["score_thresh"] = 0.3 + self.label_dict = {"Bird": 0} + self.numeric_to_label_dict = {v: k for k, v in self.label_dict.items()} + + return self def use_release(self, check_release=True): """Use the latest DeepForest model release from github and load model. From d33b3ff1ededc75840ea11964979db66ef1eab52 Mon Sep 17 00:00:00 2001 From: Mu-Magdy Date: Mon, 16 Sep 2024 19:20:05 +0300 Subject: [PATCH 08/20] Remove use_release and use_bird_release from utilities and add unit test for load_model --- deepforest/main.py | 18 +++++++++++++----- tests/test_main.py | 9 +++++++-- tests/test_utilities.py | 13 +------------ 3 files changed, 21 insertions(+), 19 deletions(-) diff --git a/deepforest/main.py b/deepforest/main.py index f4a38ba6..3fbaeb8d 100644 --- a/deepforest/main.py +++ b/deepforest/main.py @@ -14,7 +14,7 @@ from pytorch_lightning.callbacks import LearningRateMonitor from torch import optim from torchmetrics.detection import IntersectionOverUnion, MeanAveragePrecision -from huggingface_hub import PyTorchModelHubMixin, hf_hub_download, snapshot_download +from huggingface_hub import PyTorchModelHubMixin from deepforest import dataset, visualize, get_data, utilities, predict from deepforest import evaluate as evaluate_iou @@ -120,11 +120,19 @@ def __init__(self, self.save_hyperparameters() def load_model(self, model_name="weecology/deepforest-tree", revision='main'): - """Load DeepForest models from Hugging Face using from_pretrained(). + """Loads a model that has already been pretrained for a specific task, + like tree crown detection. + + Models (technically model weights) are distributed via Hugging Face + and designated the Hugging Face repository ID (model_name), which + is in the form: 'organization/repository'. For a list of models distributed + by the DeepForest team (and the associated model names) see the + documentation: + https://deepforest.readthedocs.io/en/latest/installation_and_setup/prebuilt.html Args: model_name (str): A repository ID for huggingface in the form of organization/repository - version (str): The model version ('main', 'v1.0.0', etc.). + revision (str): The model version ('main', 'v1.0.0', etc.). Returns: self (object):A trained PyTorch model with its config and weights. @@ -156,7 +164,7 @@ def use_release(self, check_release=True): warnings.warn("use_release will be deprecated in 2.0. use load_model() instead", DeprecationWarning) - self.load_model() + self.load_model('weecology/deepforest-tree') def use_bird_release(self, check_release=True): """Use the latest DeepForest bird model release from github and load @@ -171,7 +179,7 @@ def use_bird_release(self, check_release=True): warnings.warn( "use_bird_release will be deprecated in 2.0. use load_model('bird') instead", DeprecationWarning) - self.load_model('bird') + self.load_model('weecology/deepforest-bird') def create_model(self): """Define a deepforest architecture. This can be done in two ways. diff --git a/tests/test_main.py b/tests/test_main.py index 6bcba552..b1e98c6f 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -62,8 +62,7 @@ def m(download_release): m.config["train"]["epochs"] = 2 m.create_trainer() - m.use_release(check_release=False) - + m.load_model() return m @@ -140,6 +139,12 @@ def test_use_bird_release(m): boxes = m.predict_image(path=imgpath) assert not boxes.empty +def test_load_model(m): + imgpath = get_data("SOAP_031.png") + m.load_model('weecology/deepforest-tree') + boxes = m.predict_image(path=imgpath) + assert not boxes.empty + def test_train_empty(m, tmpdir): empty_csv = pd.DataFrame({ diff --git a/tests/test_utilities.py b/tests/test_utilities.py index 5a97e11c..de16f599 100644 --- a/tests/test_utilities.py +++ b/tests/test_utilities.py @@ -30,18 +30,7 @@ def test_read_pascal_voc(): annotations = utilities.read_pascal_voc(xml_path=get_data("OSBS_029.xml")) print(annotations.shape) assert annotations.shape[0] == 61 - - -def test_use_release(download_release): - # Download latest model from github release - release_tag, state_dict = utilities.use_release(check_release=False) - - -def test_use_bird_release(download_release): - # Download latest model from github release - release_tag, state_dict = utilities.use_bird_release() - assert os.path.exists(get_data("bird.pt")) - + def test_float_warning(config): """Users should get a rounding warning when adding annotations with floats""" From 6b7a617bed61ce722714b565c7fd1f3ba6cd321a Mon Sep 17 00:00:00 2001 From: Mu-Magdy Date: Mon, 16 Sep 2024 20:02:41 +0300 Subject: [PATCH 09/20] Add safetensors to requirements --- dev_requirements.txt | 1 + environment.yml | 1 + setup.py | 2 +- 3 files changed, 3 insertions(+), 1 deletion(-) diff --git a/dev_requirements.txt b/dev_requirements.txt index af42fd82..8b3400b8 100644 --- a/dev_requirements.txt +++ b/dev_requirements.txt @@ -29,6 +29,7 @@ pyyaml>=5.1.0 rasterio recommonmark rtree +safetensors slidingwindow sphinx sphinx_markdown_tables diff --git a/environment.yml b/environment.yml index 1cbea5f1..699e066c 100644 --- a/environment.yml +++ b/environment.yml @@ -49,4 +49,5 @@ dependencies: - Pygments - docformatter - opencv-python-headless + - safetensors diff --git a/setup.py b/setup.py index 31412cba..6b609044 100644 --- a/setup.py +++ b/setup.py @@ -66,6 +66,6 @@ install_requires=['albumentations>=1.0.0', 'aiolimiter', 'aiohttp', 'docformatter', 'huggingface_hub', 'geopandas', 'matplotlib', 'nbqa', 'numpy', 'opencv-python-headless>=4.5.4', 'pandas', 'Pillow>6.2.0', 'progressbar2', 'pycocotools', 'Pygments', 'pytorch-lightning>=1.5.8', 'rasterio', - 'recommonmark', 'rtree', 'scipy>1.5', 'six', 'slidingwindow', 'sphinx', 'supervision', 'torch', + 'recommonmark', 'rtree', 'safetensors', 'scipy>1.5', 'six', 'slidingwindow', 'sphinx', 'supervision', 'torch', 'torchvision>=0.13', 'tqdm', 'xmltodict', 'geopandas'], zip_safe=False) From 1531aa55ea4dd1e1ac2195146601e09c6fc79848 Mon Sep 17 00:00:00 2001 From: Mu-Magdy Date: Mon, 16 Sep 2024 21:31:03 +0300 Subject: [PATCH 10/20] modify setup' C --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 5beccc05..35269f20 100644 --- a/setup.py +++ b/setup.py @@ -66,6 +66,6 @@ hf-hub-download install_requires=['albumentations>=1.0.0', 'aiolimiter', 'aiohttp', 'docformatter', 'huggingface_hub', 'geopandas', 'matplotlib', 'nbqa', 'numpy', 'opencv-python-headless>=4.5.4', 'pandas', 'Pillow>6.2.0', - 'progressbar2', 'pycocotools', 'Pygments', 'pytorch-lightning>=1.5.8', 'rasterio', + 'progressbar2', 'pycocotools', "pydata-sphinx-theme", 'Pygments', 'pytorch-lightning>=1.5.8', 'rasterio', 'recommonmark', 'rtree', 'safetensors', 'scipy>1.5', 'six', 'slidingwindow', 'sphinx', 'supervision', 'torch', 'torchvision>=0.13', 'tqdm', 'xmltodict', 'geopandas'],zip_safe=False) From bc2c7d3d07f63da9bee9e141e1abe8abb52ecb48 Mon Sep 17 00:00:00 2001 From: Mu-Magdy Date: Mon, 16 Sep 2024 21:33:25 +0300 Subject: [PATCH 11/20] modify setup' C --- setup.py | 1 - 1 file changed, 1 deletion(-) diff --git a/setup.py b/setup.py index 35269f20..5fc0865d 100644 --- a/setup.py +++ b/setup.py @@ -63,7 +63,6 @@ license=LICENCE, packages=find_packages(), include_package_data=True, - hf-hub-download install_requires=['albumentations>=1.0.0', 'aiolimiter', 'aiohttp', 'docformatter', 'huggingface_hub', 'geopandas', 'matplotlib', 'nbqa', 'numpy', 'opencv-python-headless>=4.5.4', 'pandas', 'Pillow>6.2.0', 'progressbar2', 'pycocotools', "pydata-sphinx-theme", 'Pygments', 'pytorch-lightning>=1.5.8', 'rasterio', From b686db7f8bd4f4fc2610309a361066167b191640 Mon Sep 17 00:00:00 2001 From: Mu-Magdy Date: Mon, 16 Sep 2024 22:06:58 +0300 Subject: [PATCH 12/20] solve test failures --- deepforest/main.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/deepforest/main.py b/deepforest/main.py index 3fbaeb8d..787b1d74 100644 --- a/deepforest/main.py +++ b/deepforest/main.py @@ -138,12 +138,12 @@ def load_model(self, model_name="weecology/deepforest-tree", revision='main'): self (object):A trained PyTorch model with its config and weights. """ # Load the model using from_pretrained + self.create_model() loaded_model = self.from_pretrained(model_name, revision=revision) - self.config = loaded_model.config self.label_dict = loaded_model.label_dict self.model = loaded_model.model self.numeric_to_label_dict = loaded_model.numeric_to_label_dict - + print(loaded_model.config) # Set bird-specific settings if loading the bird model if model_name == "weecology/deepforest-bird": self.config['retinanet']["score_thresh"] = 0.3 From 0ffc8383ccbca888aa9cab230d89d9436d08ceaa Mon Sep 17 00:00:00 2001 From: Mu-Magdy Date: Fri, 20 Sep 2024 14:33:00 +0300 Subject: [PATCH 13/20] test using ethanwhite/df-test --- deepforest/main.py | 5 ++- deepforest/utilities.py | 76 ----------------------------------------- tests/test_main.py | 4 +-- 3 files changed, 4 insertions(+), 81 deletions(-) diff --git a/deepforest/main.py b/deepforest/main.py index 787b1d74..4fa719a1 100644 --- a/deepforest/main.py +++ b/deepforest/main.py @@ -119,7 +119,7 @@ def __init__(self, self.save_hyperparameters() - def load_model(self, model_name="weecology/deepforest-tree", revision='main'): + def load_model(self, model_name="ethanwhite/df-test", revision='main'): """Loads a model that has already been pretrained for a specific task, like tree crown detection. @@ -143,7 +143,6 @@ def load_model(self, model_name="weecology/deepforest-tree", revision='main'): self.label_dict = loaded_model.label_dict self.model = loaded_model.model self.numeric_to_label_dict = loaded_model.numeric_to_label_dict - print(loaded_model.config) # Set bird-specific settings if loading the bird model if model_name == "weecology/deepforest-bird": self.config['retinanet']["score_thresh"] = 0.3 @@ -164,7 +163,7 @@ def use_release(self, check_release=True): warnings.warn("use_release will be deprecated in 2.0. use load_model() instead", DeprecationWarning) - self.load_model('weecology/deepforest-tree') + self.load_model('ethanwhite/df-test') def use_bird_release(self, check_release=True): """Use the latest DeepForest bird model release from github and load diff --git a/deepforest/utilities.py b/deepforest/utilities.py index 917795b3..3b434310 100644 --- a/deepforest/utilities.py +++ b/deepforest/utilities.py @@ -56,82 +56,6 @@ def update_to(self, b=1, bsize=1, tsize=None): self.update(b * bsize - self.n) -def fetch_model(save_dir, repo_id, model_filename, version="main"): - """Downloads a model from Hugging Face and saves it to a specified - directory. - - Parameters: - - save_dir (str): The directory where the model will be saved. - - repo_id (str): The ID of the Hugging Face repository (e.g., "weecology/deepforest-tree"). - - model_filename (str): The name of the model file in the repository (e.g., "NEON.pt"). - - version (str): The version or branch of the model to download (e.g., "main", "v1.0.0"). Default is "main". - - Returns: - - output_path (str): The path where the model is saved. - """ - # Ensure the save directory exists - os.makedirs(save_dir, exist_ok=True) - - # Define the output path - output_path = os.path.join(save_dir, model_filename) - - try: - # Download the model from Hugging Face - hf_hub_download( - repo_id=repo_id, - filename=model_filename, - local_dir=save_dir, - revision=version # Specify the version or branch of the model to download - ) - print(f"Model saved to: {output_path}") - except RevisionNotFoundError as e: - print(f"Error: {e}") - print( - f"Check that the file '{model_filename}' and revision '{version}' exist in the repository '{repo_id}'." - ) - except HfHubHTTPError as e: - print(f"HTTP Error: {e}") - print( - "There might be a problem with your internet connection or the file may not exist." - ) - except Exception as e: - print(f"An unexpected error occurred: {e}") - - return version, output_path - - -def use_bird_release( - save_dir=os.path.join(_ROOT, "data/"), prebuilt_model="bird", check_release=True): - """ - Check the existence of, or download the latest model release from github - Args: - save_dir: Directory to save filepath, default to "data" in deepforest repo - prebuilt_model: Currently only accepts "NEON", but could be expanded to include other prebuilt models. The local model will be called prebuilt_model.h5 on disk. - check_release (logical): whether to check github for a model recent release. In cases where you are hitting the github API rate limit, set to False and any local model will be downloaded. If no model has been downloaded an error will raise. - Returns: release_tag, output_path (str): path to downloaded model - """ - return fetch_model(save_dir, - repo_id="weecology/deepforest-bird", - model_filename="bird.pt") - - -def use_release( - save_dir=os.path.join(_ROOT, "data/"), prebuilt_model="NEON", check_release=True): - """ - Check the existence of, or download the latest model release from github - Args: - save_dir: Directory to save filepath, default to "data" in deepforest repo - prebuilt_model: Currently only accepts "NEON", but could be expanded to include other prebuilt models. The local model will be called prebuilt_model.h5 on disk. - check_release (logical): whether to check github for a model recent release. In cases where you are hitting the github API rate limit, set to False and any local model will be downloaded. If no model has been downloaded an error will raise. - - Returns: release_tag, output_path (str): path to downloaded model - - """ - return fetch_model(save_dir, - repo_id="weecology/deepforest-tree", - model_filename="NEON.pt") - - def read_pascal_voc(xml_path): """Load annotations from xml format (e.g. RectLabel editor) and convert them into retinanet annotations format. diff --git a/tests/test_main.py b/tests/test_main.py index b1e98c6f..cb64757a 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -140,8 +140,8 @@ def test_use_bird_release(m): assert not boxes.empty def test_load_model(m): - imgpath = get_data("SOAP_031.png") - m.load_model('weecology/deepforest-tree') + imgpath = get_data("OSBS_029.png") + m.load_model('ethanwhite/df-test') boxes = m.predict_image(path=imgpath) assert not boxes.empty From df5f8c8ee03ec5da14ec80e23a2c16cab5f581ba Mon Sep 17 00:00:00 2001 From: Mu-Magdy Date: Fri, 20 Sep 2024 15:05:29 +0300 Subject: [PATCH 14/20] test --- deepforest/utilities.py | 76 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 76 insertions(+) diff --git a/deepforest/utilities.py b/deepforest/utilities.py index 3b434310..917795b3 100644 --- a/deepforest/utilities.py +++ b/deepforest/utilities.py @@ -56,6 +56,82 @@ def update_to(self, b=1, bsize=1, tsize=None): self.update(b * bsize - self.n) +def fetch_model(save_dir, repo_id, model_filename, version="main"): + """Downloads a model from Hugging Face and saves it to a specified + directory. + + Parameters: + - save_dir (str): The directory where the model will be saved. + - repo_id (str): The ID of the Hugging Face repository (e.g., "weecology/deepforest-tree"). + - model_filename (str): The name of the model file in the repository (e.g., "NEON.pt"). + - version (str): The version or branch of the model to download (e.g., "main", "v1.0.0"). Default is "main". + + Returns: + - output_path (str): The path where the model is saved. + """ + # Ensure the save directory exists + os.makedirs(save_dir, exist_ok=True) + + # Define the output path + output_path = os.path.join(save_dir, model_filename) + + try: + # Download the model from Hugging Face + hf_hub_download( + repo_id=repo_id, + filename=model_filename, + local_dir=save_dir, + revision=version # Specify the version or branch of the model to download + ) + print(f"Model saved to: {output_path}") + except RevisionNotFoundError as e: + print(f"Error: {e}") + print( + f"Check that the file '{model_filename}' and revision '{version}' exist in the repository '{repo_id}'." + ) + except HfHubHTTPError as e: + print(f"HTTP Error: {e}") + print( + "There might be a problem with your internet connection or the file may not exist." + ) + except Exception as e: + print(f"An unexpected error occurred: {e}") + + return version, output_path + + +def use_bird_release( + save_dir=os.path.join(_ROOT, "data/"), prebuilt_model="bird", check_release=True): + """ + Check the existence of, or download the latest model release from github + Args: + save_dir: Directory to save filepath, default to "data" in deepforest repo + prebuilt_model: Currently only accepts "NEON", but could be expanded to include other prebuilt models. The local model will be called prebuilt_model.h5 on disk. + check_release (logical): whether to check github for a model recent release. In cases where you are hitting the github API rate limit, set to False and any local model will be downloaded. If no model has been downloaded an error will raise. + Returns: release_tag, output_path (str): path to downloaded model + """ + return fetch_model(save_dir, + repo_id="weecology/deepforest-bird", + model_filename="bird.pt") + + +def use_release( + save_dir=os.path.join(_ROOT, "data/"), prebuilt_model="NEON", check_release=True): + """ + Check the existence of, or download the latest model release from github + Args: + save_dir: Directory to save filepath, default to "data" in deepforest repo + prebuilt_model: Currently only accepts "NEON", but could be expanded to include other prebuilt models. The local model will be called prebuilt_model.h5 on disk. + check_release (logical): whether to check github for a model recent release. In cases where you are hitting the github API rate limit, set to False and any local model will be downloaded. If no model has been downloaded an error will raise. + + Returns: release_tag, output_path (str): path to downloaded model + + """ + return fetch_model(save_dir, + repo_id="weecology/deepforest-tree", + model_filename="NEON.pt") + + def read_pascal_voc(xml_path): """Load annotations from xml format (e.g. RectLabel editor) and convert them into retinanet annotations format. From 78ede8175b2d8ade07388ae0ef6c1e2964cde89a Mon Sep 17 00:00:00 2001 From: Mu-Magdy Date: Fri, 20 Sep 2024 17:08:41 +0300 Subject: [PATCH 15/20] replace utilities.use_release() with main.deepforest().use_release() --- deepforest/utilities.py | 76 ----------------------------------------- tests/conftest.py | 3 +- 2 files changed, 2 insertions(+), 77 deletions(-) diff --git a/deepforest/utilities.py b/deepforest/utilities.py index 917795b3..3b434310 100644 --- a/deepforest/utilities.py +++ b/deepforest/utilities.py @@ -56,82 +56,6 @@ def update_to(self, b=1, bsize=1, tsize=None): self.update(b * bsize - self.n) -def fetch_model(save_dir, repo_id, model_filename, version="main"): - """Downloads a model from Hugging Face and saves it to a specified - directory. - - Parameters: - - save_dir (str): The directory where the model will be saved. - - repo_id (str): The ID of the Hugging Face repository (e.g., "weecology/deepforest-tree"). - - model_filename (str): The name of the model file in the repository (e.g., "NEON.pt"). - - version (str): The version or branch of the model to download (e.g., "main", "v1.0.0"). Default is "main". - - Returns: - - output_path (str): The path where the model is saved. - """ - # Ensure the save directory exists - os.makedirs(save_dir, exist_ok=True) - - # Define the output path - output_path = os.path.join(save_dir, model_filename) - - try: - # Download the model from Hugging Face - hf_hub_download( - repo_id=repo_id, - filename=model_filename, - local_dir=save_dir, - revision=version # Specify the version or branch of the model to download - ) - print(f"Model saved to: {output_path}") - except RevisionNotFoundError as e: - print(f"Error: {e}") - print( - f"Check that the file '{model_filename}' and revision '{version}' exist in the repository '{repo_id}'." - ) - except HfHubHTTPError as e: - print(f"HTTP Error: {e}") - print( - "There might be a problem with your internet connection or the file may not exist." - ) - except Exception as e: - print(f"An unexpected error occurred: {e}") - - return version, output_path - - -def use_bird_release( - save_dir=os.path.join(_ROOT, "data/"), prebuilt_model="bird", check_release=True): - """ - Check the existence of, or download the latest model release from github - Args: - save_dir: Directory to save filepath, default to "data" in deepforest repo - prebuilt_model: Currently only accepts "NEON", but could be expanded to include other prebuilt models. The local model will be called prebuilt_model.h5 on disk. - check_release (logical): whether to check github for a model recent release. In cases where you are hitting the github API rate limit, set to False and any local model will be downloaded. If no model has been downloaded an error will raise. - Returns: release_tag, output_path (str): path to downloaded model - """ - return fetch_model(save_dir, - repo_id="weecology/deepforest-bird", - model_filename="bird.pt") - - -def use_release( - save_dir=os.path.join(_ROOT, "data/"), prebuilt_model="NEON", check_release=True): - """ - Check the existence of, or download the latest model release from github - Args: - save_dir: Directory to save filepath, default to "data" in deepforest repo - prebuilt_model: Currently only accepts "NEON", but could be expanded to include other prebuilt models. The local model will be called prebuilt_model.h5 on disk. - check_release (logical): whether to check github for a model recent release. In cases where you are hitting the github API rate limit, set to False and any local model will be downloaded. If no model has been downloaded an error will raise. - - Returns: release_tag, output_path (str): path to downloaded model - - """ - return fetch_model(save_dir, - repo_id="weecology/deepforest-tree", - model_filename="NEON.pt") - - def read_pascal_voc(xml_path): """Load annotations from xml format (e.g. RectLabel editor) and convert them into retinanet annotations format. diff --git a/tests/conftest.py b/tests/conftest.py index 61e6c1e0..a0570050 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -23,7 +23,8 @@ def config(): def download_release(): print("running fixtures") try: - utilities.use_release() + # utilities.use_release() + main.deepforest().use_release() except urllib.error.URLError: # Add a edge case in case no internet access. pass From f8041eed1ee00f640bdfb2675f0baf3617b00b7f Mon Sep 17 00:00:00 2001 From: Mu-Magdy Date: Fri, 20 Sep 2024 17:21:44 +0300 Subject: [PATCH 16/20] add the right repo_id --- deepforest/main.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/deepforest/main.py b/deepforest/main.py index 4fa719a1..b14a283b 100644 --- a/deepforest/main.py +++ b/deepforest/main.py @@ -119,7 +119,7 @@ def __init__(self, self.save_hyperparameters() - def load_model(self, model_name="ethanwhite/df-test", revision='main'): + def load_model(self, model_name="weecology/deepforest-tree", revision='main'): """Loads a model that has already been pretrained for a specific task, like tree crown detection. @@ -163,7 +163,7 @@ def use_release(self, check_release=True): warnings.warn("use_release will be deprecated in 2.0. use load_model() instead", DeprecationWarning) - self.load_model('ethanwhite/df-test') + self.load_model('weecology/deepforest-tree') def use_bird_release(self, check_release=True): """Use the latest DeepForest bird model release from github and load From 84b59c81bea12fabe10d3aa1c5e19c1cb9a3165f Mon Sep 17 00:00:00 2001 From: Mu-Magdy Date: Fri, 20 Sep 2024 17:38:28 +0300 Subject: [PATCH 17/20] remove test for checking NEON.pt path --- tests/conftest.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/conftest.py b/tests/conftest.py index a0570050..c8c0fe96 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -28,7 +28,6 @@ def download_release(): except urllib.error.URLError: # Add a edge case in case no internet access. pass - assert os.path.exists(get_data("NEON.pt")) @pytest.fixture(scope="session") From e609754751a16438289fac5a2934d0010386312e Mon Sep 17 00:00:00 2001 From: Mu-Magdy Date: Fri, 20 Sep 2024 17:41:04 +0300 Subject: [PATCH 18/20] remove test for checking NEON.pt path --- tests/conftest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/conftest.py b/tests/conftest.py index c8c0fe96..090994e1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -24,7 +24,7 @@ def download_release(): print("running fixtures") try: # utilities.use_release() - main.deepforest().use_release() + main.deepforest().load_model() except urllib.error.URLError: # Add a edge case in case no internet access. pass From 0cc67b23a155b11f4205bd14b6e1eb87a2a53c01 Mon Sep 17 00:00:00 2001 From: Mu-Magdy Date: Tue, 24 Sep 2024 15:21:48 +0300 Subject: [PATCH 19/20] add docs to the new function --- docs/getting_started/getting_started.md | 3 +- docs/getting_started/model_loader.md | 45 +++++++++++++++++++++++++ docs/getting_started/sample_test.ipynb | 20 ++++++++++- tests/conftest.py | 1 - 4 files changed, 65 insertions(+), 4 deletions(-) create mode 100644 docs/getting_started/model_loader.md diff --git a/docs/getting_started/getting_started.md b/docs/getting_started/getting_started.md index 63260707..ab182465 100644 --- a/docs/getting_started/getting_started.md +++ b/docs/getting_started/getting_started.md @@ -12,8 +12,7 @@ from deepforest import get_data import matplotlib.pyplot as plt model = main.deepforest() -model.use_release() - +model.load_model(model_name="weecology/deepforest-tree", revision="main") sample_image_path = get_data("OSBS_029.png") img = model.predict_image(path=sample_image_path, return_plot=True) diff --git a/docs/getting_started/model_loader.md b/docs/getting_started/model_loader.md new file mode 100644 index 00000000..3c2119d1 --- /dev/null +++ b/docs/getting_started/model_loader.md @@ -0,0 +1,45 @@ +# DeepForest Model Loader + +This function loads pretrained DeepForest models from Hugging Face, with support for different model revisions. Additionally, you can save the model configuration and weights using `save_pretrained` and reload it later with `from_pretrained`. + +## `load_model` + +### Description + +The `load_model` function loads a pretrained model from Hugging Face using the repository name (`model_name`) and the desired model version (`revision`). This is useful for tasks such as tree crown detection, but it can also load bird detection models with custom configurations. + +### Arguments + +- `model_name` (str): A repository ID for Hugging Face in the form `organization/repository`. Default is `"weecology/deepforest-tree"`. + + you can choose from: + - weecology/deepforest-tree + - weecology/deepforest-bird + - weecology/deepforest-livestock + - weecology/everglades-nest-detection + - weecology/cropmodel-deadtrees +- `revision` (str): The model version (e.g., 'main', 'v1.0.0', etc.). Default is `'main'`. + +### Returns + +- `` (object): A trained PyTorch model with its configuration and weights. + +### Example Usage + +#### Load a Model + +```python +from deepforest import main +from deepforest import get_data +import matplotlib.pyplot as plt + +# Initialize the model class +model = main.deepforest() + +# Load a pretrained tree detection model from Hugging Face +model.load_model(model_name="weecology/deepforest-tree", revision="main") + +sample_image_path = get_data("OSBS_029.png") +img = model.predict_image(path=sample_image_path, return_plot=True) + +``` diff --git a/docs/getting_started/sample_test.ipynb b/docs/getting_started/sample_test.ipynb index 2ba7ae43..5dd073d3 100644 --- a/docs/getting_started/sample_test.ipynb +++ b/docs/getting_started/sample_test.ipynb @@ -28,12 +28,30 @@ "import matplotlib.pyplot as plt\n", "\n", "# model = main.deepforest()\n", - "# model.use_release()\n", + "# model.load_model(model_name=\"weecology/deepforest-tree\", revision=\"main\")\n", "\n", "# img = model.predict_image(path=\"../tests/data/OSBS_029_0.png\",return_plot=True)\n", "# #predict_image returns plot in BlueGreenRed (opencv style), but matplotlib likes RedGreenBlue, switch the channel order.\n", "# plt.imshow(img[:,:,::-1])" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# To save a model you can use \n", + "# model.save_pretrained(\"Path/to/model/folder\")\n", + "# This will give config.json for the configuration and model.safetensors for the weights" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { diff --git a/tests/conftest.py b/tests/conftest.py index 090994e1..ae69cc42 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -23,7 +23,6 @@ def config(): def download_release(): print("running fixtures") try: - # utilities.use_release() main.deepforest().load_model() except urllib.error.URLError: # Add a edge case in case no internet access. From 9bd999811da1dbfd6ccaaa9580a793d5a85463ba Mon Sep 17 00:00:00 2001 From: Mu-Magdy Date: Wed, 25 Sep 2024 01:38:14 +0300 Subject: [PATCH 20/20] add docs to the new function --- docs/getting_started/index.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/getting_started/index.rst b/docs/getting_started/index.rst index 620b7428..f1ab6f05 100644 --- a/docs/getting_started/index.rst +++ b/docs/getting_started/index.rst @@ -6,5 +6,6 @@ Getting Started :caption: Contents: getting_started + model_loader Reading_and_Writing sample_test