From 10f56ec7818e5812ed84bb53aee5a8b04e142360 Mon Sep 17 00:00:00 2001 From: parkervg Date: Tue, 15 Oct 2024 17:42:40 -0400 Subject: [PATCH] from_args for VQA ingredient --- blendsql/ingredients/builtin/vqa/main.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/blendsql/ingredients/builtin/vqa/main.py b/blendsql/ingredients/builtin/vqa/main.py index bf276334..3b6cf85b 100644 --- a/blendsql/ingredients/builtin/vqa/main.py +++ b/blendsql/ingredients/builtin/vqa/main.py @@ -1,9 +1,11 @@ from typing import List, Tuple +from attr import attrs, attrib from blendsql.models import Model from blendsql._program import Program from blendsql.ingredients.ingredient import MapIngredient from blendsql._exceptions import IngredientException +from blendsql.ingredients.utils import partialclass class ImageCaptionProgram(Program): @@ -21,11 +23,17 @@ def __call__( return ([output[0]["generated_text"].strip() for output in model_output], "") +@attrs class ImageCaption(MapIngredient): DESCRIPTION = """ If we need to generate a caption for an image stored in the database, we can use the scalar function to map to a new column: `{{ImageCaption('table::column')}}` """ + model: Model = attrib(default=None) + + @classmethod + def from_args(cls, model: Model = None): + return partialclass(cls, model=model) def run(self, model: Model, values: List[bytes], **kwargs): """Generates a caption for all byte images passed to it."""