diff --git a/deepforest/main.py b/deepforest/main.py index e7f0862c..6668006e 100644 --- a/deepforest/main.py +++ b/deepforest/main.py @@ -14,12 +14,14 @@ from pytorch_lightning.callbacks import LearningRateMonitor from torch import optim from torchmetrics.detection import IntersectionOverUnion, MeanAveragePrecision - +from huggingface_hub import PyTorchModelHubMixin from deepforest import dataset, visualize, get_data, utilities, predict from deepforest import evaluate as evaluate_iou +from huggingface_hub import PyTorchModelHubMixin + -class deepforest(pl.LightningModule): +class deepforest(pl.LightningModule, PyTorchModelHubMixin): """Class for training and predicting tree crowns in RGB images.""" def __init__(self, @@ -117,6 +119,38 @@ def __init__(self, self.save_hyperparameters() + def load_model(self, model_name="weecology/deepforest-tree", revision='main'): + """Loads a model that has already been pretrained for a specific task, + like tree crown detection. + + Models (technically model weights) are distributed via Hugging Face + and designated the Hugging Face repository ID (model_name), which + is in the form: 'organization/repository'. For a list of models distributed + by the DeepForest team (and the associated model names) see the + documentation: + https://deepforest.readthedocs.io/en/latest/installation_and_setup/prebuilt.html + + Args: + model_name (str): A repository ID for huggingface in the form of organization/repository + revision (str): The model version ('main', 'v1.0.0', etc.). + + Returns: + self (object):A trained PyTorch model with its config and weights. + """ + # Load the model using from_pretrained + self.create_model() + loaded_model = self.from_pretrained(model_name, revision=revision) + self.label_dict = loaded_model.label_dict + self.model = loaded_model.model + self.numeric_to_label_dict = loaded_model.numeric_to_label_dict + # Set bird-specific settings if loading the bird model + if model_name == "weecology/deepforest-bird": + self.config['retinanet']["score_thresh"] = 0.3 + self.label_dict = {"Bird": 0} + self.numeric_to_label_dict = {v: k for k, v in self.label_dict.items()} + + return self + def use_release(self, check_release=True): """Use the latest DeepForest model release from github and load model. Optionally download if release doesn't exist. @@ -126,20 +160,10 @@ def use_release(self, check_release=True): Returns: model (object): A trained PyTorch model """ - # Download latest model from github release - release_tag, self.release_state_dict = utilities.use_release( - check_release=check_release) - if self.config["architecture"] != "retinanet": - warnings.warn( - "The config file specifies architecture {}, but the release model is torchvision retinanet. Reloading main.deepforest with a retinanet model" - .format(self.config["architecture"])) - self.config["architecture"] = "retinanet" - self.create_model() - self.model.load_state_dict(torch.load(self.release_state_dict, weights_only=True)) - # load saved model and tag release - self.__release_version__ = release_tag - print("Loading pre-built model: {}".format(release_tag)) + warnings.warn("use_release will be deprecated in 2.0. use load_model() instead", + DeprecationWarning) + self.load_model('weecology/deepforest-tree') def use_bird_release(self, check_release=True): """Use the latest DeepForest bird model release from github and load @@ -150,21 +174,11 @@ def use_bird_release(self, check_release=True): Returns: model (object): A trained pytorch model """ - # Download latest model from github release - release_tag, self.release_state_dict = utilities.use_bird_release( - check_release=check_release) - self.model.load_state_dict(torch.load(self.release_state_dict, weights_only=True)) - - # load saved model and tag release - self.__release_version__ = release_tag - print("Loading pre-built model: {}".format(release_tag)) - - print("Setting default score threshold to 0.3") - self.config["score_thresh"] = 0.3 - # Set label dictionary to Bird - self.label_dict = {"Bird": 0} - self.numeric_to_label_dict = {v: k for k, v in self.label_dict.items()} + warnings.warn( + "use_bird_release will be deprecated in 2.0. use load_model('bird') instead", + DeprecationWarning) + self.load_model('weecology/deepforest-bird') def create_model(self): """Define a deepforest architecture. This can be done in two ways. diff --git a/deepforest/utilities.py b/deepforest/utilities.py index ebdf492b..dbbf0e96 100644 --- a/deepforest/utilities.py +++ b/deepforest/utilities.py @@ -55,82 +55,6 @@ def update_to(self, b=1, bsize=1, tsize=None): self.update(b * bsize - self.n) -def fetch_model(save_dir, repo_id, model_filename, version="main"): - """Downloads a model from Hugging Face and saves it to a specified - directory. - - Parameters: - - save_dir (str): The directory where the model will be saved. - - repo_id (str): The ID of the Hugging Face repository (e.g., "weecology/deepforest-tree"). - - model_filename (str): The name of the model file in the repository (e.g., "NEON.pt"). - - version (str): The version or branch of the model to download (e.g., "main", "v1.0.0"). Default is "main". - - Returns: - - output_path (str): The path where the model is saved. - """ - # Ensure the save directory exists - os.makedirs(save_dir, exist_ok=True) - - # Define the output path - output_path = os.path.join(save_dir, model_filename) - - try: - # Download the model from Hugging Face - hf_hub_download( - repo_id=repo_id, - filename=model_filename, - local_dir=save_dir, - revision=version # Specify the version or branch of the model to download - ) - print(f"Model saved to: {output_path}") - except RevisionNotFoundError as e: - print(f"Error: {e}") - print( - f"Check that the file '{model_filename}' and revision '{version}' exist in the repository '{repo_id}'." - ) - except HfHubHTTPError as e: - print(f"HTTP Error: {e}") - print( - "There might be a problem with your internet connection or the file may not exist." - ) - except Exception as e: - print(f"An unexpected error occurred: {e}") - - return version, output_path - - -def use_bird_release( - save_dir=os.path.join(_ROOT, "data/"), prebuilt_model="bird", check_release=True): - """ - Check the existence of, or download the latest model release from github - Args: - save_dir: Directory to save filepath, default to "data" in deepforest repo - prebuilt_model: Currently only accepts "NEON", but could be expanded to include other prebuilt models. The local model will be called prebuilt_model.h5 on disk. - check_release (logical): whether to check github for a model recent release. In cases where you are hitting the github API rate limit, set to False and any local model will be downloaded. If no model has been downloaded an error will raise. - Returns: release_tag, output_path (str): path to downloaded model - """ - return fetch_model(save_dir, - repo_id="weecology/deepforest-bird", - model_filename="bird.pt") - - -def use_release( - save_dir=os.path.join(_ROOT, "data/"), prebuilt_model="NEON", check_release=True): - """ - Check the existence of, or download the latest model release from github - Args: - save_dir: Directory to save filepath, default to "data" in deepforest repo - prebuilt_model: Currently only accepts "NEON", but could be expanded to include other prebuilt models. The local model will be called prebuilt_model.h5 on disk. - check_release (logical): whether to check github for a model recent release. In cases where you are hitting the github API rate limit, set to False and any local model will be downloaded. If no model has been downloaded an error will raise. - - Returns: release_tag, output_path (str): path to downloaded model - - """ - return fetch_model(save_dir, - repo_id="weecology/deepforest-tree", - model_filename="NEON.pt") - - def read_pascal_voc(xml_path): """Load annotations from xml format (e.g. RectLabel editor) and convert them into retinanet annotations format. diff --git a/dev_requirements.txt b/dev_requirements.txt index dc07f611..95a3df03 100644 --- a/dev_requirements.txt +++ b/dev_requirements.txt @@ -29,6 +29,7 @@ pyyaml>=5.1.0 rasterio recommonmark rtree +safetensors slidingwindow sphinx sphinx_markdown_tables diff --git a/docs/getting_started/getting_started.md b/docs/getting_started/getting_started.md index 5491744e..0993d54d 100644 --- a/docs/getting_started/getting_started.md +++ b/docs/getting_started/getting_started.md @@ -12,8 +12,7 @@ from deepforest.visualize import plot_results import os model = main.deepforest() -model.use_release() - +model.load_model(model_name="weecology/deepforest-tree", revision="main") sample_image_path = get_data("OSBS_029.png") results = model.predict_image(path=sample_image_path) plot_results(results) diff --git a/docs/getting_started/index.rst b/docs/getting_started/index.rst index 620b7428..f1ab6f05 100644 --- a/docs/getting_started/index.rst +++ b/docs/getting_started/index.rst @@ -6,5 +6,6 @@ Getting Started :caption: Contents: getting_started + model_loader Reading_and_Writing sample_test diff --git a/docs/getting_started/model_loader.md b/docs/getting_started/model_loader.md new file mode 100644 index 00000000..3c2119d1 --- /dev/null +++ b/docs/getting_started/model_loader.md @@ -0,0 +1,45 @@ +# DeepForest Model Loader + +This function loads pretrained DeepForest models from Hugging Face, with support for different model revisions. Additionally, you can save the model configuration and weights using `save_pretrained` and reload it later with `from_pretrained`. + +## `load_model` + +### Description + +The `load_model` function loads a pretrained model from Hugging Face using the repository name (`model_name`) and the desired model version (`revision`). This is useful for tasks such as tree crown detection, but it can also load bird detection models with custom configurations. + +### Arguments + +- `model_name` (str): A repository ID for Hugging Face in the form `organization/repository`. Default is `"weecology/deepforest-tree"`. + + you can choose from: + - weecology/deepforest-tree + - weecology/deepforest-bird + - weecology/deepforest-livestock + - weecology/everglades-nest-detection + - weecology/cropmodel-deadtrees +- `revision` (str): The model version (e.g., 'main', 'v1.0.0', etc.). Default is `'main'`. + +### Returns + +- `` (object): A trained PyTorch model with its configuration and weights. + +### Example Usage + +#### Load a Model + +```python +from deepforest import main +from deepforest import get_data +import matplotlib.pyplot as plt + +# Initialize the model class +model = main.deepforest() + +# Load a pretrained tree detection model from Hugging Face +model.load_model(model_name="weecology/deepforest-tree", revision="main") + +sample_image_path = get_data("OSBS_029.png") +img = model.predict_image(path=sample_image_path, return_plot=True) + +``` diff --git a/docs/getting_started/sample_test.ipynb b/docs/getting_started/sample_test.ipynb index cdf7179e..12f58d58 100644 --- a/docs/getting_started/sample_test.ipynb +++ b/docs/getting_started/sample_test.ipynb @@ -26,8 +26,8 @@ "from deepforest import get_data\n", "import os\n", "\n", - "#model = main.deepforest()\n", - "#model.use_release()\n", + "# model = main.deepforest()\n", + "# model.load_model(model_name=\"weecology/deepforest-tree\", revision=\"main\")\n", "\n", "#image_path = get_data('OSBS_029.tif')\n", "#root_dir = os.path.dirname(image_path)\n", @@ -35,6 +35,24 @@ "#results = model.predict_image(path=image_path)\n", "#plot_results(results, root_dir=root_dir)" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# To save a model you can use \n", + "# model.save_pretrained(\"Path/to/model/folder\")\n", + "# This will give config.json for the configuration and model.safetensors for the weights" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { diff --git a/environment.yml b/environment.yml index ceef82a8..9b8965d1 100644 --- a/environment.yml +++ b/environment.yml @@ -49,4 +49,5 @@ dependencies: - Pygments - docformatter - opencv-python-headless + - safetensors diff --git a/setup.py b/setup.py index 48717117..5fc0865d 100644 --- a/setup.py +++ b/setup.py @@ -63,11 +63,8 @@ license=LICENCE, packages=find_packages(), include_package_data=True, - install_requires=[ - "albumentations>=1.0.0", "aiolimiter", "aiohttp", "docformatter", "huggingface_hub", "geopandas", "matplotlib", "nbqa", "numpy", - "opencv-python-headless>=4.5.4", "pandas", "Pillow>6.2.0", "progressbar2", "pycocotools", "pydata-sphinx-theme", "Pygments", - "pytorch-lightning>=1.5.8", "rasterio", "recommonmark", "rtree", "scipy>1.5", - "six", "slidingwindow", "sphinx", "supervision", "torch", "torchvision>=0.13", "tqdm", - "xmltodict","geopandas" - ], - zip_safe=False) + install_requires=['albumentations>=1.0.0', 'aiolimiter', 'aiohttp', 'docformatter', 'huggingface_hub', + 'geopandas', 'matplotlib', 'nbqa', 'numpy', 'opencv-python-headless>=4.5.4', 'pandas', 'Pillow>6.2.0', + 'progressbar2', 'pycocotools', "pydata-sphinx-theme", 'Pygments', 'pytorch-lightning>=1.5.8', 'rasterio', + 'recommonmark', 'rtree', 'safetensors', 'scipy>1.5', 'six', 'slidingwindow', 'sphinx', 'supervision', 'torch', + 'torchvision>=0.13', 'tqdm', 'xmltodict', 'geopandas'],zip_safe=False) diff --git a/tests/conftest.py b/tests/conftest.py index 61e6c1e0..ae69cc42 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -23,11 +23,10 @@ def config(): def download_release(): print("running fixtures") try: - utilities.use_release() + main.deepforest().load_model() except urllib.error.URLError: # Add a edge case in case no internet access. pass - assert os.path.exists(get_data("NEON.pt")) @pytest.fixture(scope="session") diff --git a/tests/test_main.py b/tests/test_main.py index c04b872e..a475fc2e 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -62,8 +62,7 @@ def m(download_release): m.config["train"]["epochs"] = 2 m.create_trainer() - m.use_release(check_release=False) - + m.load_model() return m @@ -140,6 +139,12 @@ def test_use_bird_release(m): boxes = m.predict_image(path=imgpath) assert not boxes.empty +def test_load_model(m): + imgpath = get_data("OSBS_029.png") + m.load_model('ethanwhite/df-test') + boxes = m.predict_image(path=imgpath) + assert not boxes.empty + def test_train_empty(m, tmpdir): empty_csv = pd.DataFrame({ diff --git a/tests/test_utilities.py b/tests/test_utilities.py index 5a97e11c..de16f599 100644 --- a/tests/test_utilities.py +++ b/tests/test_utilities.py @@ -30,18 +30,7 @@ def test_read_pascal_voc(): annotations = utilities.read_pascal_voc(xml_path=get_data("OSBS_029.xml")) print(annotations.shape) assert annotations.shape[0] == 61 - - -def test_use_release(download_release): - # Download latest model from github release - release_tag, state_dict = utilities.use_release(check_release=False) - - -def test_use_bird_release(download_release): - # Download latest model from github release - release_tag, state_dict = utilities.use_bird_release() - assert os.path.exists(get_data("bird.pt")) - + def test_float_warning(config): """Users should get a rounding warning when adding annotations with floats"""