forked from huggingface/candle
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add DINOv2Reg4 + PlantCLEF2024 (huggingface#2293)
* Add: DINOv2Reg4 with PlantCLEF2024 weights and example ( See https://arxiv.org/abs/2309.16588 and https://zenodo.org/records/10848263 ) * Remove extra files + update README to download them + remove extra lines * minor fix (README remove extra spaces) * minor fix (README: Fix image url) * Modif: Add back interpolate_pos_encoding() + fix when no interpolation + remove extra comments + Update README ( source image changed and so the predictions ) * Fix: Improve code lisibility with '$ cargo clippy' and '$ cargo fmt' * Another clippy fix. --------- Co-authored-by: x-VEspit <[email protected]> Co-authored-by: laurent <[email protected]>
- Loading branch information
1 parent
a3dd87f
commit e27aac0
Showing
5 changed files
with
395 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
# candle-dinov2-reg4 | ||
|
||
[DINOv2-reg4](https://arxiv.org/abs/2309.16588) is the lastest version of DINOv2 with registers. | ||
In this example, it is used as an plant species classifier: the model returns the | ||
probability for the image to belong to each of the 7806 PlantCLEF2024 categories. | ||
|
||
## Running some example | ||
|
||
```bash | ||
# Download classes names and a plant picture to identify | ||
curl https://huggingface.co/vincent-espitalier/dino-v2-reg4-with-plantclef2024-weights/raw/main/species_id_mapping.txt --output candle-examples/examples/dinov2reg4/species_id_mapping.txt | ||
curl https://bs.plantnet.org/image/o/bd2d3830ac3270218ba82fd24e2290becd01317c --output candle-examples/examples/dinov2reg4/bd2d3830ac3270218ba82fd24e2290becd01317c.jpg | ||
|
||
# Perform inference | ||
cargo run --example dinov2reg4 --release -- --image candle-examples/examples/dinov2reg4/bd2d3830ac3270218ba82fd24e2290becd01317c.jpg | ||
|
||
> Orchis simia Lam. : 45.55% | ||
> Orchis × bergonii Nanteuil: 9.80% | ||
> Orchis italica Poir. : 9.66% | ||
> Orchis × angusticruris Franch.: 2.76% | ||
> Orchis × bivonae Tod. : 2.54% | ||
|
||
``` | ||
|
||
![Orchis Simia](https://bs.plantnet.org/image/o/bd2d3830ac3270218ba82fd24e2290becd01317c) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
//! DINOv2 reg4 finetuned on PlantCLEF 2024 | ||
//! https://arxiv.org/abs/2309.16588 | ||
//! https://huggingface.co/spaces/BVRA/PlantCLEF2024 | ||
//! https://zenodo.org/records/10848263 | ||
|
||
#[cfg(feature = "mkl")] | ||
extern crate intel_mkl_src; | ||
|
||
#[cfg(feature = "accelerate")] | ||
extern crate accelerate_src; | ||
|
||
use clap::Parser; | ||
|
||
use candle::{DType, IndexOp, D}; | ||
use candle_nn::{Module, VarBuilder}; | ||
use candle_transformers::models::dinov2reg4; | ||
|
||
#[derive(Parser)] | ||
struct Args { | ||
#[arg(long)] | ||
model: Option<String>, | ||
|
||
#[arg(long)] | ||
image: String, | ||
|
||
/// Run on CPU rather than on GPU. | ||
#[arg(long)] | ||
cpu: bool, | ||
} | ||
|
||
pub fn main() -> anyhow::Result<()> { | ||
let args = Args::parse(); | ||
|
||
let device = candle_examples::device(args.cpu)?; | ||
|
||
let image = candle_examples::imagenet::load_image518(args.image)?.to_device(&device)?; | ||
println!("loaded image {image:?}"); | ||
|
||
let f_species_id_mapping = "candle-examples/examples/dinov2reg4/species_id_mapping.txt"; | ||
let classes: Vec<String> = std::fs::read_to_string(f_species_id_mapping) | ||
.expect("missing classes file") | ||
.split('\n') | ||
.map(|s| s.to_string()) | ||
.collect(); | ||
|
||
let model_file = match args.model { | ||
None => { | ||
let api = hf_hub::api::sync::Api::new()?; | ||
let api = | ||
api.model("vincent-espitalier/dino-v2-reg4-with-plantclef2024-weights".into()); | ||
api.get( | ||
"vit_base_patch14_reg4_dinov2_lvd142m_pc24_onlyclassifier_then_all.safetensors", | ||
)? | ||
} | ||
Some(model) => model.into(), | ||
}; | ||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? }; | ||
let model = dinov2reg4::vit_base(vb)?; | ||
println!("model built"); | ||
let logits = model.forward(&image.unsqueeze(0)?)?; | ||
let prs = candle_nn::ops::softmax(&logits, D::Minus1)? | ||
.i(0)? | ||
.to_vec1::<f32>()?; | ||
let mut prs = prs.iter().enumerate().collect::<Vec<_>>(); | ||
prs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1)); | ||
for &(category_idx, pr) in prs.iter().take(5) { | ||
println!("{:24}: {:.2}%", classes[category_idx], 100. * pr); | ||
} | ||
Ok(()) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.