Skip to content

Commit

Permalink
Merge branch 'main' into issue-209
Browse files Browse the repository at this point in the history
  • Loading branch information
yrobla authored Feb 6, 2025
2 parents 6cafe76 + 8da7955 commit c6af226
Show file tree
Hide file tree
Showing 19 changed files with 233 additions and 70 deletions.
3 changes: 2 additions & 1 deletion api/openapi.json
Original file line number Diff line number Diff line change
Expand Up @@ -1607,7 +1607,8 @@
"vllm",
"ollama",
"lm_studio",
"llamacpp"
"llamacpp",
"openai"
],
"title": "ProviderType",
"description": "Represents the different types of providers we support."
Expand Down
10 changes: 5 additions & 5 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ PyYAML = "==6.0.2"
fastapi = "==0.115.8"
uvicorn = "==0.34.0"
structlog = "==25.1.0"
litellm = "==1.60.2"
litellm = "==1.60.4"
llama_cpp_python = "==0.3.5"
cryptography = "==44.0.0"
sqlalchemy = "==2.0.37"
Expand Down Expand Up @@ -41,7 +41,7 @@ ruff = "==0.9.4"
bandit = "==1.8.2"
build = "==1.2.2.post1"
wheel = "==0.45.1"
litellm = "==1.60.2"
litellm = "==1.60.4"
pytest-asyncio = "==0.25.3"
llama_cpp_python = "==0.3.5"
scikit-learn = "==1.6.1"
Expand Down
5 changes: 3 additions & 2 deletions src/codegate/api/v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from fastapi.routing import APIRoute
from pydantic import BaseModel, ValidationError

