Skip to content

Commit

Permalink
test
Browse files Browse the repository at this point in the history
  • Loading branch information
Mu-Magdy committed Sep 20, 2024
1 parent 0ffc838 commit df5f8c8
Showing 1 changed file with 76 additions and 0 deletions.
76 changes: 76 additions & 0 deletions deepforest/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,82 @@ def update_to(self, b=1, bsize=1, tsize=None):
self.update(b * bsize - self.n)


def fetch_model(save_dir, repo_id, model_filename, version="main"):
"""Downloads a model from Hugging Face and saves it to a specified
directory.
Parameters:
- save_dir (str): The directory where the model will be saved.
- repo_id (str): The ID of the Hugging Face repository (e.g., "weecology/deepforest-tree").
- model_filename (str): The name of the model file in the repository (e.g., "NEON.pt").
- version (str): The version or branch of the model to download (e.g., "main", "v1.0.0"). Default is "main".
Returns:
- output_path (str): The path where the model is saved.
"""
# Ensure the save directory exists
os.makedirs(save_dir, exist_ok=True)

# Define the output path
output_path = os.path.join(save_dir, model_filename)

try:
# Download the model from Hugging Face
hf_hub_download(
repo_id=repo_id,
filename=model_filename,
local_dir=save_dir,
revision=version # Specify the version or branch of the model to download
)
print(f"Model saved to: {output_path}")
except RevisionNotFoundError as e:
print(f"Error: {e}")
print(
f"Check that the file '{model_filename}' and revision '{version}' exist in the repository '{repo_id}'."
)
except HfHubHTTPError as e:
print(f"HTTP Error: {e}")
print(
"There might be a problem with your internet connection or the file may not exist."
)
except Exception as e:
print(f"An unexpected error occurred: {e}")

return version, output_path


def use_bird_release(
save_dir=os.path.join(_ROOT, "data/"), prebuilt_model="bird", check_release=True):
"""
Check the existence of, or download the latest model release from github
Args:
save_dir: Directory to save filepath, default to "data" in deepforest repo
prebuilt_model: Currently only accepts "NEON", but could be expanded to include other prebuilt models. The local model will be called prebuilt_model.h5 on disk.
check_release (logical): whether to check github for a model recent release. In cases where you are hitting the github API rate limit, set to False and any local model will be downloaded. If no model has been downloaded an error will raise.
Returns: release_tag, output_path (str): path to downloaded model
"""
return fetch_model(save_dir,
repo_id="weecology/deepforest-bird",
model_filename="bird.pt")


def use_release(
save_dir=os.path.join(_ROOT, "data/"), prebuilt_model="NEON", check_release=True):
"""
Check the existence of, or download the latest model release from github
Args:
save_dir: Directory to save filepath, default to "data" in deepforest repo
prebuilt_model: Currently only accepts "NEON", but could be expanded to include other prebuilt models. The local model will be called prebuilt_model.h5 on disk.
check_release (logical): whether to check github for a model recent release. In cases where you are hitting the github API rate limit, set to False and any local model will be downloaded. If no model has been downloaded an error will raise.
Returns: release_tag, output_path (str): path to downloaded model
"""
return fetch_model(save_dir,
repo_id="weecology/deepforest-tree",
model_filename="NEON.pt")


def read_pascal_voc(xml_path):
"""Load annotations from xml format (e.g. RectLabel editor) and convert
them into retinanet annotations format.
Expand Down

0 comments on commit df5f8c8

Please sign in to comment.