Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
am committed Oct 29, 2024
1 parent 385a359 commit 0427288
Show file tree
Hide file tree
Showing 8 changed files with 50 additions and 91 deletions.
12 changes: 11 additions & 1 deletion m3/demo/gradio_m3.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,6 +676,7 @@ def update_sys_message(sys_message, sv):
sv.sys_msg = sys_message
return sv


def update_modality_prompt(modality_prompt, sv):
"""Update the modality prompt"""
logger.debug(f"Updating the modality prompt")
Expand Down Expand Up @@ -785,7 +786,16 @@ def create_demo(source, model_path, conv_mode, server_port):
clear_btn.click(
fn=clear_all_convs,
inputs=[sv],
outputs=[sv, prompt_edit, chat_history, history_text, history_text_full, sys_prompt_text, model_cards_text, modality_prompt_dropdown],
outputs=[
sv,
prompt_edit,
chat_history,
history_text,
history_text_full,
sys_prompt_text,
model_cards_text,
modality_prompt_dropdown,
],
)

# States
Expand Down
12 changes: 8 additions & 4 deletions m3/eval/scripts/classification/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@
"incorporate the additional results generated by an expert classification model:\n"
)
text_only_expert_prefix = (
"When answering the question, please "
"incorporate the results generated by an expert classification model:\n"
"When answering the question, please incorporate the results generated by an expert classification model:\n"
)
multi_choice_prompt = [
[
Expand Down Expand Up @@ -98,7 +97,7 @@ def has_placeholder(prompts, placeholder):
return False


_first_prompt = "Question: is there <class_name> according to the image?\n" "Please reply with yes or no.\n"
_first_prompt = "Question: is there <class_name> according to the image?\n Please reply with yes or no.\n"
binary_conv_prompt = [
[0, model_list + f"<image> This is a CXR image.\n{_first_prompt}"],
[1, "This looks like a chest x-ray. Let me first trigger <CXR()>."],
Expand All @@ -110,4 +109,9 @@ def has_placeholder(prompts, placeholder):
[1, None],
]

templates = {"multi_choice": multi_choice_prompt, "binary": binary_prompt, "binary_conv": binary_conv_prompt, "text_only": text_only_prompt}
templates = {
"multi_choice": multi_choice_prompt,
"binary": binary_prompt,
"binary_conv": binary_conv_prompt,
"text_only": text_only_prompt,
}
2 changes: 0 additions & 2 deletions m3/eval/scripts/print_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ def main():
except:
results["report_mimiccxr_old_green"] = "?"


try:
json_data = load_json(os.path.join(args.input, "report_mimiccxr/results_clean.json"))
results["report_mimiccxr_clean_bleu4"] = json_data["BLEU_4"]
Expand All @@ -95,7 +94,6 @@ def main():
except:
results["report_mimiccxr_clean_expert_bleu4"] = "?"
results["report_mimiccxr_clean_expert_rougel"] = "?"


try:
json_data = load_json(os.path.join(args.input, "report_mimiccxr/result_clean_expert_green.json"))
Expand Down
19 changes: 4 additions & 15 deletions m3/eval/scripts/report_updated/cleaning.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,20 +101,15 @@ def capitalize_first_letter(sentences):
# Capitalize the first letter
refined_sentences.append(sentence[0].upper() + sentence[1:])
else:
print(
f"First character is not a letter in sentence: '{sentence}'"
)
print(f"First character is not a letter in sentence: '{sentence}'")
continue

return refined_sentences


def remove_before_colon(sentences):
# Replace only when a colon is followed by a space
refined_sentences = [
sentence.split(": ", 1)[-1] if ": " in sentence else sentence
for sentence in sentences
]
refined_sentences = [sentence.split(": ", 1)[-1] if ": " in sentence else sentence for sentence in sentences]

return refined_sentences

Expand Down Expand Up @@ -166,9 +161,7 @@ def remove_duplicate_sentences(sentences):

for sentence in sentences:
if sentence not in seen:
unique_sentences.append(
sentence
) # Add sentence to the result if not seen before
unique_sentences.append(sentence) # Add sentence to the result if not seen before
seen.add(sentence) # Mark this sentence as seen

return unique_sentences
Expand Down Expand Up @@ -209,11 +202,7 @@ def replace_appearance(text, target_word, replacement_word, nth_app):
return text

# Rebuild the string: join the parts, but replace the second occurrence
modified_text = (
target_word.join(parts[:nth_app])
+ replacement_word
+ target_word.join(parts[nth_app:])
)
modified_text = target_word.join(parts[:nth_app]) + replacement_word + target_word.join(parts[nth_app:])

return modified_text

Expand Down
40 changes: 10 additions & 30 deletions m3/eval/scripts/report_updated/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,14 @@
# limitations under the License.

import argparse
import re
from io import BytesIO
import math
import os
import os.path as osp
import re
from io import BytesIO

import math
import requests
import torch
from PIL import Image

from llava.constants import (
DEFAULT_IM_END_TOKEN,
DEFAULT_IM_START_TOKEN,
Expand All @@ -28,14 +26,10 @@
IMAGE_TOKEN_INDEX,
)
from llava.conversation import SeparatorStyle, conv_templates
from llava.mm_utils import (
KeywordsStoppingCriteria,
get_model_name_from_path,
process_images,
tokenizer_image_token,
)
from llava.mm_utils import KeywordsStoppingCriteria, get_model_name_from_path, process_images, tokenizer_image_token
from llava.model.builder import load_pretrained_model
from llava.utils import disable_torch_init
from PIL import Image


def load_filenames(file_path):
Expand Down Expand Up @@ -80,29 +74,23 @@ def eval_model(args):
output_folder = args.output_folder

model_name = get_model_name_from_path(args.model_path)
tokenizer, model, image_processor, context_len = load_pretrained_model(
args.model_path, model_name, args.model_base
)
tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, model_name, args.model_base)

for img_filename in image_filenames:
image_path = osp.join(images_folder, img_filename)
image = load_image(image_path)

query = "Describe the image in detail."

image_token_se = (
DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN
)
image_token_se = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN
if IMAGE_PLACEHOLDER in query:
if model.config.mm_use_im_start_end:
query = re.sub(IMAGE_PLACEHOLDER, image_token_se, query)
else:
query = re.sub(IMAGE_PLACEHOLDER, DEFAULT_IMAGE_TOKEN, query)
else:
if DEFAULT_IMAGE_TOKEN not in query:
print(
f"no <image> tag found in input. Automatically append one at the beginning of text."
)
print(f"no <image> tag found in input. Automatically append one at the beginning of text.")
if model.config.mm_use_im_start_end:
query = image_token_se + "\n" + query
else:
Expand Down Expand Up @@ -133,16 +121,8 @@ def eval_model(args):
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()

images_tensor = process_images([image], image_processor, model.config).to(
model.device, dtype=torch.float16
)
input_ids = (
tokenizer_image_token(
prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt"
)
.unsqueeze(0)
.cuda()
)
images_tensor = process_images([image], image_processor, model.config).to(model.device, dtype=torch.float16)
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).cuda()

stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
keywords = [stop_str]
Expand Down
37 changes: 10 additions & 27 deletions m3/eval/scripts/report_updated/inference_expert.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,15 @@
# limitations under the License.

import argparse
import re
from io import BytesIO
import json
import math
import os
import os.path as osp
import re
from io import BytesIO

import json
import math
import requests
import torch
from PIL import Image

from llava.constants import (
DEFAULT_IM_END_TOKEN,
DEFAULT_IM_START_TOKEN,
Expand All @@ -29,15 +27,10 @@
IMAGE_TOKEN_INDEX,
)
from llava.conversation import SeparatorStyle, conv_templates
from llava.mm_utils import (
KeywordsStoppingCriteria,
get_model_name_from_path,
process_images,
tokenizer_image_token,
)
from llava.mm_utils import KeywordsStoppingCriteria, get_model_name_from_path, process_images, tokenizer_image_token
from llava.model.builder import load_pretrained_model
from llava.utils import disable_torch_init

from PIL import Image

model_list = (
"Here is a list of available expert models:\n"
Expand Down Expand Up @@ -108,9 +101,7 @@ def eval_model(args):
output_folder = args.output_folder

model_name = get_model_name_from_path(args.model_path)
tokenizer, model, image_processor, context_len = load_pretrained_model(
args.model_path, model_name, args.model_base
)
tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, model_name, args.model_base)