import codegate.muxing.models as mux_models
from codegate import __version__
from codegate.api import v1_models, v1_processing
from codegate.db.connection import AlreadyExistsError, DbReader
Expand Down Expand Up @@ -477,7 +478,7 @@ async def delete_workspace_custom_instructions(workspace_name: str):
)
async def get_workspace_muxes(
workspace_name: str,
) -> List[v1_models.MuxRule]:
) -> List[mux_models.MuxRule]:
"""Get the mux rules of a workspace.
The list is ordered in order of priority. That is, the first rule in the list
Expand All @@ -501,7 +502,7 @@ async def get_workspace_muxes(
)
async def set_workspace_muxes(
workspace_name: str,
request: List[v1_models.MuxRule],
request: List[mux_models.MuxRule],
):
"""Set the mux rules of a workspace."""
try:
Expand Down
23 changes: 0 additions & 23 deletions src/codegate/api/v1_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,26 +267,3 @@ class ModelByProvider(pydantic.BaseModel):

def __str__(self):
return f"{self.provider_name} / {self.name}"


class MuxMatcherType(str, Enum):
"""
Represents the different types of matchers we support.
"""

# Always match this prompt
catch_all = "catch_all"


class MuxRule(pydantic.BaseModel):
"""
Represents a mux rule for a provider.
"""

provider_id: str
model: str
# The type of matcher to use
matcher_type: MuxMatcherType
# The actual matcher to use. Note that
# this depends on the matcher type.
matcher: Optional[str] = None
1 change: 1 addition & 0 deletions src/codegate/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# Default provider URLs
DEFAULT_PROVIDER_URLS = {
"openai": "https://api.openai.com/v1",
"openrouter": "https://openrouter.ai/api/v1",
"anthropic": "https://api.anthropic.com/v1",
"vllm": "http://localhost:8000", # Base URL without /v1 path
"ollama": "http://localhost:11434", # Default Ollama server URL
Expand Down
1 change: 1 addition & 0 deletions src/codegate/db/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ class ProviderType(str, Enum):
ollama = "ollama"
lm_studio = "lm_studio"
llamacpp = "llamacpp"
openrouter = "openai"


class GetPromptWithOutputsRow(BaseModel):
Expand Down
2 changes: 2 additions & 0 deletions src/codegate/muxing/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@ def __init__(self):
db_models.ProviderType.anthropic: self._format_antropic,
# Our Lllamacpp provider emits OpenAI chunks
db_models.ProviderType.llamacpp: self._format_openai,
# OpenRouter is a dialect of OpenAI
db_models.ProviderType.openrouter: self._format_openai,
}

def _format_ollama(self, chunk: str) -> str:
Expand Down
27 changes: 27 additions & 0 deletions src/codegate/muxing/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from enum import Enum
from typing import Optional

import pydantic


class MuxMatcherType(str, Enum):
"""
Represents the different types of matchers we support.
"""

# Always match this prompt
catch_all = "catch_all"


class MuxRule(pydantic.BaseModel):
"""
Represents a mux rule for a provider.
"""

provider_id: str
model: str
# The type of matcher to use
matcher_type: MuxMatcherType
# The actual matcher to use. Note that
# this depends on the matcher type.
matcher: Optional[str] = None
6 changes: 5 additions & 1 deletion src/codegate/muxing/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import structlog
from fastapi import APIRouter, HTTPException, Request

from codegate.clients.detector import DetectClient
from codegate.muxing import rulematcher
from codegate.muxing.adapter import BodyAdapter, ResponseAdapter
from codegate.providers.registry import ProviderRegistry
Expand Down Expand Up @@ -38,6 +39,7 @@ def _ensure_path_starts_with_slash(self, path: str) -> str:
def _setup_routes(self):

@self.router.post(f"/{self.route_name}/{{rest_of_path:path}}")
@DetectClient()
async def route_to_dest_provider(
request: Request,
rest_of_path: str = "",
Expand Down Expand Up @@ -73,7 +75,9 @@ async def route_to_dest_provider(
api_key = model_route.auth_material.auth_blob

# Send the request to the destination provider. It will run the pipeline
response = await provider.process_request(new_data, api_key, rest_of_path)
response = await provider.process_request(
new_data, api_key, rest_of_path, request.state.detected_client
)
# Format the response to the client always using the OpenAI format
return self._response_adapter.format_response_to_client(
response, model_route.endpoint.provider_type
Expand Down
2 changes: 2 additions & 0 deletions src/codegate/providers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@
from codegate.providers.base import BaseProvider
from codegate.providers.ollama.provider import OllamaProvider
from codegate.providers.openai.provider import OpenAIProvider
from codegate.providers.openrouter.provider import OpenRouterProvider
from codegate.providers.registry import ProviderRegistry
from codegate.providers.vllm.provider import VLLMProvider

__all__ = [
"BaseProvider",
"ProviderRegistry",
"OpenAIProvider",
"OpenRouterProvider",
"AnthropicProvider",
"VLLMProvider",
"OllamaProvider",
Expand Down
8 changes: 8 additions & 0 deletions src/codegate/providers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from codegate.clients.clients import ClientType
from codegate.codegate_logging import setup_logging
from codegate.config import Config
from codegate.db.connection import DbRecorder
from codegate.pipeline.base import (
PipelineContext,
Expand Down Expand Up @@ -88,6 +89,13 @@ async def process_request(
def provider_route_name(self) -> str:
pass

def _get_base_url(self) -> str:
"""
Get the base URL from config with proper formatting
"""
config = Config.get_config()
return config.provider_urls.get(self.provider_route_name) if config else ""

async def _run_output_stream_pipeline(
self,
input_context: PipelineContext,
Expand Down
3 changes: 3 additions & 0 deletions src/codegate/providers/openai/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from codegate.providers.openai.provider import OpenAIProvider

__all__ = ["OpenAIProvider"]
47 changes: 47 additions & 0 deletions src/codegate/providers/openrouter/provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import json

from fastapi import Header, HTTPException, Request

from codegate.clients.detector import DetectClient
from codegate.pipeline.factory import PipelineFactory
from codegate.providers.openai import OpenAIProvider


class OpenRouterProvider(OpenAIProvider):
def __init__(self, pipeline_factory: PipelineFactory):
super().__init__(pipeline_factory)

@property
def provider_route_name(self) -> str:
return "openrouter"

def _setup_routes(self):
@self.router.post(f"/{self.provider_route_name}/api/v1/chat/completions")
@self.router.post(f"/{self.provider_route_name}/chat/completions")
@DetectClient()
async def create_completion(
request: Request,
authorization: str = Header(..., description="Bearer token"),
):
if not authorization.startswith("Bearer "):
raise HTTPException(status_code=401, detail="Invalid authorization header")

api_key = authorization.split(" ")[1]
body = await request.body()
data = json.loads(body)

base_url = self._get_base_url()
data["base_url"] = base_url

# litellm workaround - add openrouter/ prefix to model name to make it openai-compatible
# once we get rid of litellm, this can simply be removed
original_model = data.get("model", "")
if not original_model.startswith("openrouter/"):
data["model"] = f"openrouter/{original_model}"

return await self.process_request(
data,
api_key,
request.url.path,
request.state.detected_client,
)
4 changes: 1 addition & 3 deletions src/codegate/providers/vllm/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

from codegate.clients.clients import ClientType
from codegate.clients.detector import DetectClient
from codegate.config import Config
from codegate.pipeline.factory import PipelineFactory
from codegate.providers.base import BaseProvider, ModelFetchError
from codegate.providers.litellmshim import LiteLLmShim, sse_stream_generator
Expand Down Expand Up @@ -39,8 +38,7 @@ def _get_base_url(self) -> str:
"""
Get the base URL from config with proper formatting
"""
config = Config.get_config()
base_url = config.provider_urls.get("vllm") if config else ""
base_url = super()._get_base_url()
if base_url:
base_url = base_url.rstrip("/")
# Add /v1 if not present
Expand Down
5 changes: 5 additions & 0 deletions src/codegate/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from codegate.providers.lm_studio.provider import LmStudioProvider
from codegate.providers.ollama.provider import OllamaProvider
from codegate.providers.openai.provider import OpenAIProvider
from codegate.providers.openrouter.provider import OpenRouterProvider
from codegate.providers.registry import ProviderRegistry, get_provider_registry
from codegate.providers.vllm.provider import VLLMProvider

Expand Down Expand Up @@ -75,6 +76,10 @@ async def log_user_agent(request: Request, call_next):
ProviderType.openai,
OpenAIProvider(pipeline_factory),
)
registry.add_provider(
ProviderType.openrouter,
OpenRouterProvider(pipeline_factory),
)
registry.add_provider(
ProviderType.anthropic,
AnthropicProvider(
Expand Down
Loading

0 comments on commit c6af226

Please sign in to comment.