-
-
Notifications
You must be signed in to change notification settings - Fork 116
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(VisualReplayStrategy): adapters: ultralytics, som, anthropic, go…
…ogle; remove_move_before_click; vision.py * add prompts/, adapters/openai.py, strategies/visual.py (wip) * adapters.anthropic * add anthropic.py * prompt with active segment descriptions * Set-of-Mark Prompting Adapter (#612) * Update openadapt/config.py * remove_move_before_click * started_counter; adapters.ultralytics * add vision.py * add openadapt/adapters/google.py * filter_masks_by_size * documentation * update README * add ultralytics * exclude alembic in black/flake8 * exclude .venv in black/flake8 * disable som adapter; remove logging * add adapters.google --------- Co-authored-by: Cody DeVilliers <[email protected]>
- Loading branch information
Showing
32 changed files
with
3,817 additions
and
1,269 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
34 changes: 34 additions & 0 deletions
34
alembic/versions/30a5ba9d6453_add_active_segment_description_and_.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
"""add active_segment_description and available_segment_descriptions | ||
Revision ID: 30a5ba9d6453 | ||
Revises: 530f0663324e | ||
Create Date: 2024-04-05 12:02:57.843244 | ||
""" | ||
from alembic import op | ||
import sqlalchemy as sa | ||
|
||
|
||
# revision identifiers, used by Alembic. | ||
revision = '30a5ba9d6453' | ||
down_revision = '530f0663324e' | ||
branch_labels = None | ||
depends_on = None | ||
|
||
|
||
def upgrade() -> None: | ||
# ### commands auto generated by Alembic - please adjust! ### | ||
with op.batch_alter_table('action_event', schema=None) as batch_op: | ||
batch_op.add_column(sa.Column('active_segment_description', sa.String(), nullable=True)) | ||
batch_op.add_column(sa.Column('available_segment_descriptions', sa.String(), nullable=True)) | ||
|
||
# ### end Alembic commands ### | ||
|
||
|
||
def downgrade() -> None: | ||
# ### commands auto generated by Alembic - please adjust! ### | ||
with op.batch_alter_table('action_event', schema=None) as batch_op: | ||
batch_op.drop_column('available_segment_descriptions') | ||
batch_op.drop_column('active_segment_description') | ||
|
||
# ### end Alembic commands ### |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
"""Adapters for completion and segmentation.""" | ||
|
||
from types import ModuleType | ||
|
||
from openadapt import config | ||
from . import anthropic | ||
from . import openai | ||
from . import replicate | ||
from . import som | ||
from . import ultralytics | ||
from . import google | ||
|
||
|
||
def get_default_prompt_adapter() -> ModuleType: | ||
"""Returns the default prompt adapter module. | ||
Returns: | ||
The module corresponding to the default prompt adapter. | ||
""" | ||
return { | ||
"openai": openai, | ||
"anthropic": anthropic, | ||
"google": google, | ||
}[config.DEFAULT_ADAPTER] | ||
|
||
|
||
def get_default_segmentation_adapter() -> ModuleType: | ||
"""Returns the default image segmentation adapter module. | ||
Returns: | ||
The module corresponding to the default segmentation adapter. | ||
""" | ||
return { | ||
"som": som, | ||
"replicate": replicate, | ||
"ultralytics": ultralytics, | ||
}[config.DEFAULT_SEGMENTATION_ADAPTER] | ||
|
||
|
||
__all__ = ["anthropic", "openai", "replicate", "som", "ultralytics", "google"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,146 @@ | ||
"""Adapter for Anthropic API with vision support.""" | ||
|
||
from pprint import pprint | ||
|
||
from loguru import logger | ||
import anthropic | ||
|
||
from openadapt import cache, config | ||
|
||
|
||
MAX_TOKENS = 4096 | ||
# from https://docs.anthropic.com/claude/docs/vision | ||
MAX_IMAGES = 20 | ||
MODEL_NAME = "claude-3-opus-20240229" | ||
|
||
|
||
@cache.cache() | ||
def create_payload( | ||
prompt: str, | ||
system_prompt: str | None = None, | ||
base64_images: list[tuple[str, str]] | None = None, | ||
model: str = MODEL_NAME, | ||
max_tokens: int | None = None, | ||
) -> dict: | ||
"""Creates the payload for the Anthropic API request with image support.""" | ||
messages = [] | ||
|
||
user_message_content = [] | ||
|
||
max_tokens = max_tokens or MAX_TOKENS | ||
if max_tokens > MAX_TOKENS: | ||
logger.warning(f"{max_tokens=} > {MAX_TOKENS=}") | ||
max_tokens = MAX_TOKENS | ||
|
||
# Add base64 encoded images to the user message content | ||
if base64_images: | ||
for image_data in base64_images: | ||
# Extract media type and base64 data | ||
media_type, base64_str = image_data.split(";base64,", 1) | ||
media_type = media_type.split(":")[-1] # Remove 'data:' prefix | ||
|
||
user_message_content.append( | ||
{ | ||
"type": "image", | ||
"source": { | ||
"type": "base64", | ||
"media_type": media_type, | ||
"data": base64_str, | ||
}, | ||
} | ||
) | ||
|
||
# Add text prompt | ||
user_message_content.append( | ||
{ | ||
"type": "text", | ||
"text": prompt, | ||
} | ||
) | ||
|
||
# Construct user message | ||
messages.append( | ||
{ | ||
"role": "user", | ||
"content": user_message_content, | ||
} | ||
) | ||
|
||
# Prepare the full payload | ||
payload = { | ||
"model": model, | ||
"max_tokens": max_tokens, | ||
"messages": messages, | ||
} | ||
|
||
# Add system_prompt as a top-level parameter if provided | ||
if system_prompt: | ||
payload["system"] = system_prompt | ||
|
||
return payload | ||
|
||
|
||
client = anthropic.Anthropic(api_key=config.ANTHROPIC_API_KEY) | ||
|
||
|
||
@cache.cache() | ||
def get_completion(payload: dict) -> str: | ||
"""Sends a request to the Anthropic API and returns the response.""" | ||
try: | ||
response = client.messages.create(**payload) | ||
except Exception as exc: | ||
logger.exception(exc) | ||
import ipdb | ||
|
||
ipdb.set_trace() | ||
""" | ||
Message( | ||
id='msg_01L55ai2A9q92687mmjMSch3', | ||
content=[ | ||
ContentBlock( | ||
text='{ | ||
"action": [ | ||
{ | ||
"name": "press", | ||
"key_name": "cmd", | ||
"canonical_key_name": "cmd" | ||
}, | ||
... | ||
] | ||
}', | ||
type='text' | ||
) | ||
], | ||
model='claude-3-opus-20240229', | ||
role='assistant', | ||
stop_reason='end_turn', | ||
stop_sequence=None, | ||
type='message', | ||
usage=Usage(input_tokens=4379, output_tokens=109)) | ||
""" | ||
texts = [content_block.text for content_block in response.content] | ||
return "\n".join(texts) | ||
|
||
|
||
def prompt( | ||
prompt: str, | ||
system_prompt: str | None = None, | ||
base64_images: list[str] | None = None, | ||
max_tokens: int | None = None, | ||
) -> str: | ||
"""Public method to get a response from the Anthropic API with image support.""" | ||
if len(base64_images) > MAX_IMAGES: | ||
# XXX TODO handle this | ||
raise Exception( | ||
f"{len(base64_images)=} > {MAX_IMAGES=}. Use a different adapter." | ||
) | ||
payload = create_payload( | ||
prompt, | ||
system_prompt, | ||
base64_images, | ||
max_tokens=max_tokens, | ||
) | ||
# pprint(f"payload=\n{payload}") # Log payload for debugging | ||
result = get_completion(payload) | ||
pprint(f"result=\n{result}") # Log result for debugging | ||
return result |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
"""Adapter for Google Gemini. | ||
See https://ai.google.dev/tutorials/python_quickstart for documentation. | ||
""" | ||
|
||
from pprint import pprint | ||
|
||
from PIL import Image | ||
import fire | ||
import google.generativeai as genai | ||
|
||
from openadapt import cache, config, utils | ||
|
||
|
||
MAX_TOKENS = 2**20 # 1048576 | ||
MODEL_NAME = [ | ||
"gemini-pro-vision", | ||
"models/gemini-1.5-pro-latest", | ||
][-1] | ||
# https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/send-multimodal-prompts | ||
MAX_IMAGES = { | ||
"gemini-pro-vision": 16, | ||
"models/gemini-1.5-pro-latest": 3000, | ||
}[MODEL_NAME] | ||
|
||
|
||
@cache.cache() | ||
def prompt( | ||
prompt: str, | ||
system_prompt: str | None = None, | ||
base64_images: list[str] | None = None, | ||
# max_tokens: int | None = None, | ||
model_name: str = MODEL_NAME, | ||
) -> str: | ||
"""Public method to get a response from the Google API with image support.""" | ||
full_prompt = "\n\n###\n\n".join([s for s in (system_prompt, prompt) if s]) | ||
# HACK | ||
full_prompt += "\nWhen responding in JSON, you MUST use double quotes around keys." | ||
|
||
# TODO: modify API across all adapters to accept PIL.Image | ||
images = ( | ||
[utils.utf82image(base64_image) for base64_image in base64_images] | ||
if base64_images | ||
else [] | ||
) | ||
|
||
genai.configure(api_key=config.GOOGLE_API_KEY) | ||
model = genai.GenerativeModel(model_name) | ||
response = model.generate_content([full_prompt] + images) | ||
response.resolve() | ||
pprint(f"response=\n{response}") # Log response for debugging | ||
return response.text | ||
|
||
|
||
def main(text: str, image_path: str | None = None) -> None: | ||
"""Prompt Google Gemini with text and a path to an image.""" | ||
if image_path: | ||
with Image.open(image_path) as img: | ||
# Convert image to RGB if it's RGBA (to remove alpha channel) | ||
if img.mode in ("RGBA", "LA") or ( | ||
img.mode == "P" and "transparency" in img.info | ||
): | ||
img = img.convert("RGB") | ||
base64_image = utils.image2utf8(img) | ||
else: | ||
base64_image = None | ||
|
||
base64_images = [base64_image] if base64_image else None | ||
output = prompt(text, base64_images=base64_images) | ||
print(output) | ||
|
||
|
||
if __name__ == "__main__": | ||
fire.Fire(main) |
Oops, something went wrong.