Skip to content

Commit

Permalink
feat(VisualReplayStrategy): adapters: ultralytics, som, anthropic, go…
Browse files Browse the repository at this point in the history
…ogle; remove_move_before_click; vision.py

* add prompts/, adapters/openai.py, strategies/visual.py (wip)

* adapters.anthropic

* add anthropic.py

* prompt with active segment descriptions

* Set-of-Mark Prompting Adapter (#612)

* Update openadapt/config.py

* remove_move_before_click

* started_counter; adapters.ultralytics

* add vision.py

* add openadapt/adapters/google.py

* filter_masks_by_size

* documentation

* update README

* add ultralytics

* exclude alembic in black/flake8

* exclude .venv in black/flake8

* disable som adapter; remove logging

* add adapters.google

---------

Co-authored-by: Cody DeVilliers <[email protected]>
  • Loading branch information
abrichr and Cody-DV authored Apr 16, 2024
1 parent cc645c4 commit 250943f
Show file tree
Hide file tree
Showing 32 changed files with 3,817 additions and 1,269 deletions.
2 changes: 1 addition & 1 deletion .flake8
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@ exclude =
.venv
docstring-convention = google
max-line-length = 88
extend-ignore = ANN101
extend-ignore = ANN101, E203
13 changes: 5 additions & 8 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@v3
with:
ref: ${{ env.BRANCH }}
repository: ${{ env.REPO }}

- name: Set up Python
uses: actions/setup-python@v3
Expand All @@ -35,12 +38,6 @@ jobs:
if: matrix.os == 'macos-latest'
run: sh install/install_openadapt.sh

- name: Checkout code
uses: actions/checkout@v3
with:
ref: ${{ env.BRANCH }}
repository: ${{ env.REPO }}

- name: Install poetry
uses: snok/install-poetry@v1
with:
Expand All @@ -63,7 +60,7 @@ jobs:
if: steps.cache-deps.outputs.cache-hit == 'true'

- name: Check formatting with Black
run: poetry run black --preview --check .
run: poetry run black --preview --check . --exclude '/(alembic|\.venv)/'

- name: Run Flake8
run: poetry run flake8
run: poetry run flake8 --exclude=alembic,.venv
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,11 @@ python -m openadapt.replay NaiveReplayStrategy
Other replay strategies include:

- [`StatefulReplayStrategy`](https://github.com/OpenAdaptAI/OpenAdapt/blob/main/openadapt/strategies/stateful.py): Proof-of-concept which uses the OpenAI GPT-4 API with prompts constructed via OS-level window data.
- [`VisualReplayStrategy`](https://github.com/OpenAdaptAI/OpenAdapt/blob/main/openadapt/strategies/visual.py): Uses [Fast Segment Anything Model (FastSAM)](https://github.com/CASIA-IVA-Lab/FastSAM) to segment active window. Accepts an "instructions" parameter that is used to modify the recording, e.g.:

```
python -m openadapt.replay VisualReplayStrategy --instructions "Multiply 9x5 instead of 6x8"
```

See https://github.com/OpenAdaptAI/OpenAdapt/tree/main/openadapt/strategies for a complete list. More ReplayStrategies coming soon! (see [Contributing](#Contributing)).

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
"""add active_segment_description and available_segment_descriptions
Revision ID: 30a5ba9d6453
Revises: 530f0663324e
Create Date: 2024-04-05 12:02:57.843244
"""
from alembic import op
import sqlalchemy as sa


# revision identifiers, used by Alembic.
revision = '30a5ba9d6453'
down_revision = '530f0663324e'
branch_labels = None
depends_on = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('action_event', schema=None) as batch_op:
batch_op.add_column(sa.Column('active_segment_description', sa.String(), nullable=True))
batch_op.add_column(sa.Column('available_segment_descriptions', sa.String(), nullable=True))

# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('action_event', schema=None) as batch_op:
batch_op.drop_column('available_segment_descriptions')
batch_op.drop_column('active_segment_description')

# ### end Alembic commands ###
40 changes: 40 additions & 0 deletions openadapt/adapters/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
"""Adapters for completion and segmentation."""

from types import ModuleType

from openadapt import config
from . import anthropic
from . import openai
from . import replicate
from . import som
from . import ultralytics
from . import google


def get_default_prompt_adapter() -> ModuleType:
"""Returns the default prompt adapter module.
Returns:
The module corresponding to the default prompt adapter.
"""
return {
"openai": openai,
"anthropic": anthropic,
"google": google,
}[config.DEFAULT_ADAPTER]


def get_default_segmentation_adapter() -> ModuleType:
"""Returns the default image segmentation adapter module.
Returns:
The module corresponding to the default segmentation adapter.
"""
return {
"som": som,
"replicate": replicate,
"ultralytics": ultralytics,
}[config.DEFAULT_SEGMENTATION_ADAPTER]


__all__ = ["anthropic", "openai", "replicate", "som", "ultralytics", "google"]
146 changes: 146 additions & 0 deletions openadapt/adapters/anthropic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
"""Adapter for Anthropic API with vision support."""

from pprint import pprint

from loguru import logger
import anthropic

from openadapt import cache, config


MAX_TOKENS = 4096
# from https://docs.anthropic.com/claude/docs/vision
MAX_IMAGES = 20
MODEL_NAME = "claude-3-opus-20240229"


@cache.cache()
def create_payload(
prompt: str,
system_prompt: str | None = None,
base64_images: list[tuple[str, str]] | None = None,
model: str = MODEL_NAME,
max_tokens: int | None = None,
) -> dict:
"""Creates the payload for the Anthropic API request with image support."""
messages = []

user_message_content = []

max_tokens = max_tokens or MAX_TOKENS
if max_tokens > MAX_TOKENS:
logger.warning(f"{max_tokens=} > {MAX_TOKENS=}")
max_tokens = MAX_TOKENS

# Add base64 encoded images to the user message content
if base64_images:
for image_data in base64_images:
# Extract media type and base64 data
media_type, base64_str = image_data.split(";base64,", 1)
media_type = media_type.split(":")[-1] # Remove 'data:' prefix

user_message_content.append(
{
"type": "image",
"source": {
"type": "base64",
"media_type": media_type,
"data": base64_str,
},
}
)

# Add text prompt
user_message_content.append(
{
"type": "text",
"text": prompt,
}
)

# Construct user message
messages.append(
{
"role": "user",
"content": user_message_content,
}
)

# Prepare the full payload
payload = {
"model": model,
"max_tokens": max_tokens,
"messages": messages,
}

# Add system_prompt as a top-level parameter if provided
if system_prompt:
payload["system"] = system_prompt

return payload


client = anthropic.Anthropic(api_key=config.ANTHROPIC_API_KEY)


@cache.cache()
def get_completion(payload: dict) -> str:
"""Sends a request to the Anthropic API and returns the response."""
try:
response = client.messages.create(**payload)
except Exception as exc:
logger.exception(exc)
import ipdb

ipdb.set_trace()
"""
Message(
id='msg_01L55ai2A9q92687mmjMSch3',
content=[
ContentBlock(
text='{
"action": [
{
"name": "press",
"key_name": "cmd",
"canonical_key_name": "cmd"
},
...
]
}',
type='text'
)
],
model='claude-3-opus-20240229',
role='assistant',
stop_reason='end_turn',
stop_sequence=None,
type='message',
usage=Usage(input_tokens=4379, output_tokens=109))
"""
texts = [content_block.text for content_block in response.content]
return "\n".join(texts)


def prompt(
prompt: str,
system_prompt: str | None = None,
base64_images: list[str] | None = None,
max_tokens: int | None = None,
) -> str:
"""Public method to get a response from the Anthropic API with image support."""
if len(base64_images) > MAX_IMAGES:
# XXX TODO handle this
raise Exception(
f"{len(base64_images)=} > {MAX_IMAGES=}. Use a different adapter."
)
payload = create_payload(
prompt,
system_prompt,
base64_images,
max_tokens=max_tokens,
)
# pprint(f"payload=\n{payload}") # Log payload for debugging
result = get_completion(payload)
pprint(f"result=\n{result}") # Log result for debugging
return result
74 changes: 74 additions & 0 deletions openadapt/adapters/google.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
"""Adapter for Google Gemini.
See https://ai.google.dev/tutorials/python_quickstart for documentation.
"""

from pprint import pprint

from PIL import Image
import fire
import google.generativeai as genai

from openadapt import cache, config, utils


MAX_TOKENS = 2**20 # 1048576
MODEL_NAME = [
"gemini-pro-vision",
"models/gemini-1.5-pro-latest",
][-1]
# https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/send-multimodal-prompts
MAX_IMAGES = {
"gemini-pro-vision": 16,
"models/gemini-1.5-pro-latest": 3000,
}[MODEL_NAME]


@cache.cache()
def prompt(
prompt: str,
system_prompt: str | None = None,
base64_images: list[str] | None = None,
# max_tokens: int | None = None,
model_name: str = MODEL_NAME,
) -> str:
"""Public method to get a response from the Google API with image support."""
full_prompt = "\n\n###\n\n".join([s for s in (system_prompt, prompt) if s])
# HACK
full_prompt += "\nWhen responding in JSON, you MUST use double quotes around keys."

# TODO: modify API across all adapters to accept PIL.Image
images = (
[utils.utf82image(base64_image) for base64_image in base64_images]
if base64_images
else []
)

genai.configure(api_key=config.GOOGLE_API_KEY)
model = genai.GenerativeModel(model_name)
response = model.generate_content([full_prompt] + images)
response.resolve()
pprint(f"response=\n{response}") # Log response for debugging
return response.text


def main(text: str, image_path: str | None = None) -> None:
"""Prompt Google Gemini with text and a path to an image."""
if image_path:
with Image.open(image_path) as img:
# Convert image to RGB if it's RGBA (to remove alpha channel)
if img.mode in ("RGBA", "LA") or (
img.mode == "P" and "transparency" in img.info
):
img = img.convert("RGB")
base64_image = utils.image2utf8(img)
else:
base64_image = None

base64_images = [base64_image] if base64_image else None
output = prompt(text, base64_images=base64_images)
print(output)


if __name__ == "__main__":
fire.Fire(main)
Loading

0 comments on commit 250943f

Please sign in to comment.