diff --git a/scripts/caption_generation_inference.py b/scripts/caption_generation_inference.py index 8687258..16409de 100644 --- a/scripts/caption_generation_inference.py +++ b/scripts/caption_generation_inference.py @@ -47,16 +47,35 @@ def infer_gemini( outputs = [] for idx, image_path in enumerate(images): - model = GenerativeModel("") + model = GenerativeModel("gemini-pro-vision") temp = generative_models.Image.load_from_file(image_path) image_data = generative_models.Part.from_image(temp) qid, retrieval_results = retrieval_dict.get(image_path) - message = p_class.prepare_message(retrieval_results) - content = [message, image_data] + + fewshot_images = p_class.get_fewshot_image_data(retrieval_results) + fewshot_captions = p_class.get_fewshot_captions(retrieval_results) + assert len(fewshot_images) == len(fewshot_captions) + + message = p_class.prepare_gemini_message(len(fewshot_images)) + if fewshot_images: + json_data = [message] + content = [message] + for fewshot_image, fewshot_caption, result in zip( + fewshot_images, fewshot_captions, retrieval_results + ): + json_data.append(result[1]) + json_data.append(fewshot_caption) + content.append(fewshot_image) + content.append(fewshot_caption) + json_data.append(image_path) + content.append(image_data) + else: + json_data = [message, image_path] + content = [message, image_data] if idx < 10: with open(f"{samples_dir}/prompt_{idx}.txt", "w") as f: - json.dump(content, f) + json.dump(json_data, f) f.write("\n") try: response = model.generate_content(content) diff --git a/scripts/gemini-caption-prompt-with-examples.txt b/scripts/gemini-caption-prompt-with-examples.txt new file mode 100644 index 0000000..091c897 --- /dev/null +++ b/scripts/gemini-caption-prompt-with-examples.txt @@ -0,0 +1,4 @@ +I have provided you {num} image(s) and their corresponding caption(s) as example(s) and one last image without a caption. + +Generate the caption for the last remaining image based on your understanding of the last image and provided image-caption examples. +Only respond with the caption string; do not say any other words or explain. The answer must be compulsarily in one line. diff --git a/scripts/generator_prompt.py b/scripts/generator_prompt.py index 8c6cdb2..bde2b81 100644 --- a/scripts/generator_prompt.py +++ b/scripts/generator_prompt.py @@ -3,6 +3,8 @@ from mimetypes import guess_type from PIL import Image +from vertexai import generative_models + class Prompt: @@ -37,6 +39,14 @@ def prepare_gpt_message(self, num_candidates): else: prompt = self.prompt_template return prompt + + def prepare_gemini_message(self, num_candidates): + num_examples = min(self.k, num_candidates) + if self.k > 0: + prompt = self.prompt_template.format(num=num_examples) + else: + prompt = self.prompt_template.format(num=num_examples) + return prompt def merge_images(self, retrieval_results, query_image_path, dist_images=5): if self.k == 0: @@ -95,6 +105,17 @@ def get_fewshot_image_urls(self, retrieval_results): assert hit[1] image_urls.append(self.encode_image_as_url(hit[1])) return image_urls + + def get_fewshot_image_data(self, retrieval_results): + num_examples = min(self.k, len(retrieval_results)) + image_datas = [] + for index, hit in enumerate(retrieval_results): + if index == num_examples: + break + assert hit[1] + temp = generative_models.Image.load_from_file(hit[1]) + image_datas.append(generative_models.Part.from_image(temp)) + return image_datas def get_fewshot_captions(self, retrieval_results): num_examples = min(self.k, len(retrieval_results))