diff --git a/.gitignore b/.gitignore
index efd6665..58f4a07 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,3 +1,3 @@
build/
-upn.egg-info/
+chatrex.egg-info/
checkpoints/
\ No newline at end of file
diff --git a/README.md b/README.md
index 0e61d1e..3736963 100644
--- a/README.md
+++ b/README.md
@@ -1,60 +1,118 @@
-

+
- [](https://arxiv.org/pdf/2403.14610.pdf) [](https://hits.seeyoufarm.com)
+ [](https://arxiv.org/pdf/2403.14610.pdf) [](https://hits.seeyoufarm.com)
+
+ Introduction •
+ Universal Proposal Network •
+ ChatRex •
+ Rexverse-2M Dataset
+
+
+
+
+----
+
# 1. Introduction 📚
-**TL;DR: UPN is an object proposal model that can detect any object without any prompt input.**
+**TL;DR: ChatRex is a MLLM skilled in perception that can respond to questions while simultaneously grounding its answers to the referenced objects.**
-Universal Proposal Network (UPN) is a robust object proposal model designed as part of ChatRex to enable comprehensive and accurate object detection across diverse granularities and domains. Built upon T-Rex2, UPN is a DETR-based model with a dual-granularity prompt tuning strategy, combining fine-grained (e.g., part-level) and coarse-grained (e.g., instance-level) detection.
+ChatRex is a Multimodal Large Language Model (MLLM) designed to seamlessly integrate fine-grained object perception and robust language understanding. By adopting a decoupled architecture with a retrieval-based approach for object detection and leveraging high-resolution visual inputs, ChatRex addresses key challenges in perception tasks. It is powered by the Rexverse-2M dataset with diverse image-region-text annotations. ChatRex can be applied to various scenarios requiring fine-grained perception, such as object detection, grounded conversation, grounded image captioning and region
+understanding.
-

+
+----
+
# 2. Installation 🛠️
```bash
-conda install -n upn python=3.9
+conda install -n chatrex python=3.9
pip install torch==2.1.2 torchvision==0.16.2 --index-url https://download.pytorch.org/whl/cu121
-pip install -v- e .
-# install deformable attention
-cd upn/ops
+pip install -v -e .
+# install deformable attention for universal proposal network
+cd chatrex/upn/ops
pip install -v -e .
```
-To verify the installation, run the following command:
+## 2.1 Download Pre-trained Models
+We provide model checkpoints for both the ***Universal Proposal Network (UPN)*** and the ***ChatRex model***. You can download the pre-trained models from the following links:
+- [UPN Checkpoint](https://drive.google)
+- [ChatRex-7B Checkpoint](https://huggingface.co/IDEA-Research/ChatRex-7B)
+
+Or you can also using the following command to download the pre-trained models:
```bash
-python tests/test_install.py
+mkdir checkpoints
+mkdir checkpoints/upn
+# download UPN checkpoint
+wget -O checkpoints/upn/upn_large.pth https://drive.google.com/file/d/
+# download ChatRex checkpoint from huggingface IDEA-Research/ChatRex-7B
+# Download ChatRex checkpoint from Hugging Face
+git lfs install
+git clone https://huggingface.co/IDEA-Research/ChatRex-7B checkpoints/chatrex
```
-If the installation is successful, you will get two visualizations of the model's output in `tests` folder.
+## 2.2 Verify Installation
+To verify the ***installation of the Universal Proposal Network (UPN)***, run the following command:
+```bash
+python tests/test_upn_install.py
+```
+If the installation is successful, you will get two visualization images of both fine-grained proposal and coarse-grained proposal in `tests` folder.
+
+To verify the ***installation of the ChatRex model***, run the following command:
+```bash
+python tests/test_chatrex_install.py
+```
+
+If the installation is successful, you will get an output like this:
+```text
+prediction: shows a brown dog lying on a bed. The dog is resting comfortably, possibly sleeping, and is positioned on the left side of the bed
+```
# 3. Usage 🚀
+## 3.1 Use UPN for Object Proposal Generation
+
+Universal Proposal Network (UPN) is a robust object proposal model designed as part of ChatRex to enable comprehensive and accurate object detection across diverse granularities and domains. Built upon T-Rex2, UPN is a DETR-based model with a dual-granularity prompt tuning strategy, combining fine-grained (e.g., part-level) and coarse-grained (e.g., instance-level) detection.
+
+
+

+
+
+----
+
+
+Example Code for UPN
+
```python
import torch
from PIL import Image
from tools.visualize import plot_boxes_to_image
-from upn import UPNWrapper
+from chatrex.upn import UPNWrapper
-ckpt_path = "checkpoints/upn_large.pth"
-test_image_path = "tests/test_image.jpeg"
+ckpt_path = "checkpoints/upn_checkpoints/upn_large.pth"
+test_image_path = "tests/images/test_upn.jpeg"
model = UPNWrapper(ckpt_path)
# fine-grained prompt
fine_grained_proposals = model.inference(
test_image_path, prompt_type="fine_grained_prompt"
)
+# filter by score (default: 0.3) and nms (default: 0.8)
fine_grained_filtered_proposals = model.filter(
fine_grained_proposals, min_score=0.3, nms_value=0.8
)
+## output is a dict with keys: "original_xyxy_boxes", "scores"
+## - "original_xyxy_boxes": list of boxes in xyxy format in shape (B, N, 4)
+## - "scores": list of scores for each box in shape (B, N)
# coarse-grained prompt
coarse_grained_proposals = model.inference(
@@ -63,8 +121,475 @@ coarse_grained_proposals = model.inference(
coarse_grained_filtered_proposals = model.filter(
coarse_grained_proposals, min_score=0.3, nms_value=0.8
)
+
+## output is a dict with keys: "original_xyxy_boxes", "scores"
+## - "original_xyxy_boxes": list of boxes in xyxy format in shape (B, N, 4)
+## - "scores": list of scores for each box in shape (B, N)
+```
+
+
+
+We also provide a visualization tool to visualize the object proposals generated by UPN. You can use the following code to visualize the object proposals:
+
+
+Example Code for UPN Visualization
+
+```python
+
+from chatrex.tools.visualize import plot_boxes_to_image
+image = Image.open(test_image_path)
+fine_grained_vis_image, _ = plot_boxes_to_image(
+ image.copy(),
+ fine_grained_filtered_proposals["original_xyxy_boxes"][0],
+ fine_grained_filtered_proposals["scores"][0],
+)
+fine_grained_vis_image.save("tests/test_image_fine_grained.jpeg")
+print(f"fine-grained proposal is saved at tests/test_image_fine_grained.jpeg")
+
+coarse_grained_vis_image, _ = plot_boxes_to_image(
+ image.copy(),
+ coarse_grained_filtered_proposals["original_xyxy_boxes"][0],
+ coarse_grained_filtered_proposals["scores"][0],
+)
+coarse_grained_vis_image.save("tests/test_image_coarse_grained.jpeg")
+print(f"coarse-grained proposal is saved at tests/test_image_coarse_grained.jpeg")
+
+```
+
+
+## 3.2 Usage of ChatRex
+
+ChatRex takes three inputs: image, text prompt, and box input. For the box input, you can either use the object proposals generated by UPN or provide your own box input (user drawn boxes). We have wrapped the ChatRex model to huggingface transformers format for easy usage. ChatRex can be used for various tasks and we provide example code for each task below.
+
+### 3.2.1 ChatRex for Object Detection & Grounding & Referring
+
+Example Prompt for detection, grounding, referring tasks:
+```text
+# Single Object Detection
+Please detect dog in this image. Answer the question with object indexes.
+Please detect the man in yellow shirt in this image. Answer the question with object indexes.
+
+# multiple object detection, use ; to separate the objects
+Please detect person; pigeon in this image. Answer the question with object indexes.
+Please detect person in the car; cat below the table in this image. Answer the question with object indexes.
+```
+
+
+Example Code
+
+- [Example Code in python file](tests/test_chatrex_detection.py)
+
+```python
+import torch
+from PIL import Image
+from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig
+
+from chatrex.tools.visualize import visualize_chatrex_output
+from chatrex.upn import UPNWrapper
+
+if __name__ == "__main__":
+ # load the processor
+ processor = AutoProcessor.from_pretrained(
+ "checkpoints/chatrex7b",
+ trust_remote_code=True,
+ device_map="cuda",
+ )
+
+ print(f"loading chatrex model...")
+ # load chatrex model
+ model = AutoModelForCausalLM.from_pretrained(
+ "checkpoints/chatrex7b",
+ trust_remote_code=True,
+ use_safetensors=True,
+ ).to("cuda")
+
+ # load upn model
+ print(f"loading upn model...")
+ ckpt_path = "checkpoints/upn_checkpoints/upn_large.pth"
+ model_upn = UPNWrapper(ckpt_path)
+ test_image_path = "tests/images/test_chatrex_detection.jpg"
+
+ # get upn predictions
+ fine_grained_proposals = model_upn.inference(
+ test_image_path, prompt_type="fine_grained_prompt"
+ )
+ fine_grained_filtered_proposals = model_upn.filter(
+ fine_grained_proposals, min_score=0.3, nms_value=0.8
+ )
+
+ inputs = processor.process(
+ image=Image.open(test_image_path),
+ question="Please detect person; pigeon in this image. Answer the question with object indexes.",
+ bbox=fine_grained_filtered_proposals["original_xyxy_boxes"][
+ 0
+ ], # box in xyxy format
+ )
+
+ inputs = {k: v.to("cuda") for k, v in inputs.items()}
+
+ # perform inference
+ gen_config = GenerationConfig(
+ max_new_tokens=512,
+ do_sample=False,
+ eos_token_id=processor.tokenizer.eos_token_id,
+ pad_token_id=(
+ processor.tokenizer.pad_token_id
+ if processor.tokenizer.pad_token_id is not None
+ else processor.tokenizer.eos_token_id
+ ),
+ )
+ with torch.autocast(device_type="cuda", enabled=True, dtype=torch.bfloat16):
+ prediction = model.generate(
+ inputs, gen_config=gen_config, tokenizer=processor.tokenizer
+ )
+ print(f"prediction:", prediction)
+
+ # visualize the prediction
+ vis_image = visualize_chatrex_output(
+ Image.open(test_image_path),
+ fine_grained_filtered_proposals["original_xyxy_boxes"][0],
+ prediction,
+ font_size=15,
+ draw_width=5,
+ )
+ vis_image.save("tests/test_chatrex_detection.jpeg")
+ print(f"prediction is saved at tests/test_chatrex_detection.jpeg")
+```
+
+The output from LLM is like:
+```text
+person
+pigeon
+```
+
+The visualization of the output is like:
+
+
+

+
+
+
+
+----
+
+### 3.2.2 ChatRex for Region Caption
+Example Prompt for Region Caption tasks:
+
+```text
+# Single Object Detection
+## caption in category name
+What is the category name of ? Answer the question with its category name in free format.
+
+## caption in short phrase
+Can you provide me with a short phrase to describe ? Answer the question with a short phrase.
+
+## caption in referring style
+Can you provide me with a brief description of ? Answer the question with brief description.
+
+## caption in one sentence
+Can you provide me with a one sentence of ? Answer the question with one sentence description.
+
+# multiple object detection, use ; to separate the objects
+```
+
+
+Example Code
+
+- [Example Code in python file](tests/test_chatrex_region_caption.py)
+
+```python
+import torch
+from PIL import Image
+from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig
+
+from chatrex.tools.visualize import visualize_chatrex_output
+from chatrex.upn import UPNWrapper
+
+if __name__ == "__main__":
+ # load the processor
+ processor = AutoProcessor.from_pretrained(
+ "checkpoints/chatrex7b",
+ trust_remote_code=True,
+ device_map="cuda",
+ )
+
+ print(f"loading chatrex model...")
+ # load chatrex model
+ model = AutoModelForCausalLM.from_pretrained(
+ "checkpoints/chatrex7b",
+ trust_remote_code=True,
+ use_safetensors=True,
+ ).to("cuda")
+
+ test_image_path = "tests/images/test_chatrex_install.jpg"
+
+ inputs = processor.process(
+ image=Image.open(test_image_path),
+ question="Can you provide a one sentence description of in the image? Answer the question with a one sentence description.",
+ bbox=[[73.88417, 56.62228, 227.69223, 216.34338]],
+ )
+
+ inputs = {k: v.to("cuda") for k, v in inputs.items()}
+
+ # perform inference
+ gen_config = GenerationConfig(
+ max_new_tokens=512,
+ do_sample=False,
+ eos_token_id=processor.tokenizer.eos_token_id,
+ pad_token_id=(
+ processor.tokenizer.pad_token_id
+ if processor.tokenizer.pad_token_id is not None
+ else processor.tokenizer.eos_token_id
+ ),
+ )
+ with torch.autocast(device_type="cuda", enabled=True, dtype=torch.bfloat16):
+ prediction = model.generate(
+ inputs, gen_config=gen_config, tokenizer=processor.tokenizer
+ )
+ print(f"prediction:", prediction)
+
+ # visualize the prediction
+ vis_image = visualize_chatrex_output(
+ Image.open(test_image_path),
+ [[73.88417, 56.62228, 227.69223, 216.34338]],
+ prediction,
+ font_size=15,
+ draw_width=5,
+ )
+ vis_image.save("tests/test_chatrex_region_caption.jpeg")
+ print(f"prediction is saved at tests/test_chatrex_region_caption.jpeg")
+```
+
+The output from LLM is like:
+```text
+A brown dog is lying on a bed, appearing relaxed and comfortable
+```
+
+The visualization of the output is like:
+
+
+

+
+
+
+
+----
+
+### 3.2.3 ChatRex for Grounded Image Captioning
+Example Prompt for Region Caption tasks:
+
+```text
+# Brief Grounded Imager Caption
+Please breifly describe this image in one sentence and detect all the mentioned objects. Answer the question with grounded answer.
+
+# Detailed Grounded Image Caption
+Please provide a detailed description of the image and detect all the mentioned objects. Answer the question with grounded object indexes.
+```
+
+
+Example Code
+
+- [Example Code in python file](tests/test_chatrex_grounded_image_caption.py)
+
+```python
+import torch
+from PIL import Image
+from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig
+
+from chatrex.tools.visualize import visualize_chatrex_output
+from chatrex.upn import UPNWrapper
+
+if __name__ == "__main__":
+ # load the processor
+ processor = AutoProcessor.from_pretrained(
+ "checkpoints/chatrex7b",
+ trust_remote_code=True,
+ device_map="cuda",
+ )
+
+ print(f"loading chatrex model...")
+ # load chatrex model
+ model = AutoModelForCausalLM.from_pretrained(
+ "checkpoints/chatrex7b",
+ trust_remote_code=True,
+ use_safetensors=True,
+ ).to("cuda")
+
+ # load upn model
+ print(f"loading upn model...")
+ ckpt_path = "checkpoints/upn_checkpoints/upn_large.pth"
+ model_upn = UPNWrapper(ckpt_path)
+ test_image_path = "tests/images/test_chatrex_grounded_caption.jpg"
+
+ # get upn predictions
+ fine_grained_proposals = model_upn.inference(
+ test_image_path, prompt_type="fine_grained_prompt"
+ )
+ fine_grained_filtered_proposals = model_upn.filter(
+ fine_grained_proposals, min_score=0.3, nms_value=0.8
+ )
+
+ inputs = processor.process(
+ image=Image.open(test_image_path),
+ question="Please breifly describe this image in one sentence and detect all the mentioned objects. Answer the question with grounded answer.",
+ bbox=fine_grained_filtered_proposals["original_xyxy_boxes"][
+ 0
+ ], # box in xyxy format
+ )
+
+ inputs = {k: v.to("cuda") for k, v in inputs.items()}
+
+ # perform inference
+ gen_config = GenerationConfig(
+ max_new_tokens=512,
+ do_sample=False,
+ eos_token_id=processor.tokenizer.eos_token_id,
+ pad_token_id=(
+ processor.tokenizer.pad_token_id
+ if processor.tokenizer.pad_token_id is not None
+ else processor.tokenizer.eos_token_id
+ ),
+ )
+ with torch.autocast(device_type="cuda", enabled=True, dtype=torch.bfloat16):
+ prediction = model.generate(
+ inputs, gen_config=gen_config, tokenizer=processor.tokenizer
+ )
+ print(f"prediction:", prediction)
+
+ # visualize the prediction
+ vis_image = visualize_chatrex_output(
+ Image.open(test_image_path),
+ fine_grained_filtered_proposals["original_xyxy_boxes"][0],
+ prediction,
+ font_size=15,
+ draw_width=5,
+ )
+ vis_image.save("tests/test_chatrex_grounded_image_caption.jpeg")
+ print(f"prediction is saved at tests/test_chatrex_grounded_image_caption.jpeg")
```
+The output from LLM is like:
+```text
+The image depicts a cozy living room with a plaid couch, a wooden TV standholding a black television, a red armchair, and a whiteboardwith writing on the wall, accompanied by a framed posterof a couple.
+```
+
+The visualization of the output is like:
+
+
+

+
+
+
+
+----
+
+### 3.2.4 ChatRex for Grounded Conversation
+Example Prompt for Region Caption tasks:
+
+```text
+Answer the question in Grounded format. Question
+```
+
+
+Example Code
+
+- [Example Code in python file](tests/test_chatrex_grounded_conversation.py)
+
+```python
+import torch
+from PIL import Image
+from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig
+
+from chatrex.tools.visualize import visualize_chatrex_output
+from chatrex.upn import UPNWrapper
+
+if __name__ == "__main__":
+ # load the processor
+ processor = AutoProcessor.from_pretrained(
+ "checkpoints/chatrex7b",
+ trust_remote_code=True,
+ device_map="cuda",
+ )
+
+ print(f"loading chatrex model...")
+ # load chatrex model
+ model = AutoModelForCausalLM.from_pretrained(
+ "checkpoints/chatrex7b",
+ trust_remote_code=True,
+ use_safetensors=True,
+ ).to("cuda")
+
+ # load upn model
+ print(f"loading upn model...")
+ ckpt_path = "checkpoints/upn_checkpoints/upn_large.pth"
+ model_upn = UPNWrapper(ckpt_path)
+ test_image_path = "tests/images/test_grounded_conversation.jpg"
+
+ # get upn predictions
+ fine_grained_proposals = model_upn.inference(
+ test_image_path, prompt_type="coarse_grained_prompt"
+ )
+ fine_grained_filtered_proposals = model_upn.filter(
+ fine_grained_proposals, min_score=0.3, nms_value=0.8
+ )
+
+ inputs = processor.process(
+ image=Image.open(test_image_path),
+ question="Answer the question in grounded format. This is a photo of my room, and can you tell me what kind of person I am? ",
+ bbox=fine_grained_filtered_proposals["original_xyxy_boxes"][
+ 0
+ ], # box in xyxy format
+ )
+
+ inputs = {k: v.to("cuda") for k, v in inputs.items()}
+
+ # perform inference
+ gen_config = GenerationConfig(
+ max_new_tokens=512,
+ do_sample=False,
+ eos_token_id=processor.tokenizer.eos_token_id,
+ pad_token_id=(
+ processor.tokenizer.pad_token_id
+ if processor.tokenizer.pad_token_id is not None
+ else processor.tokenizer.eos_token_id
+ ),
+ )
+ with torch.autocast(device_type="cuda", enabled=True, dtype=torch.bfloat16):
+ prediction = model.generate(
+ inputs, gen_config=gen_config, tokenizer=processor.tokenizer
+ )
+ print(f"prediction:", prediction)
+
+ # visualize the prediction
+ vis_image = visualize_chatrex_output(
+ Image.open(test_image_path),
+ fine_grained_filtered_proposals["original_xyxy_boxes"][0],
+ prediction,
+ font_size=30,
+ draw_width=10,
+ )
+ vis_image.save("tests/test_chatrex_grounded_conversation.jpeg")
+ print(f"prediction is saved at tests/test_chatrex_grounded_conversation.jpeg")
+
+```
+
+The output from LLM is like:
+```text
+Based on the items in the image, it can be inferred that the person who owns this room has an interest in fitness and possibly enjoys reading. The presence of the dumbbell suggests a commitment to physical activity, while the book indicates a liking for literature or reading. The sneakers and the plush toy add a personal touch, suggesting that the person might also value comfort and perhaps has a playful or nostalgic side. However, without more context, it is not possible to accurately determine the individual's specific traits or personality.
+```
+
+The visualization of the output is like:
+
+
+

+
+
+
+
+----
+
+
+
# BibTeX 📚
```
@misc{jiang2024trex2,
diff --git a/assets/capability_overview.jpg b/assets/capability_overview.jpg
new file mode 100644
index 0000000..6cdae65
Binary files /dev/null and b/assets/capability_overview.jpg differ
diff --git a/assets/teaser.jpg b/assets/teaser.jpg
new file mode 100644
index 0000000..296b44a
Binary files /dev/null and b/assets/teaser.jpg differ
diff --git a/assets/test_chatrex_grounded_conversation.jpeg b/assets/test_chatrex_grounded_conversation.jpeg
new file mode 100644
index 0000000..c92cabf
Binary files /dev/null and b/assets/test_chatrex_grounded_conversation.jpeg differ
diff --git a/assets/upn_res.jpg b/assets/upn_res.jpg
new file mode 100644
index 0000000..011ff5a
Binary files /dev/null and b/assets/upn_res.jpg differ
diff --git a/assets/upn_vis.jpg b/assets/upn_vis.jpg
deleted file mode 100644
index 36d03fb..0000000
Binary files a/assets/upn_vis.jpg and /dev/null differ
diff --git a/assets/vis_output/test_chatrex_detection.jpeg b/assets/vis_output/test_chatrex_detection.jpeg
new file mode 100644
index 0000000..c2b132f
Binary files /dev/null and b/assets/vis_output/test_chatrex_detection.jpeg differ
diff --git a/assets/vis_output/test_chatrex_grounded_image_caption.jpeg b/assets/vis_output/test_chatrex_grounded_image_caption.jpeg
new file mode 100644
index 0000000..1b9092d
Binary files /dev/null and b/assets/vis_output/test_chatrex_grounded_image_caption.jpeg differ
diff --git a/assets/vis_output/test_chatrex_region_caption.jpeg b/assets/vis_output/test_chatrex_region_caption.jpeg
new file mode 100644
index 0000000..6b1f95f
Binary files /dev/null and b/assets/vis_output/test_chatrex_region_caption.jpeg differ
diff --git a/chatre b/chatre
new file mode 120000
index 0000000..1da3f52
--- /dev/null
+++ b/chatre
@@ -0,0 +1 @@
+/comp_robot/jiangqing/projects/2023/research/MG-LLaVA/checkpoints/chatre
\ No newline at end of file
diff --git a/chatrex/tools/Tahoma.ttf b/chatrex/tools/Tahoma.ttf
new file mode 100644
index 0000000..d204f95
Binary files /dev/null and b/chatrex/tools/Tahoma.ttf differ
diff --git a/chatrex/tools/prompt_templates.py b/chatrex/tools/prompt_templates.py
new file mode 100644
index 0000000..2dbc1a0
--- /dev/null
+++ b/chatrex/tools/prompt_templates.py
@@ -0,0 +1,174 @@
+REGION_CAPTION_SINGLE_REGION_FREE_FORMAT_CATEGORY_NAME_STAGE2 = [
+ "What is the category name of [OBJ]? Answer the question with its category name in free format.",
+ "What is the category name of [OBJ]? Answer with its category name.",
+ "Identify the category of [OBJ].",
+ "What would you call [OBJ]? Provide its category name.",
+ "Give the category name for [OBJ].",
+ "Classify [OBJ] and provide its category name.",
+ "What is the category of [OBJ]?",
+ "What is [OBJ]'s category name?",
+ "Describe [OBJ] with its category name.",
+ "What is the most fitting category name for [OBJ]?",
+ "Provide the category name for [OBJ].",
+ "Classify [OBJ] and name its category.",
+ "What is the category name for [OBJ]?",
+ "What is the proper name for [OBJ]?",
+ "Give the category for [OBJ].",
+ "What would you categorize [OBJ] as?",
+ "Name [OBJ] based on its category.",
+ "What is the category name for [OBJ]?",
+ "Give a category name for [OBJ].",
+ "Provide the category name for [OBJ].",
+ "What is the category of [OBJ]?",
+]
+
+REGION_CAPTION_SINGLE_REGION_SHORT_PHRASE_STAGE2 = [
+ "Can you provide me with a short phrase to describe [OBJ]? Answer the question with a short phrase.",
+ "Give a short phrase that describes [OBJ].",
+ "What short phrase can describe [OBJ]?",
+ "Provide a short phrase for [OBJ].",
+ "What is a short phrase that best represents [OBJ]?",
+ "Describe [OBJ] using only a short phrase.",
+ "Can you summarize [OBJ] with a short phrase?",
+ "What is the best short phrase for [OBJ]?",
+ "Offer a short phrase for [OBJ].",
+ "What short phrase would you use for [OBJ]?",
+ "Give a concise short phrase for [OBJ].",
+ "What short phrase summarizes [OBJ]?",
+ "Provide a suitable short phrase for [OBJ].",
+ "What is a fitting short phrase for [OBJ]?",
+ "What’s an appropriate short phrase for [OBJ]?",
+ "Give a quick short phrase to describe [OBJ].",
+ "Can you suggest a short phrase for [OBJ]?",
+ "What short phrase would you assign to [OBJ]?",
+ "Provide a quick short phrase for [OBJ].",
+ "What short phrase explains [OBJ]?",
+]
+
+
+REGION_CAPTION_SINGLE_REGION_BREIFLY_STAGE2 = [
+ "Can you provide me with a brief description of [OBJ]? Answer the question with brief description."
+ "Give a brief description of [OBJ].",
+ "What brief description can you provide for [OBJ]?",
+ "Provide a brief description for [OBJ].",
+ "What is a brief description of [OBJ]?",
+ "Describe [OBJ] in a few sentences.",
+ "Can you summarize [OBJ] with a brief description?",
+ "What is the best brief description for [OBJ]?",
+ "Offer a brief description for [OBJ].",
+ "How would you describe [OBJ] briefly?",
+ "What brief summary can describe [OBJ]?",
+ "Provide a suitable brief description for [OBJ].",
+ "What is a fitting brief description for [OBJ]?",
+ "What’s an appropriate brief description for [OBJ]?",
+ "Can you suggest a brief description for [OBJ]?",
+ "What brief description would you assign to [OBJ]?",
+ "What brief description explains [OBJ]?",
+]
+
+REGION_CAPTION_ONE_SENTENCE_STAGE2 = [
+ "Can you provide a one sentence description of [OBJ] in the image? Answer the question with a one sentence description.",
+ "Describe [OBJ] in one sentence.",
+ "What is a one sentence description of [OBJ]?",
+ "Provide a one sentence description for [OBJ].",
+ "Summarize [OBJ] in one sentence.",
+ "How would you describe [OBJ] in a single sentence?",
+ "Give a one sentence summary of [OBJ].",
+ "Describe [OBJ] using only one sentence.",
+ "What is the best one sentence description for [OBJ]?",
+ "Offer a one sentence description for [OBJ].",
+ "What is a fitting one sentence to describe [OBJ]?",
+ "What’s an appropriate one sentence description for [OBJ]?",
+ "Can you suggest a one sentence description for [OBJ]?",
+ "Provide a one sentence explanation for [OBJ].",
+ "What single sentence would describe [OBJ]?",
+ "Give a concise one sentence description of [OBJ].",
+ "What single sentence explains [OBJ]?",
+ "How would you summarize [OBJ] in one sentence?",
+]
+
+COUNTING_SINGLE_REGION_STAGE2 = [
+ "How many [OBJ] are there in this image? Answer the question with the number of objects and locate them with object indexes.",
+ "Can you count how many [OBJ] are present in the image? Provide the total count along with the corresponding object indexes.",
+ "Identify and count all the [OBJ] in this image. Please return the number of objects and their respective indexes.",
+ "How many instances of [OBJ] do you see in this image? Report the count and the indexes of each detected object.",
+ "Locate all [OBJ] in this image and count them. Provide the number along with their object indexes.",
+ "Please count the [OBJ] present in the image and provide their total number along with the index of each object.",
+ "Count the number of [OBJ] visible in the image and give the total along with the index of each object.",
+ "Determine how many [OBJ] are in this image and list their indexes along with the total count.",
+ "Find and count all [OBJ] in the image. Provide the object indexes along with the total number.",
+ "How many [OBJ] can you find in this image? Return the count and their respective object indexes.",
+]
+
+GROUNDING_SINGLE_REGION_STAGE2 = [
+ "Please detect [OBJ] in this image. Answer the question with object indexes.",
+ "Detect [OBJ] with object indexes.",
+ "Find [OBJ] in the image and provide the object indexes.",
+ "Detect [OBJ] in the image and return the object indexes.",
+ "What are the object indexes for [OBJ] in this image?",
+ "Locate [OBJ] in the image and give the object indexes.",
+ "Identify [OBJ] in the image and provide the corresponding indexes.",
+ "Detect and return the indexes of [OBJ] in this image.",
+ "Please locate [OBJ] and answer with its object indexes.",
+ "Find the object indexes for [OBJ] in this image.",
+ "Can you detect [OBJ] in the image? Provide the object indexes.",
+ "Provide the object indexes for [OBJ] detected in this image.",
+ "What are the indexes for [OBJ] detected in the image?",
+ "Find [OBJ] in this image and answer with its indexes.",
+ "Can you locate [OBJ] and provide the object indexes?",
+ "Detect [OBJ] and return the object indexes in this image.",
+ "Identify and give the object indexes for [OBJ] in the image.",
+ "Please find [OBJ] in the image and return the indexes.",
+ "Where is [OBJ] in this image? Provide the object indexes.",
+ "Can you locate [OBJ] in this image and give the indexes?",
+ "Please detect [OBJ] and provide the object indexes.",
+]
+
+BREIF_CAPTION_WITH_GROUDING_STAGE2 = [
+ "Please breifly describe this image in one sentence and detect all the mentioned objects. Answer the question with grounded answer."
+ "Please briefly describe this image and detect all the mentioned objects. Answer with grounded object indexes.",
+ "Breifly describe this image with grounded answer",
+ "Provide a brief description of this image and detect the objects mentioned in the description.",
+ "Describe this image briefly and ground all the mentioned objects by providing their indexes.",
+ "Give a concise description of the image and detect all the objects mentioned in the description.",
+ "What is a brief description of the image? Provide the object indexes for all mentioned objects.",
+ "Summarize this image briefly and ground the mentioned objects with their indexes.",
+ "Describe this image and detect all the mentioned objects, returning their indexes.",
+ "Provide a short description of the image and ground the objects that are mentioned.",
+ "What is a quick description of this image? Detect the objects mentioned and provide their indexes.",
+ "Give a brief description and detect all mentioned objects, grounding them with their object indexes.",
+ "Can you briefly describe this image and also detect the mentioned objects? Provide grounded object indexes.",
+ "What is a simple description of the image? Detect and provide indexes for the mentioned objects.",
+ "Provide a short description and ground the objects mentioned in the image.",
+ "Give a quick description of the image and provide the indexes of all mentioned objects.",
+ "Summarize this image and ground all the objects mentioned by providing their indexes.",
+ "Describe this image and detect all the mentioned objects. Provide grounded object indexes as the answer.",
+ "What is a brief description of this image? Detect the objects in the description and provide their indexes.",
+ "Offer a short summary of the image and ground all the mentioned objects by providing their object indexes.",
+ "Provide a short caption of the image and detect all mentioned objects with their indexes.",
+ "Give a concise summary of the image and detect all mentioned objects, returning their object indexes.",
+]
+
+DETAILED_CAPTION_WITH_GROUDING_STAGE2 = [
+ "Please provide a detailed description of the image and detect all the mentioned objects. Answer the question with grounded object indexes.",
+ "Give a thorough description of the image and detect all mentioned objects, providing their object indexes.",
+ "Describe the image in detail and ground all the mentioned objects by returning their object indexes.",
+ "What is a detailed description of the image? Detect and provide the indexes for the mentioned objects.",
+ "Provide a detailed explanation of this image and ground all the mentioned objects.",
+ "Can you describe the image in detail and detect the objects mentioned, returning their indexes?",
+ "Describe all the key details in the image and detect the mentioned objects with their indexes.",
+ "Provide a complete and detailed description of the image and ground the mentioned objects.",
+ "Give a comprehensive description of the image and provide the object indexes for all mentioned objects.",
+ "What is happening in this image? Provide a detailed description and detect the mentioned objects with their indexes.",
+ "Describe the content of this image in detail and ground the objects mentioned.",
+ "Offer a full description of the image and provide the object indexes for all detected objects.",
+ "What is a detailed narrative of the scene in this image? Detect and provide indexes for the mentioned objects.",
+ "Explain the image thoroughly and provide the object indexes for all mentioned objects.",
+ "Provide a detailed description of the image and ground the mentioned objects with their indexes.",
+ "Give a detailed overview of the image and detect all mentioned objects by returning their indexes.",
+ "Can you explain the details of the image and ground the mentioned objects, providing object indexes?",
+ "Provide a detailed narrative of this image and detect the objects mentioned, returning their indexes.",
+ "What is a full description of the image? Ground the mentioned objects and provide their indexes.",
+ "Describe all the important details in the image and ground the objects mentioned with their indexes.",
+ "Describe this image in detail with grounded answer.",
+]
diff --git a/chatrex/tools/visualize.py b/chatrex/tools/visualize.py
new file mode 100644
index 0000000..69cea3a
--- /dev/null
+++ b/chatrex/tools/visualize.py
@@ -0,0 +1,203 @@
+import re
+from typing import Dict, List
+
+import numpy as np
+from PIL import Image, ImageDraw, ImageFont
+
+
+class ColorGenerator:
+
+ def __init__(self, color_type) -> None:
+ self.color_type = color_type
+
+ if color_type == "same":
+ self.color = tuple((np.random.randint(0, 127, size=3) + 128).tolist())
+ elif color_type == "text":
+ np.random.seed(3396)
+ self.num_colors = 300
+ self.colors = np.random.randint(0, 127, size=(self.num_colors, 3)) + 128
+ else:
+ raise ValueError
+
+ def get_color(self, text):
+ if self.color_type == "same":
+ return self.color
+
+ if self.color_type == "text":
+ text_hash = hash(text)
+ index = text_hash % self.num_colors
+ color = tuple(self.colors[index])
+ return color
+
+ raise ValueError
+
+
+def plot_boxes_to_image(
+ image_pil: Image,
+ boxes: List[List[float]],
+ scores: List[float],
+ return_point: bool = False,
+ point_width: float = 10.0,
+ return_score=True,
+) -> Image:
+ """Plot bounding boxes and labels on an image.
+
+ Args:
+ image_pil (PIL.Image): The input image as a PIL Image object.
+ boxes: A list of bounding boxes in shape (N, 4), in (x1, y1, x2, y2) format.
+ scores: A list of scores for each bounding box.
+ return_point (bool): Draw center point instead of bounding box. Defaults to False.
+
+ Returns:
+ Union[PIL.Image, PIL.Image]: A tuple containing the input image and ploted image.
+ """
+ # Create a PIL ImageDraw object to draw on the input image
+ draw = ImageDraw.Draw(image_pil)
+ # Create a new binary mask image with the same size as the input image
+ mask = Image.new("L", image_pil.size, 0)
+ # Create a PIL ImageDraw object to draw on the mask image
+ mask_draw = ImageDraw.Draw(mask)
+
+ # Draw boxes and masks for each box and label in the target dictionary
+ for box, score in zip(boxes, scores):
+ # Convert the box coordinates from 0..1 to 0..W, 0..H
+ score = score.item()
+ # Generate a random color for the box outline
+ color = tuple(np.random.randint(0, 255, size=3).tolist())
+ # Extract the box coordinates
+ x0, y0, x1, y1 = box
+ x0, y0, x1, y1 = int(x0), int(y0), int(x1), int(y1)
+ if return_point:
+ ceter_x = int((x0 + x1) / 2)
+ ceter_y = int((y0 + y1) / 2)
+ # Draw the center point on the input image
+ draw.ellipse(
+ (
+ ceter_x - point_width,
+ ceter_y - point_width,
+ ceter_x + point_width,
+ ceter_y + point_width,
+ ),
+ fill=color,
+ width=point_width,
+ )
+ else:
+ # Draw the box outline on the input image
+ draw.rectangle([x0, y0, x1, y1], outline=color, width=int(point_width))
+
+ # Draw the label text on the input image
+ if return_score:
+ text = f"{score:.2f}"
+ else:
+ text = f""
+ font = ImageFont.load_default()
+ if hasattr(font, "getbbox"):
+ bbox = draw.textbbox((x0, y0), text, font)
+ else:
+ w, h = draw.textsize(text, font)
+ bbox = (x0, y0, w + x0, y0 + h)
+ if not return_point:
+ draw.rectangle(bbox, fill=color)
+ draw.text((x0, y0), text, fill="white")
+
+ # Draw the box on the mask image
+ mask_draw.rectangle([x0, y0, x1, y1], fill=255, width=6)
+ return image_pil, mask
+
+
+def convert_all_cate_prediction_to_ans(prediction: str) -> Dict[str, List[str]]:
+ # Define the pattern to extract ground truth labels and object tags
+ pattern = r"(.*?)<\/ground>(.*?)<\/objects>"
+
+ # Find all matches in the prediction string
+ matches = re.findall(pattern, prediction)
+
+ # Initialize the dictionary to store the result
+ ans = {}
+
+ for ground, objects in matches:
+ ground = ground.strip()
+ # Extract the object tags, e.g., , ,
+ object_tags = re.findall(r"", objects)
+ # Add the ground truth label and associated object tags to the dictionary
+ if ground not in ans:
+ ans[ground] = object_tags
+
+ return ans
+
+
+def visualize(image, boxes, labels, font_size: int = 12, draw_width: int = 6):
+ image = image.copy()
+ font_path = "chatrex/tools/Tahoma.ttf"
+ font = ImageFont.truetype(font_path, font_size)
+ color_generaor = ColorGenerator("text")
+ draw = ImageDraw.Draw(image)
+
+ for box, label in zip(boxes, labels):
+ x0, y0, x1, y1 = box
+ if isinstance(label, list):
+ label = label[0]
+ color = color_generaor.get_color(label)
+ text = label
+ try:
+ draw.rectangle(
+ [x0, y0, x1, y1],
+ outline=color,
+ width=draw_width,
+ )
+ except Exception as e:
+ print(f"error: {e}")
+ continue
+ bbox = draw.textbbox((x0, y0), text, font)
+ box_h = bbox[3] - bbox[1]
+ box_w = bbox[2] - bbox[0]
+
+ y0_text = y0 - box_h - (draw_width * 2)
+ y1_text = y0 + draw_width
+ if y0_text < 0:
+ y0_text = 0
+ y1_text = y0 + 2 * draw_width + box_h
+ draw.rectangle(
+ [x0, y0_text, bbox[2] + draw_width * 2, y1_text],
+ fill=color,
+ )
+ draw.text(
+ (x0 + draw_width, y0_text),
+ str(text),
+ fill="black",
+ font=font,
+ )
+
+ return image
+
+
+def visualize_chatrex_output(
+ image_pil: Image,
+ input_boxes: List[List[int]],
+ prediction_text: str,
+ font_size=15,
+ draw_width: int = 6,
+) -> Image:
+ """Plot bounding boxes and labels on an image.
+
+ Args:
+ image_pil (PIL.Image): The input image as a PIL Image object.
+ input_boxes: A list of bounding boxes in shape (N, 4), in (x1, y1, x2, y2) format.
+ prediction_text: The prediction text from the model.
+ font_size: The font size for the text. Defaults to 15.
+ draw_width: The width of the bounding box outline. Defaults to 6.
+
+ Returns:
+ PIL.Image: A tuple containing the input image and ploted image.
+ """
+ prediction_dict = convert_all_cate_prediction_to_ans(prediction_text)
+ pred_boxes = []
+ pred_labels = []
+ for k, v in prediction_dict.items():
+ for box in v:
+ obj_idx = int(box[4:-1])
+ if obj_idx < len(input_boxes):
+ pred_labels.append(k)
+ pred_boxes.append(input_boxes[obj_idx])
+ image_pred = visualize(image_pil, pred_boxes, pred_labels, font_size, draw_width)
+ return image_pred
diff --git a/upn/_C.cpython-39-x86_64-linux-gnu.so b/chatrex/upn/_C.cpython-39-x86_64-linux-gnu.so
similarity index 87%
rename from upn/_C.cpython-39-x86_64-linux-gnu.so
rename to chatrex/upn/_C.cpython-39-x86_64-linux-gnu.so
index eb66ace..53e8265 100755
Binary files a/upn/_C.cpython-39-x86_64-linux-gnu.so and b/chatrex/upn/_C.cpython-39-x86_64-linux-gnu.so differ
diff --git a/upn/__init__.py b/chatrex/upn/__init__.py
similarity index 100%
rename from upn/__init__.py
rename to chatrex/upn/__init__.py
diff --git a/upn/builder.py b/chatrex/upn/builder.py
similarity index 100%
rename from upn/builder.py
rename to chatrex/upn/builder.py
diff --git a/configs/upn_large.py b/chatrex/upn/configs/upn_large.py
similarity index 100%
rename from configs/upn_large.py
rename to chatrex/upn/configs/upn_large.py
diff --git a/upn/inference_wrapper.py b/chatrex/upn/inference_wrapper.py
similarity index 97%
rename from upn/inference_wrapper.py
rename to chatrex/upn/inference_wrapper.py
index 28bf142..8de2211 100644
--- a/upn/inference_wrapper.py
+++ b/chatrex/upn/inference_wrapper.py
@@ -4,20 +4,20 @@
import numpy as np
import torch
+import chatrex.upn.transforms.transform as T
from mmengine import Config
from PIL import Image
from torchvision.ops import nms
-import upn.transforms.transform as T
-from upn import build_architecture
-from upn.models.module import nested_tensor_from_tensor_list
+from chatrex.upn import build_architecture
+from chatrex.upn.models.module import nested_tensor_from_tensor_list
def build_model(
ckpt_path: str,
):
current_path = os.path.dirname(os.path.abspath(__file__))
- config_path = f"configs/upn_large.py"
+ config_path = f"chatrex/upn/configs/upn_large.py"
model_cfg = Config.fromfile(config_path).model
model = build_architecture(model_cfg)
checkpoint = torch.load(ckpt_path, map_location="cpu")
diff --git a/upn/models/architecture/__init__.py b/chatrex/upn/models/architecture/__init__.py
similarity index 100%
rename from upn/models/architecture/__init__.py
rename to chatrex/upn/models/architecture/__init__.py
diff --git a/upn/models/architecture/deformable_transformer.py b/chatrex/upn/models/architecture/deformable_transformer.py
similarity index 98%
rename from upn/models/architecture/deformable_transformer.py
rename to chatrex/upn/models/architecture/deformable_transformer.py
index 9ad5565..77d3b37 100644
--- a/upn/models/architecture/deformable_transformer.py
+++ b/chatrex/upn/models/architecture/deformable_transformer.py
@@ -4,9 +4,10 @@
import torch
import torch.nn as nn
-from upn import ARCHITECTURES, build_decoder, build_encoder
-from upn.models.utils import gen_encoder_output_proposals, inverse_sigmoid
-from upn.ops.modules import MSDeformAttn
+from chatrex.upn import ARCHITECTURES, build_decoder, build_encoder
+from chatrex.upn.models.utils import (gen_encoder_output_proposals,
+ inverse_sigmoid)
+from chatrex.upn.ops.modules import MSDeformAttn
@ARCHITECTURES.register_module()
diff --git a/upn/models/architecture/upn_model.py b/chatrex/upn/models/architecture/upn_model.py
similarity index 97%
rename from upn/models/architecture/upn_model.py
rename to chatrex/upn/models/architecture/upn_model.py
index 414f848..9fb222c 100644
--- a/upn/models/architecture/upn_model.py
+++ b/chatrex/upn/models/architecture/upn_model.py
@@ -5,14 +5,10 @@
import torch.nn as nn
import torch.nn.functional as F
-from upn import ARCHITECTURES, build_architecture, build_backbone
-from upn.models.module import (
- MLP,
- ContrastiveAssign,
- NestedTensor,
- nested_tensor_from_tensor_list,
-)
-from upn.models.utils import inverse_sigmoid
+from chatrex.upn import ARCHITECTURES, build_architecture, build_backbone
+from chatrex.upn.models.module import (MLP, ContrastiveAssign, NestedTensor,
+ nested_tensor_from_tensor_list)
+from chatrex.upn.models.utils import inverse_sigmoid
class LayerNorm2d(nn.Module):
diff --git a/upn/models/backbone/__init__.py b/chatrex/upn/models/backbone/__init__.py
similarity index 100%
rename from upn/models/backbone/__init__.py
rename to chatrex/upn/models/backbone/__init__.py
diff --git a/upn/models/backbone/swin.py b/chatrex/upn/models/backbone/swin.py
similarity index 99%
rename from upn/models/backbone/swin.py
rename to chatrex/upn/models/backbone/swin.py
index 6cb45e6..8aa27b2 100644
--- a/upn/models/backbone/swin.py
+++ b/chatrex/upn/models/backbone/swin.py
@@ -1,13 +1,14 @@
+from typing import Dict, List
+
+import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
-import numpy as np
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
-from upn import BACKBONES
-from upn.models.module import NestedTensor
-from typing import List, Dict
+from chatrex.upn import BACKBONES
+from chatrex.upn.models.module import NestedTensor
class Mlp(nn.Module):
diff --git a/upn/models/backbone/wrapper.py b/chatrex/upn/models/backbone/wrapper.py
similarity index 98%
rename from upn/models/backbone/wrapper.py
rename to chatrex/upn/models/backbone/wrapper.py
index 643ea84..410bd45 100644
--- a/upn/models/backbone/wrapper.py
+++ b/chatrex/upn/models/backbone/wrapper.py
@@ -3,9 +3,9 @@
import torch
import torch.nn as nn
-from upn import BACKBONES, build_backbone, build_position_embedding
-from upn.models.module import NestedTensor
-from upn.models.utils import clean_state_dict
+from chatrex.upn import BACKBONES, build_backbone, build_position_embedding
+from chatrex.upn.models.module import NestedTensor
+from chatrex.upn.models.utils import clean_state_dict
class FrozenBatchNorm2d(torch.nn.Module):
diff --git a/upn/models/decoder/__init__.py b/chatrex/upn/models/decoder/__init__.py
similarity index 100%
rename from upn/models/decoder/__init__.py
rename to chatrex/upn/models/decoder/__init__.py
diff --git a/upn/models/decoder/upn_decoder.py b/chatrex/upn/models/decoder/upn_decoder.py
similarity index 97%
rename from upn/models/decoder/upn_decoder.py
rename to chatrex/upn/models/decoder/upn_decoder.py
index 76fc38c..d1f9698 100644
--- a/upn/models/decoder/upn_decoder.py
+++ b/chatrex/upn/models/decoder/upn_decoder.py
@@ -3,15 +3,12 @@
import torch
import torch.nn as nn
-from upn import DECODERS, build_decoder
-from upn.models.module import MLP
-from upn.models.utils import (
- gen_sineembed_for_position,
- get_activation_fn,
- get_clones,
- inverse_sigmoid,
-)
-from upn.ops.modules import MSDeformAttn
+from chatrex.upn import DECODERS, build_decoder
+from chatrex.upn.models.module import MLP
+from chatrex.upn.models.utils import (gen_sineembed_for_position,
+ get_activation_fn, get_clones,
+ inverse_sigmoid)
+from chatrex.upn.ops.modules import MSDeformAttn
@DECODERS.register_module()
diff --git a/upn/models/encoder/__init__.py b/chatrex/upn/models/encoder/__init__.py
similarity index 100%
rename from upn/models/encoder/__init__.py
rename to chatrex/upn/models/encoder/__init__.py
diff --git a/upn/models/encoder/upn_encoder.py b/chatrex/upn/models/encoder/upn_encoder.py
similarity index 98%
rename from upn/models/encoder/upn_encoder.py
rename to chatrex/upn/models/encoder/upn_encoder.py
index e1c87d9..2c56d5d 100644
--- a/upn/models/encoder/upn_encoder.py
+++ b/chatrex/upn/models/encoder/upn_encoder.py
@@ -4,9 +4,9 @@
import torch.nn as nn
import torch.utils.checkpoint as checkpoint
-from upn import ENCODERS, build_encoder
-from upn.models.utils import get_activation_fn, get_clones
-from upn.ops.modules import MSDeformAttn
+from chatrex.upn import ENCODERS, build_encoder
+from chatrex.upn.models.utils import get_activation_fn, get_clones
+from chatrex.upn.ops.modules import MSDeformAttn
@ENCODERS.register_module()
diff --git a/upn/models/module/__init__.py b/chatrex/upn/models/module/__init__.py
similarity index 100%
rename from upn/models/module/__init__.py
rename to chatrex/upn/models/module/__init__.py
diff --git a/upn/models/module/contrastive.py b/chatrex/upn/models/module/contrastive.py
similarity index 100%
rename from upn/models/module/contrastive.py
rename to chatrex/upn/models/module/contrastive.py
diff --git a/upn/models/module/mlp.py b/chatrex/upn/models/module/mlp.py
similarity index 100%
rename from upn/models/module/mlp.py
rename to chatrex/upn/models/module/mlp.py
diff --git a/upn/models/module/nested_tensor.py b/chatrex/upn/models/module/nested_tensor.py
similarity index 100%
rename from upn/models/module/nested_tensor.py
rename to chatrex/upn/models/module/nested_tensor.py
diff --git a/upn/models/utils/__init__.py b/chatrex/upn/models/utils/__init__.py
similarity index 100%
rename from upn/models/utils/__init__.py
rename to chatrex/upn/models/utils/__init__.py
diff --git a/upn/models/utils/detr_utils.py b/chatrex/upn/models/utils/detr_utils.py
similarity index 99%
rename from upn/models/utils/detr_utils.py
rename to chatrex/upn/models/utils/detr_utils.py
index 91e49bb..3ec3596 100644
--- a/upn/models/utils/detr_utils.py
+++ b/chatrex/upn/models/utils/detr_utils.py
@@ -7,8 +7,8 @@
import torch.nn.functional as F
from torch import nn
-from upn import POS_EMBEDDINGS
-from upn.models.module import NestedTensor
+from chatrex.upn import POS_EMBEDDINGS
+from chatrex.upn.models.module import NestedTensor
@POS_EMBEDDINGS.register_module()
diff --git a/upn/ops/MultiScaleDeformableAttention.cpython-39-x86_64-linux-gnu.so b/chatrex/upn/ops/MultiScaleDeformableAttention.cpython-39-x86_64-linux-gnu.so
similarity index 87%
rename from upn/ops/MultiScaleDeformableAttention.cpython-39-x86_64-linux-gnu.so
rename to chatrex/upn/ops/MultiScaleDeformableAttention.cpython-39-x86_64-linux-gnu.so
index edb67c3..72f9ffc 100755
Binary files a/upn/ops/MultiScaleDeformableAttention.cpython-39-x86_64-linux-gnu.so and b/chatrex/upn/ops/MultiScaleDeformableAttention.cpython-39-x86_64-linux-gnu.so differ
diff --git a/upn/ops/dist/MultiScaleDeformableAttention-1.0-py3.8-linux-x86_64.egg b/chatrex/upn/ops/dist/MultiScaleDeformableAttention-1.0-py3.8-linux-x86_64.egg
similarity index 100%
rename from upn/ops/dist/MultiScaleDeformableAttention-1.0-py3.8-linux-x86_64.egg
rename to chatrex/upn/ops/dist/MultiScaleDeformableAttention-1.0-py3.8-linux-x86_64.egg
diff --git a/upn/ops/dist/MultiScaleDeformableAttention-1.0-py3.9-linux-x86_64.egg b/chatrex/upn/ops/dist/MultiScaleDeformableAttention-1.0-py3.9-linux-x86_64.egg
similarity index 100%
rename from upn/ops/dist/MultiScaleDeformableAttention-1.0-py3.9-linux-x86_64.egg
rename to chatrex/upn/ops/dist/MultiScaleDeformableAttention-1.0-py3.9-linux-x86_64.egg
diff --git a/upn/ops/functions/__init__.py b/chatrex/upn/ops/functions/__init__.py
similarity index 100%
rename from upn/ops/functions/__init__.py
rename to chatrex/upn/ops/functions/__init__.py
diff --git a/upn/ops/functions/__pycache__/__init__.cpython-38.pyc b/chatrex/upn/ops/functions/__pycache__/__init__.cpython-38.pyc
similarity index 100%
rename from upn/ops/functions/__pycache__/__init__.cpython-38.pyc
rename to chatrex/upn/ops/functions/__pycache__/__init__.cpython-38.pyc
diff --git a/upn/ops/functions/__pycache__/ms_deform_attn_func.cpython-38.pyc b/chatrex/upn/ops/functions/__pycache__/ms_deform_attn_func.cpython-38.pyc
similarity index 100%
rename from upn/ops/functions/__pycache__/ms_deform_attn_func.cpython-38.pyc
rename to chatrex/upn/ops/functions/__pycache__/ms_deform_attn_func.cpython-38.pyc
diff --git a/upn/ops/functions/ms_deform_attn_func.py b/chatrex/upn/ops/functions/ms_deform_attn_func.py
similarity index 100%
rename from upn/ops/functions/ms_deform_attn_func.py
rename to chatrex/upn/ops/functions/ms_deform_attn_func.py
diff --git a/upn/ops/modules/__init__.py b/chatrex/upn/ops/modules/__init__.py
similarity index 100%
rename from upn/ops/modules/__init__.py
rename to chatrex/upn/ops/modules/__init__.py
diff --git a/upn/ops/modules/__pycache__/__init__.cpython-38.pyc b/chatrex/upn/ops/modules/__pycache__/__init__.cpython-38.pyc
similarity index 100%
rename from upn/ops/modules/__pycache__/__init__.cpython-38.pyc
rename to chatrex/upn/ops/modules/__pycache__/__init__.cpython-38.pyc
diff --git a/upn/ops/modules/__pycache__/ms_deform_attn.cpython-38.pyc b/chatrex/upn/ops/modules/__pycache__/ms_deform_attn.cpython-38.pyc
similarity index 100%
rename from upn/ops/modules/__pycache__/ms_deform_attn.cpython-38.pyc
rename to chatrex/upn/ops/modules/__pycache__/ms_deform_attn.cpython-38.pyc
diff --git a/upn/ops/modules/ms_deform_attn.py b/chatrex/upn/ops/modules/ms_deform_attn.py
similarity index 100%
rename from upn/ops/modules/ms_deform_attn.py
rename to chatrex/upn/ops/modules/ms_deform_attn.py
diff --git a/upn/ops/modules/ms_deform_attn_key_aware.py b/chatrex/upn/ops/modules/ms_deform_attn_key_aware.py
similarity index 100%
rename from upn/ops/modules/ms_deform_attn_key_aware.py
rename to chatrex/upn/ops/modules/ms_deform_attn_key_aware.py
diff --git a/upn/ops/setup.py b/chatrex/upn/ops/setup.py
similarity index 100%
rename from upn/ops/setup.py
rename to chatrex/upn/ops/setup.py
diff --git a/upn/ops/src/cpu/ms_deform_attn_cpu.cpp b/chatrex/upn/ops/src/cpu/ms_deform_attn_cpu.cpp
similarity index 100%
rename from upn/ops/src/cpu/ms_deform_attn_cpu.cpp
rename to chatrex/upn/ops/src/cpu/ms_deform_attn_cpu.cpp
diff --git a/upn/ops/src/cpu/ms_deform_attn_cpu.h b/chatrex/upn/ops/src/cpu/ms_deform_attn_cpu.h
similarity index 100%
rename from upn/ops/src/cpu/ms_deform_attn_cpu.h
rename to chatrex/upn/ops/src/cpu/ms_deform_attn_cpu.h
diff --git a/upn/ops/src/cuda/ms_deform_attn_cuda.cu b/chatrex/upn/ops/src/cuda/ms_deform_attn_cuda.cu
similarity index 100%
rename from upn/ops/src/cuda/ms_deform_attn_cuda.cu
rename to chatrex/upn/ops/src/cuda/ms_deform_attn_cuda.cu
diff --git a/upn/ops/src/cuda/ms_deform_attn_cuda.h b/chatrex/upn/ops/src/cuda/ms_deform_attn_cuda.h
similarity index 100%
rename from upn/ops/src/cuda/ms_deform_attn_cuda.h
rename to chatrex/upn/ops/src/cuda/ms_deform_attn_cuda.h
diff --git a/upn/ops/src/cuda/ms_deform_im2col_cuda.cuh b/chatrex/upn/ops/src/cuda/ms_deform_im2col_cuda.cuh
similarity index 100%
rename from upn/ops/src/cuda/ms_deform_im2col_cuda.cuh
rename to chatrex/upn/ops/src/cuda/ms_deform_im2col_cuda.cuh
diff --git a/upn/ops/src/ms_deform_attn.h b/chatrex/upn/ops/src/ms_deform_attn.h
similarity index 100%
rename from upn/ops/src/ms_deform_attn.h
rename to chatrex/upn/ops/src/ms_deform_attn.h
diff --git a/upn/ops/src/vision.cpp b/chatrex/upn/ops/src/vision.cpp
similarity index 100%
rename from upn/ops/src/vision.cpp
rename to chatrex/upn/ops/src/vision.cpp
diff --git a/upn/ops/test.py b/chatrex/upn/ops/test.py
similarity index 100%
rename from upn/ops/test.py
rename to chatrex/upn/ops/test.py
diff --git a/upn/transforms/transform.py b/chatrex/upn/transforms/transform.py
similarity index 100%
rename from upn/transforms/transform.py
rename to chatrex/upn/transforms/transform.py
diff --git a/upn/version.py b/chatrex/version.py
similarity index 100%
rename from upn/version.py
rename to chatrex/version.py
diff --git a/configs/TSVDatasets b/configs/TSVDatasets
deleted file mode 120000
index 8d4cf8e..0000000
--- a/configs/TSVDatasets
+++ /dev/null
@@ -1 +0,0 @@
-TSVDatasets/
\ No newline at end of file
diff --git a/requirements.txt b/requirements.txt
index dfbc023..7e3f257 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,10 +1,10 @@
mmengine==0.8.2
numpy==1.22.0
Pillow==10.1.0
-safetensors==0.3.1
+safetensors==0.4.1
scipy==1.9.1
setuptools==59.5.0
termcolor==2.4.0
-timm==0.4.12
-transformers==4.25.1
+timm==1.0.7
+transformers==4.44.2
pycocotools==2.0.8
\ No newline at end of file
diff --git a/setup.py b/setup.py
index d8bceb9..c75a204 100644
--- a/setup.py
+++ b/setup.py
@@ -7,7 +7,7 @@
from torch.utils.cpp_extension import CUDA_HOME, CppExtension, CUDAExtension
version = "v1.0"
-package_name = "upn"
+package_name = "chatrex"
cwd = os.path.dirname(os.path.abspath(__file__))
sha = "Unknown"
@@ -22,7 +22,7 @@
def write_version_file():
- version_path = os.path.join(cwd, "upn", "version.py")
+ version_path = os.path.join(cwd, "chatrex", "version.py")
with open(version_path, "w") as f:
f.write(f"__version__ = '{version}'\n")
# f.write(f"git_version = {repr(sha)}\n")
@@ -35,7 +35,7 @@ def write_version_file():
def get_extensions():
this_dir = os.path.dirname(os.path.abspath(__file__))
- extensions_dir = os.path.join(this_dir, "upn", "ops", "src")
+ extensions_dir = os.path.join(this_dir,"chatrex/upn", "ops", "src")
main_source = os.path.join(extensions_dir, "vision.cpp")
sources = glob.glob(os.path.join(extensions_dir, "**", "*.cpp"))
@@ -74,7 +74,7 @@ def get_extensions():
ext_modules = [
extension(
- "upn._C",
+ "chatrex.upn._C",
sources,
include_dirs=include_dirs,
define_macros=define_macros,
@@ -172,10 +172,10 @@ def gen_packages_items():
write_version_file()
setup(
- name="upn",
+ name="chatrex",
version="v1.0",
author="International Digital Economy Academy, Qing Jiang",
- url="https://github.com/IDEA-Research/Universal-Proposal-Network",
+ url="https://github.com/IDEA-Research/ChatRex",
description="Universal Proposal Network.",
license=license,
install_requires=parse_requirements("requirements.txt"),
diff --git a/tests/images/test_chatrex_detection.jpg b/tests/images/test_chatrex_detection.jpg
new file mode 100644
index 0000000..5be167e
Binary files /dev/null and b/tests/images/test_chatrex_detection.jpg differ
diff --git a/tests/images/test_chatrex_grounded_caption.jpg b/tests/images/test_chatrex_grounded_caption.jpg
new file mode 100644
index 0000000..a3058b9
Binary files /dev/null and b/tests/images/test_chatrex_grounded_caption.jpg differ
diff --git a/tests/images/test_chatrex_install.jpg b/tests/images/test_chatrex_install.jpg
new file mode 100644
index 0000000..3ba5352
Binary files /dev/null and b/tests/images/test_chatrex_install.jpg differ
diff --git a/tests/images/test_grounded_conversation.jpg b/tests/images/test_grounded_conversation.jpg
new file mode 100644
index 0000000..23c7f53
Binary files /dev/null and b/tests/images/test_grounded_conversation.jpg differ
diff --git a/tests/test_image.jpeg b/tests/images/test_upn.jpeg
similarity index 100%
rename from tests/test_image.jpeg
rename to tests/images/test_upn.jpeg
diff --git a/tests/test_chatrex_detection.py b/tests/test_chatrex_detection.py
new file mode 100644
index 0000000..d181dd2
--- /dev/null
+++ b/tests/test_chatrex_detection.py
@@ -0,0 +1,74 @@
+import torch
+from PIL import Image
+from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig
+
+from chatrex.tools.visualize import visualize_chatrex_output
+from chatrex.upn import UPNWrapper
+
+if __name__ == "__main__":
+ # load the processor
+ processor = AutoProcessor.from_pretrained(
+ "checkpoints/chatrex7b",
+ trust_remote_code=True,
+ device_map="cuda",
+ )
+
+ print(f"loading chatrex model...")
+ # load chatrex model
+ model = AutoModelForCausalLM.from_pretrained(
+ "checkpoints/chatrex7b",
+ trust_remote_code=True,
+ use_safetensors=True,
+ ).to("cuda")
+
+ # load upn model
+ print(f"loading upn model...")
+ ckpt_path = "checkpoints/upn_checkpoints/upn_large.pth"
+ model_upn = UPNWrapper(ckpt_path)
+ test_image_path = "tests/images/test_chatrex_detection.jpg"
+
+ # get upn predictions
+ fine_grained_proposals = model_upn.inference(
+ test_image_path, prompt_type="fine_grained_prompt"
+ )
+ fine_grained_filtered_proposals = model_upn.filter(
+ fine_grained_proposals, min_score=0.3, nms_value=0.8
+ )
+
+ inputs = processor.process(
+ image=Image.open(test_image_path),
+ question="Please detect person; pigeon in this image. Answer the question with object indexes.",
+ bbox=fine_grained_filtered_proposals["original_xyxy_boxes"][
+ 0
+ ], # box in xyxy format
+ )
+
+ inputs = {k: v.to("cuda") for k, v in inputs.items()}
+
+ # perform inference
+ gen_config = GenerationConfig(
+ max_new_tokens=512,
+ do_sample=False,
+ eos_token_id=processor.tokenizer.eos_token_id,
+ pad_token_id=(
+ processor.tokenizer.pad_token_id
+ if processor.tokenizer.pad_token_id is not None
+ else processor.tokenizer.eos_token_id
+ ),
+ )
+ with torch.autocast(device_type="cuda", enabled=True, dtype=torch.bfloat16):
+ prediction = model.generate(
+ inputs, gen_config=gen_config, tokenizer=processor.tokenizer
+ )
+ print(f"prediction:", prediction)
+
+ # visualize the prediction
+ vis_image = visualize_chatrex_output(
+ Image.open(test_image_path),
+ fine_grained_filtered_proposals["original_xyxy_boxes"][0],
+ prediction,
+ font_size=15,
+ draw_width=5,
+ )
+ vis_image.save("tests/test_chatrex_detection.jpeg")
+ print(f"prediction is saved at tests/test_chatrex_detection.jpeg")
diff --git a/tests/test_chatrex_grounded_conversation.py b/tests/test_chatrex_grounded_conversation.py
new file mode 100644
index 0000000..c39b381
--- /dev/null
+++ b/tests/test_chatrex_grounded_conversation.py
@@ -0,0 +1,74 @@
+import torch
+from PIL import Image
+from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig
+
+from chatrex.tools.visualize import visualize_chatrex_output
+from chatrex.upn import UPNWrapper
+
+if __name__ == "__main__":
+ # load the processor
+ processor = AutoProcessor.from_pretrained(
+ "checkpoints/chatrex7b",
+ trust_remote_code=True,
+ device_map="cuda",
+ )
+
+ print(f"loading chatrex model...")
+ # load chatrex model
+ model = AutoModelForCausalLM.from_pretrained(
+ "checkpoints/chatrex7b",
+ trust_remote_code=True,
+ use_safetensors=True,
+ ).to("cuda")
+
+ # load upn model
+ print(f"loading upn model...")
+ ckpt_path = "checkpoints/upn_checkpoints/upn_large.pth"
+ model_upn = UPNWrapper(ckpt_path)
+ test_image_path = "tests/images/test_grounded_conversation.jpg"
+
+ # get upn predictions
+ fine_grained_proposals = model_upn.inference(
+ test_image_path, prompt_type="coarse_grained_prompt"
+ )
+ fine_grained_filtered_proposals = model_upn.filter(
+ fine_grained_proposals, min_score=0.3, nms_value=0.8
+ )
+
+ inputs = processor.process(
+ image=Image.open(test_image_path),
+ question="Answer the question in grounded format. This is a photo of my room, and can you tell me what kind of person I am? ",
+ bbox=fine_grained_filtered_proposals["original_xyxy_boxes"][
+ 0
+ ], # box in xyxy format
+ )
+
+ inputs = {k: v.to("cuda") for k, v in inputs.items()}
+
+ # perform inference
+ gen_config = GenerationConfig(
+ max_new_tokens=512,
+ do_sample=False,
+ eos_token_id=processor.tokenizer.eos_token_id,
+ pad_token_id=(
+ processor.tokenizer.pad_token_id
+ if processor.tokenizer.pad_token_id is not None
+ else processor.tokenizer.eos_token_id
+ ),
+ )
+ with torch.autocast(device_type="cuda", enabled=True, dtype=torch.bfloat16):
+ prediction = model.generate(
+ inputs, gen_config=gen_config, tokenizer=processor.tokenizer
+ )
+ print(f"prediction:", prediction)
+
+ # visualize the prediction
+ vis_image = visualize_chatrex_output(
+ Image.open(test_image_path),
+ fine_grained_filtered_proposals["original_xyxy_boxes"][0],
+ prediction,
+ font_size=30,
+ draw_width=10,
+ )
+ vis_image.save("tests/test_chatrex_grounded_conversation.jpeg")
+ print(f"prediction is saved at tests/test_chatrex_grounded_conversation.jpeg")
diff --git a/tests/test_chatrex_grounded_image_caption.py b/tests/test_chatrex_grounded_image_caption.py
new file mode 100644
index 0000000..0fb9db3
--- /dev/null
+++ b/tests/test_chatrex_grounded_image_caption.py
@@ -0,0 +1,74 @@
+import torch
+from PIL import Image
+from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig
+
+from chatrex.tools.visualize import visualize_chatrex_output
+from chatrex.upn import UPNWrapper
+
+if __name__ == "__main__":
+ # load the processor
+ processor = AutoProcessor.from_pretrained(
+ "checkpoints/chatrex7b",
+ trust_remote_code=True,
+ device_map="cuda",
+ )
+
+ print(f"loading chatrex model...")
+ # load chatrex model
+ model = AutoModelForCausalLM.from_pretrained(
+ "checkpoints/chatrex7b",
+ trust_remote_code=True,
+ use_safetensors=True,
+ ).to("cuda")
+
+ # load upn model
+ print(f"loading upn model...")
+ ckpt_path = "checkpoints/upn_checkpoints/upn_large.pth"
+ model_upn = UPNWrapper(ckpt_path)
+ test_image_path = "tests/images/test_chatrex_grounded_caption.jpg"
+
+ # get upn predictions
+ fine_grained_proposals = model_upn.inference(
+ test_image_path, prompt_type="fine_grained_prompt"
+ )
+ fine_grained_filtered_proposals = model_upn.filter(
+ fine_grained_proposals, min_score=0.3, nms_value=0.8
+ )
+
+ inputs = processor.process(
+ image=Image.open(test_image_path),
+ question="Please breifly describe this image in one sentence and detect all the mentioned objects. Answer the question with grounded answer.",
+ bbox=fine_grained_filtered_proposals["original_xyxy_boxes"][
+ 0
+ ], # box in xyxy format
+ )
+
+ inputs = {k: v.to("cuda") for k, v in inputs.items()}
+
+ # perform inference
+ gen_config = GenerationConfig(
+ max_new_tokens=512,
+ do_sample=False,
+ eos_token_id=processor.tokenizer.eos_token_id,
+ pad_token_id=(
+ processor.tokenizer.pad_token_id
+ if processor.tokenizer.pad_token_id is not None
+ else processor.tokenizer.eos_token_id
+ ),
+ )
+ with torch.autocast(device_type="cuda", enabled=True, dtype=torch.bfloat16):
+ prediction = model.generate(
+ inputs, gen_config=gen_config, tokenizer=processor.tokenizer
+ )
+ print(f"prediction:", prediction)
+
+ # visualize the prediction
+ vis_image = visualize_chatrex_output(
+ Image.open(test_image_path),
+ fine_grained_filtered_proposals["original_xyxy_boxes"][0],
+ prediction,
+ font_size=15,
+ draw_width=5,
+ )
+ vis_image.save("tests/test_chatrex_grounded_image_caption.jpeg")
+ print(f"prediction is saved at tests/test_chatrex_grounded_image_caption.jpeg")
diff --git a/tests/test_chatrex_install.py b/tests/test_chatrex_install.py
new file mode 100644
index 0000000..d09343c
--- /dev/null
+++ b/tests/test_chatrex_install.py
@@ -0,0 +1,42 @@
+from PIL import Image
+from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig
+
+if __name__ == "__main__":
+ # load the processor
+ processor = AutoProcessor.from_pretrained(
+ "checkpoints/chatrex7b",
+ trust_remote_code=True,
+ device_map="cuda",
+ )
+
+ # load the model
+ model = AutoModelForCausalLM.from_pretrained(
+ "checkpoints/chatrex7b",
+ trust_remote_code=True,
+ use_safetensors=True,
+ ).to('cuda')
+
+ inputs = processor.process(
+ image=Image.open(
+ 'tests/images/test_chatrex_install.jpg'
+ ),
+ question="Can you provide me with a brief description of ?",
+ bbox=[[73.88417,56.62228,227.69223,216.34338]] # box in xyxy format
+ )
+
+
+ inputs = {k: v.to("cuda") for k, v in inputs.items()}
+
+ # perform inference
+ gen_config = GenerationConfig(
+ max_new_tokens=512,
+ do_sample=False,
+ eos_token_id=processor.tokenizer.eos_token_id,
+ pad_token_id=(
+ processor.tokenizer.pad_token_id
+ if processor.tokenizer.pad_token_id is not None
+ else processor.tokenizer.eos_token_id
+ ))
+ prediction = model.generate(inputs, gen_config=gen_config, tokenizer=processor.tokenizer)
+ print(f'prediction:', prediction)
+
diff --git a/tests/test_chatrex_region_caption.py b/tests/test_chatrex_region_caption.py
new file mode 100644
index 0000000..4140425
--- /dev/null
+++ b/tests/test_chatrex_region_caption.py
@@ -0,0 +1,60 @@
+import torch
+from PIL import Image
+from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig
+
+from chatrex.tools.visualize import visualize_chatrex_output
+from chatrex.upn import UPNWrapper
+
+if __name__ == "__main__":
+ # load the processor
+ processor = AutoProcessor.from_pretrained(
+ "checkpoints/chatrex7b",
+ trust_remote_code=True,
+ device_map="cuda",
+ )
+
+ print(f"loading chatrex model...")
+ # load chatrex model
+ model = AutoModelForCausalLM.from_pretrained(
+ "checkpoints/chatrex7b",
+ trust_remote_code=True,
+ use_safetensors=True,
+ ).to("cuda")
+
+ test_image_path = "tests/images/test_chatrex_install.jpg"
+
+ inputs = processor.process(
+ image=Image.open(test_image_path),
+ question="Can you provide a one sentence description of in the image? Answer the question with a one sentence description.",
+ bbox=[[73.88417, 56.62228, 227.69223, 216.34338]],
+ )
+
+ inputs = {k: v.to("cuda") for k, v in inputs.items()}
+
+ # perform inference
+ gen_config = GenerationConfig(
+ max_new_tokens=512,
+ do_sample=False,
+ eos_token_id=processor.tokenizer.eos_token_id,
+ pad_token_id=(
+ processor.tokenizer.pad_token_id
+ if processor.tokenizer.pad_token_id is not None
+ else processor.tokenizer.eos_token_id
+ ),
+ )
+ with torch.autocast(device_type="cuda", enabled=True, dtype=torch.bfloat16):
+ prediction = model.generate(
+ inputs, gen_config=gen_config, tokenizer=processor.tokenizer
+ )
+ print(f"prediction:", prediction)
+
+ # visualize the prediction
+ vis_image = visualize_chatrex_output(
+ Image.open(test_image_path),
+ [[73.88417, 56.62228, 227.69223, 216.34338]],
+ prediction,
+ font_size=15,
+ draw_width=5,
+ )
+ vis_image.save("tests/test_chatrex_region_caption.jpeg")
+ print(f"prediction is saved at tests/test_chatrex_region_caption.jpeg")
diff --git a/tests/test_install.py b/tests/test_upn_install.py
similarity index 76%
rename from tests/test_install.py
rename to tests/test_upn_install.py
index 7f57f0b..af7147a 100644
--- a/tests/test_install.py
+++ b/tests/test_upn_install.py
@@ -1,12 +1,11 @@
-import torch
from PIL import Image
-from tools.visualize import plot_boxes_to_image
-from upn import UPNWrapper
+from chatrex.tools.visualize import plot_boxes_to_image
+from chatrex.upn import UPNWrapper
if __name__ == "__main__":
- ckpt_path = "checkpoints/upn_large.pth"
- test_image_path = "tests/test_image.jpeg"
+ ckpt_path = "checkpoints/upn_checkpoints/upn_large.pth"
+ test_image_path = "tests/images/test_upn.jpeg"
model = UPNWrapper(ckpt_path)
# fine-grained prompt
@@ -33,6 +32,7 @@
fine_grained_filtered_proposals["scores"][0],
)
fine_grained_vis_image.save("tests/test_image_fine_grained.jpeg")
+ print(f"fine-grained proposal is saved at tests/test_image_fine_grained.jpeg")
coarse_grained_vis_image, _ = plot_boxes_to_image(
image.copy(),
@@ -40,3 +40,4 @@
coarse_grained_filtered_proposals["scores"][0],
)
coarse_grained_vis_image.save("tests/test_image_coarse_grained.jpeg")
+ print(f"coarse-grained proposal is saved at tests/test_image_coarse_grained.jpeg")
diff --git a/tools/visualize.py b/tools/visualize.py
deleted file mode 100644
index aeffb21..0000000
--- a/tools/visualize.py
+++ /dev/null
@@ -1,77 +0,0 @@
-from typing import List
-
-import numpy as np
-from PIL import Image, ImageDraw, ImageFont
-
-
-def plot_boxes_to_image(
- image_pil: Image,
- boxes: List[List[float]],
- scores: List[float],
- return_point: bool = False,
- point_width: float = 10.0,
- return_score=True,
-) -> Image:
- """Plot bounding boxes and labels on an image.
-
- Args:
- image_pil (PIL.Image): The input image as a PIL Image object.
- boxes: A list of bounding boxes in shape (N, 4), in (x1, y1, x2, y2) format.
- scores: A list of scores for each bounding box.
- return_point (bool): Draw center point instead of bounding box. Defaults to False.
-
- Returns:
- Union[PIL.Image, PIL.Image]: A tuple containing the input image and ploted image.
- """
- # Create a PIL ImageDraw object to draw on the input image
- draw = ImageDraw.Draw(image_pil)
- # Create a new binary mask image with the same size as the input image
- mask = Image.new("L", image_pil.size, 0)
- # Create a PIL ImageDraw object to draw on the mask image
- mask_draw = ImageDraw.Draw(mask)
-
- # Draw boxes and masks for each box and label in the target dictionary
- for box, score in zip(boxes, scores):
- # Convert the box coordinates from 0..1 to 0..W, 0..H
- score = score.item()
- # Generate a random color for the box outline
- color = tuple(np.random.randint(0, 255, size=3).tolist())
- # Extract the box coordinates
- x0, y0, x1, y1 = box
- x0, y0, x1, y1 = int(x0), int(y0), int(x1), int(y1)
- if return_point:
- ceter_x = int((x0 + x1) / 2)
- ceter_y = int((y0 + y1) / 2)
- # Draw the center point on the input image
- draw.ellipse(
- (
- ceter_x - point_width,
- ceter_y - point_width,
- ceter_x + point_width,
- ceter_y + point_width,
- ),
- fill=color,
- width=point_width,
- )
- else:
- # Draw the box outline on the input image
- draw.rectangle([x0, y0, x1, y1], outline=color, width=int(point_width))
-
- # Draw the label text on the input image
- if return_score:
- text = f"{score:.2f}"
- else:
- text = f""
- font = ImageFont.load_default()
- if hasattr(font, "getbbox"):
- bbox = draw.textbbox((x0, y0), text, font)
- else:
- w, h = draw.textsize(text, font)
- bbox = (x0, y0, w + x0, y0 + h)
- if not return_point:
- draw.rectangle(bbox, fill=color)
- draw.text((x0, y0), text, fill="white")
-
- # Draw the box on the mask image
- mask_draw.rectangle([x0, y0, x1, y1], fill=255, width=6)
- return image_pil, mask
diff --git a/upn/ops/MultiScaleDeformableAttention.egg-info/PKG-INFO b/upn/ops/MultiScaleDeformableAttention.egg-info/PKG-INFO
deleted file mode 100644
index 18ccefb..0000000
--- a/upn/ops/MultiScaleDeformableAttention.egg-info/PKG-INFO
+++ /dev/null
@@ -1,11 +0,0 @@
-Metadata-Version: 2.1
-Name: MultiScaleDeformableAttention
-Version: 1.0
-Summary: PyTorch Wrapper for CUDA Functions of Multi-Scale Deformable Attention
-Home-page: https://github.com/fundamentalvision/Deformable-DETR
-Author: Weijie Su
-License: UNKNOWN
-Platform: UNKNOWN
-
-UNKNOWN
-
diff --git a/upn/ops/MultiScaleDeformableAttention.egg-info/SOURCES.txt b/upn/ops/MultiScaleDeformableAttention.egg-info/SOURCES.txt
deleted file mode 100644
index 676963a..0000000
--- a/upn/ops/MultiScaleDeformableAttention.egg-info/SOURCES.txt
+++ /dev/null
@@ -1,13 +0,0 @@
-setup.py
-/comp_robot/jiangqing/projects/2023/research/open-source/UPN/upn/ops/src/vision.cpp
-/comp_robot/jiangqing/projects/2023/research/open-source/UPN/upn/ops/src/cpu/ms_deform_attn_cpu.cpp
-/comp_robot/jiangqing/projects/2023/research/open-source/UPN/upn/ops/src/cuda/ms_deform_attn_cuda.cu
-MultiScaleDeformableAttention.egg-info/PKG-INFO
-MultiScaleDeformableAttention.egg-info/SOURCES.txt
-MultiScaleDeformableAttention.egg-info/dependency_links.txt
-MultiScaleDeformableAttention.egg-info/top_level.txt
-functions/__init__.py
-functions/ms_deform_attn_func.py
-modules/__init__.py
-modules/ms_deform_attn.py
-modules/ms_deform_attn_key_aware.py
\ No newline at end of file
diff --git a/upn/ops/MultiScaleDeformableAttention.egg-info/dependency_links.txt b/upn/ops/MultiScaleDeformableAttention.egg-info/dependency_links.txt
deleted file mode 100644
index 8b13789..0000000
--- a/upn/ops/MultiScaleDeformableAttention.egg-info/dependency_links.txt
+++ /dev/null
@@ -1 +0,0 @@
-
diff --git a/upn/ops/MultiScaleDeformableAttention.egg-info/top_level.txt b/upn/ops/MultiScaleDeformableAttention.egg-info/top_level.txt
deleted file mode 100644
index 25d8f77..0000000
--- a/upn/ops/MultiScaleDeformableAttention.egg-info/top_level.txt
+++ /dev/null
@@ -1,3 +0,0 @@
-MultiScaleDeformableAttention
-functions
-modules