for img_filename in image_filenames:
image_path = osp.join(images_folder, img_filename)
Expand Down Expand Up @@ -146,22 +137,14 @@ def eval_model(args):
conv.append_message(
conv.roles[0],
f"The resulting predictions are:\n{preds}. Analyze the image and take these predictions "
f"into account when responding to this prompt.\n{query}"
f"into account when responding to this prompt.\n{query}",
)
conv.append_message(conv.roles[1], None)
print(conv)
prompt = conv.get_prompt()

images_tensor = process_images([image], image_processor, model.config).to(
model.device, dtype=torch.float16
)
input_ids = (
tokenizer_image_token(
prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt"
)
.unsqueeze(0)
.cuda()
)
images_tensor = process_images([image], image_processor, model.config).to(model.device, dtype=torch.float16)
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).cuda()

stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
keywords = [stop_str]
Expand Down
10 changes: 3 additions & 7 deletions m3/eval/scripts/report_updated/metric_green_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from green_score import GREEN


def load_text_files_from_directory(directory):
"""
Load the content of all text files from a directory and return it as a list of strings.
Expand All @@ -26,15 +27,12 @@ def load_text_files_from_directory(directory):
filepath = os.path.join(directory, filename)
if os.path.isfile(filepath) and filename.endswith(".txt"):
with open(filepath, "r") as file:
content = (
file.read().strip()
) # Read the file content and remove any surrounding whitespace
content = file.read().strip() # Read the file content and remove any surrounding whitespace
texts.append(content)
return texts


def run_inference(refs, hyps):

# Initialize the GREEN model (assumes correct GPU has been set in the environment)
model = GREEN(
model_id_or_path="StanfordAIMI/GREEN-radllama2-7b",
Expand Down Expand Up @@ -99,9 +97,7 @@ def read_files(directory):
refs_dir = sys.argv[1] # refs directory path
hyps_dir = sys.argv[2] # hyps directory path
num_partitions = int(sys.argv[3]) # Total number of partitions (e.g., total GPUs)
partition_index = int(
sys.argv[4]
) # Index of this partition (e.g., current GPU/process index)
partition_index = int(sys.argv[4]) # Index of this partition (e.g., current GPU/process index)

ground_truths = read_files(refs_dir)
predictions = read_files(hyps_dir)
Expand Down
9 changes: 4 additions & 5 deletions m3/eval/scripts/report_updated/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
import re

from pycocoevalcap.bleu.bleu import Bleu
from pycocoevalcap.cider.cider import Cider
from pycocoevalcap.meteor.meteor import Meteor
from pycocoevalcap.rouge.rouge import Rouge
from pycocoevalcap.cider.cider import Cider
from pycocoevalcap.tokenizer.ptbtokenizer import PTBTokenizer


Expand Down Expand Up @@ -60,7 +60,7 @@ def compute_scores(gts, res):
:param res: Dictionary with the image ids ant their generated captions
:print: Evaluation score (the mean of the scores of all the instances) for each measure
"""
print('tokenization...')
print("tokenization...")
tokenizer = PTBTokenizer()
gts = tokenizer.tokenize(gts)
res = tokenizer.tokenize(res)
Expand Down Expand Up @@ -122,9 +122,7 @@ def main():

_text = predictions[filename]
_text = normalize_spaces(_text.replace("\n", " "))
prediction_data.append(
{"image_id": idx, "caption": normalize_spaces(_text.replace("\n", " "))}
)
prediction_data.append({"image_id": idx, "caption": normalize_spaces(_text.replace("\n", " "))})

print(f"found {len(prediction_data)} prediction data points.")

Expand All @@ -141,5 +139,6 @@ def main():
with open(args.output, "w") as f:
json.dump(eval_res, f)


if __name__ == "__main__":
main()

0 comments on commit 0427288

Please sign in to comment.