Skip to content

Commit

Permalink
ext::ai: Add Anthropic support
Browse files Browse the repository at this point in the history
Wire in Antrhopic models: `claude-3-{haiku,sonnet,opus}`.
  • Loading branch information
elprans committed Apr 10, 2024
1 parent 9ecaa35 commit 832614e
Show file tree
Hide file tree
Showing 2 changed files with 164 additions and 0 deletions.
50 changes: 50 additions & 0 deletions edb/lib/ext/ai.edgeql
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,22 @@ CREATE EXTENSION PACKAGE ai VERSION '1.0' {
};
};

create type ext::ai::AnthropicProviderConfig extending ext::ai::ProviderConfig {
alter property name {
set protected := true;
set default := 'builtin::anthropic';
};

alter property display_name {
set protected := true;
set default := 'Anthropic';
};

alter property api_url {
set default := 'https://api.anthropic.com/v1'
};
};

create type ext::ai::Config extending cfg::ExtensionConfig {
create multi link providers: ext::ai::ProviderConfig {
create annotation std::description :=
Expand Down Expand Up @@ -238,6 +254,40 @@ CREATE EXTENSION PACKAGE ai VERSION '1.0' {
ext::ai::text_gen_model_context_window := "8192";
};

# Anthropic models.
create abstract type ext::ai::AnthropicClaude3HaikuModel
extending ext::ai::TextGenerationModel
{
alter annotation
ext::ai::model_name := "claude-3-haiku-20240307";
alter annotation
ext::ai::model_provider := "builtin::anthropic";
alter annotation
ext::ai::text_gen_model_context_window := "200000";
};

create abstract type ext::ai::AnthropicClaude3SonnetModel
extending ext::ai::TextGenerationModel
{
alter annotation
ext::ai::model_name := "claude-3-sonnet-20240229";
alter annotation
ext::ai::model_provider := "builtin::anthropic";
alter annotation
ext::ai::text_gen_model_context_window := "200000";
};

create abstract type ext::ai::AnthropicClaude3OpusModel
extending ext::ai::TextGenerationModel
{
alter annotation
ext::ai::model_name := "claude-3-opus-20240229";
alter annotation
ext::ai::model_provider := "builtin::anthropic";
alter annotation
ext::ai::text_gen_model_context_window := "200000";
};

create scalar type ext::ai::DistanceFunction
extending enum<Cosine, InnerProduct, L2>;

Expand Down
114 changes: 114 additions & 0 deletions edb/server/protocol/ai_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,10 @@ async def _start_chat(
await _start_mistral_chat(
protocol, request, response,
provider, model_name, messages, stream)
elif provider.name == "builtin::anthropic":
await _start_anthropic_chat(
protocol, request, response,
provider, model_name, messages, stream)
else:
raise RuntimeError(f"unsupported model provider: {provider.name}")

Expand Down Expand Up @@ -577,6 +581,116 @@ async def _start_mistral_chat(
stream,
)


async def _start_anthropic_chat(
protocol: protocol.HttpProtocol,
request: protocol.HttpRequest,
response: protocol.HttpResponse,
provider,
model_name: str,
messages: list[dict],
stream: bool,
) -> None:
client = httpx.AsyncClient(
headers={
"anthropic-version": "2023-06-01",
"anthropic-beta": "messages-2023-12-15",
"x-api-key": f"{provider.secret}",
},
base_url=provider.api_url,
)

anthropic_messages = []
system_prompt_parts = []
for message in messages:
if message["role"] == "system":
system_prompt_parts.append(message["content"])
else:
anthropic_messages.append(message)

system_prompt = "\n".join(system_prompt_parts)

if stream:
async with httpx_sse.aconnect_sse(
client,
method="POST",
url="/messages",
json={
"model": model_name,
"messages": anthropic_messages,
"stream": True,
"system": system_prompt,
"max_tokens": 4096,
}
) as event_source:
async for sse in event_source.aiter_sse():
if not response.sent:
response.status = http.HTTPStatus.OK
response.content_type = b'text/event-stream'
response.close_connection = False
response.custom_headers["Cache-Control"] = "no-cache"
protocol.write(request, response)

if sse.event == "message_start":
message = sse.json()["message"]
for k in tuple(message):
if k not in {"id", "type", "role", "model"}:
del message[k]
message_data = json.dumps(message).encode("utf-8")
event = (
b'event: message_start\n'
+ b'data: {"type": "message_start",'
+ b'"message":' + message_data + b'}\n\n'
)
protocol.write_raw(event)

elif sse.event == "content_block_start":
protocol.write_raw(
b'event: content_block_start\n'
+ b'data: ' + sse.data.encode("utf-8") + b'\n\n'
)
elif sse.event == "content_block_delta":
protocol.write_raw(
b'event: content_block_start\n'
+ b'data: ' + sse.data.encode("utf-8") + b'\n\n'
)
elif sse.event == "message_delta":
delta = sse.json()["delta"]
delta_data = json.dumps(delta).encode("utf-8")
event = (
b'event: message_delta\n'
+ b'data: {"type": "message_delta",'
+ b"delta:" + delta_data + b'}\n\n'
)
protocol.write_raw(event)
elif sse.event == "message_stop":
event = (
b'event: message_stop\n'
+ b'data: {"type": "message_stop"}\n\n'
)
protocol.write_raw(event)

protocol.close()

else:
result = await client.post(
"/messages",
json={
"model": model_name,
"messages": anthropic_messages,
"system": system_prompt,
"max_tokens": 4096,
}
)

response.status = http.HTTPStatus.OK
response.content_type = b'application/json'
response_text = result.json()["content"][0]["text"]
response.body = json.dumps({
"response": response_text,
}).encode("utf-8")


#
# HTTP API
#
Expand Down

0 comments on commit 832614e

Please sign in to comment.