Skip to content

Commit 1582c2a

Browse files
Fix to also prepend openrouter/ when muxing (#983)
Moved inside a function and called parent
1 parent 0753bd6 commit 1582c2a

File tree

3 files changed

+25
-12
lines changed

3 files changed

+25
-12
lines changed

src/codegate/muxing/adapter.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,10 @@ def _from_anthropic_to_openai(self, anthropic_body: dict) -> dict:
5858

5959
def _get_provider_formatted_url(self, model_route: rulematcher.ModelRoute) -> str:
6060
"""Get the provider formatted URL to use in base_url. Note this value comes from DB"""
61-
if model_route.endpoint.provider_type == db_models.ProviderType.openai:
61+
if model_route.endpoint.provider_type in [
62+
db_models.ProviderType.openai,
63+
db_models.ProviderType.openrouter,
64+
]:
6265
return f"{model_route.endpoint.endpoint}/v1"
6366
return model_route.endpoint.endpoint
6467

src/codegate/providers/openrouter/provider.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from fastapi import Header, HTTPException, Request
55
from litellm.types.llms.openai import ChatCompletionRequest
66

7+
from codegate.clients.clients import ClientType
78
from codegate.clients.detector import DetectClient
89
from codegate.pipeline.factory import PipelineFactory
910
from codegate.providers.fim_analyzer import FIMAnalyzer
@@ -34,6 +35,21 @@ def __init__(self, pipeline_factory: PipelineFactory):
3435
def provider_route_name(self) -> str:
3536
return "openrouter"
3637

38+
async def process_request(
39+
self,
40+
data: dict,
41+
api_key: str,
42+
is_fim_request: bool,
43+
client_type: ClientType,
44+
):
45+
# litellm workaround - add openrouter/ prefix to model name to make it openai-compatible
46+
# once we get rid of litellm, this can simply be removed
47+
original_model = data.get("model", "")
48+
if not original_model.startswith("openrouter/"):
49+
data["model"] = f"openrouter/{original_model}"
50+
51+
return await super().process_request(data, api_key, is_fim_request, client_type)
52+
3753
def _setup_routes(self):
3854
@self.router.post(f"/{self.provider_route_name}/api/v1/chat/completions")
3955
@self.router.post(f"/{self.provider_route_name}/chat/completions")
@@ -52,14 +68,8 @@ async def create_completion(
5268

5369
base_url = self._get_base_url()
5470
data["base_url"] = base_url
55-
56-
# litellm workaround - add openrouter/ prefix to model name to make it openai-compatible
57-
# once we get rid of litellm, this can simply be removed
58-
original_model = data.get("model", "")
59-
if not original_model.startswith("openrouter/"):
60-
data["model"] = f"openrouter/{original_model}"
61-
6271
is_fim_request = FIMAnalyzer.is_fim_request(request.url.path, data)
72+
6373
return await self.process_request(
6474
data,
6575
api_key,

tests/providers/openrouter/test_openrouter_provider.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import json
2-
from unittest.mock import AsyncMock, MagicMock
2+
from unittest.mock import AsyncMock, MagicMock, patch
33

44
import pytest
55
from fastapi import HTTPException
@@ -26,11 +26,11 @@ def test_get_base_url(provider):
2626

2727

2828
@pytest.mark.asyncio
29-
async def test_model_prefix_added():
29+
@patch("codegate.providers.openai.OpenAIProvider.process_request")
30+
async def test_model_prefix_added(mocked_parent_process_request):
3031
"""Test that model name gets prefixed with openrouter/ when not already present"""
3132
mock_factory = MagicMock(spec=PipelineFactory)
3233
provider = OpenRouterProvider(mock_factory)
33-
provider.process_request = AsyncMock()
3434

3535
# Mock request
3636
mock_request = MagicMock(spec=Request)
@@ -47,7 +47,7 @@ async def test_model_prefix_added():
4747
await create_completion(request=mock_request, authorization="Bearer test-token")
4848

4949
# Verify process_request was called with prefixed model
50-
call_args = provider.process_request.call_args[0]
50+
call_args = mocked_parent_process_request.call_args[0]
5151
assert call_args[0]["model"] == "openrouter/gpt-4"
5252

5353

0 commit comments

Comments
 (0)