Skip to content

Commit ad8aac6

Browse files
Add mux rules for match according to file and request type
Closes: #885 and #944 This PR introduces the necessary changes for: 1. Muxing to match on a specific filename (`main.py`) or file type (`.py`) 2. Muxing on a specific request type, i.e. chat or FIM The matching is tightly coupled with the priority of the muxing rules. A match will occur on the first possible rule. Example: Muxing rule 1 -> match all `.py` files -> go to `chat-gpt` Muxing rule 2 -> match all `.md` files -> go to `ollama` If in a request we recieve a `README.md` and a `main.py` file the request will be routed to `chat-gpt` since it's the match with the highest priority in the list. To introduce above changes there were some minor changes: 1. Separate FIM detection into its own class. Before it used to be part of `BaseProvider` 2. Detect if a request is FIM or not before calling `process_request` method. Both of them were necessary in order to re-use the logic to create the matcher for request type
1 parent d588f36 commit ad8aac6

File tree

16 files changed

+412
-218
lines changed

16 files changed

+412
-218
lines changed

src/codegate/muxing/models.py

+19
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33

44
import pydantic
55

6+
from codegate.clients.clients import ClientType
7+
68

79
class MuxMatcherType(str, Enum):
810
"""
@@ -11,6 +13,12 @@ class MuxMatcherType(str, Enum):
1113

1214
# Always match this prompt
1315
catch_all = "catch_all"
16+
# Match based on the filename. It will match if there is a filename
17+
# in the request that matches the matcher either extension or full name (*.py or main.py)
18+
filename_match = "filename_match"
19+
# Match based on the request type. It will match if the request type
20+
# matches the matcher (e.g. FIM or chat)
21+
request_type_match = "request_type_match"
1422

1523

1624
class MuxRule(pydantic.BaseModel):
@@ -25,3 +33,14 @@ class MuxRule(pydantic.BaseModel):
2533
# The actual matcher to use. Note that
2634
# this depends on the matcher type.
2735
matcher: Optional[str] = None
36+
37+
38+
class ThingToMatchMux(pydantic.BaseModel):
39+
"""
40+
Represents the fields we can use to match a mux rule.
41+
"""
42+
43+
body: dict
44+
url_request_path: str
45+
is_fim_request: bool
46+
client_type: ClientType

src/codegate/muxing/router.py

+31-47
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
import json
2+
from typing import Optional
23

34
import structlog
45
from fastapi import APIRouter, HTTPException, Request
56

6-
from codegate.clients.clients import ClientType
77
from codegate.clients.detector import DetectClient
8-
from codegate.extract_snippets.body_extractor import BodyCodeSnippetExtractorError
9-
from codegate.extract_snippets.factory import BodyCodeExtractorFactory
8+
from codegate.muxing import models as mux_models
109
from codegate.muxing import rulematcher
1110
from codegate.muxing.adapter import BodyAdapter, ResponseAdapter
11+
from codegate.providers.fim_analyzer import FIMAnalyzer
1212
from codegate.providers.registry import ProviderRegistry
1313
from codegate.workspaces.crud import WorkspaceCrud
1414

@@ -39,40 +39,20 @@ def get_routes(self) -> APIRouter:
3939
def _ensure_path_starts_with_slash(self, path: str) -> str:
4040
return path if path.startswith("/") else f"/{path}"
4141

42-
def _extract_request_filenames(self, detected_client: ClientType, data: dict) -> set[str]:
42+
async def _get_model_route(
43+
self, thing_to_match: mux_models.ThingToMatchMux
44+
) -> Optional[rulematcher.ModelRoute]:
4345
"""
44-
Extract filenames from the request data.
46+
Get the model route for the given things_to_match.
4547
"""
46-
try:
47-
body_extractor = BodyCodeExtractorFactory.create_snippet_extractor(detected_client)
48-
return body_extractor.extract_unique_filenames(data)
49-
except BodyCodeSnippetExtractorError as e:
50-
logger.error(f"Error extracting filenames from request: {e}")
51-
return set()
52-
53-
async def _get_model_routes(self, filenames: set[str]) -> list[rulematcher.ModelRoute]:
54-
"""
55-
Get the model routes for the given filenames.
56-
"""
57-
model_routes = []
5848
mux_registry = await rulematcher.get_muxing_rules_registry()
5949
try:
60-
# Try to get a catch_all route
61-
single_model_route = await mux_registry.get_match_for_active_workspace(
62-
thing_to_match=None
63-
)
64-
model_routes.append(single_model_route)
65-
66-
# Get the model routes for each filename
67-
for filename in filenames:
68-
model_route = await mux_registry.get_match_for_active_workspace(
69-
thing_to_match=filename
70-
)
71-
model_routes.append(model_route)
50+
# Try to get a model route for the active workspace
51+
model_route = await mux_registry.get_match_for_active_workspace(thing_to_match)
52+
return model_route
7253
except Exception as e:
7354
logger.error(f"Error getting active workspace muxes: {e}")
7455
raise HTTPException(str(e), status_code=404)
75-
return model_routes
7656

7757
def _setup_routes(self):
7858

@@ -88,34 +68,38 @@ async def route_to_dest_provider(
8868
1. Get destination provider from DB and active workspace.
8969
2. Map the request body to the destination provider format.
9070
3. Run pipeline. Selecting the correct destination provider.
91-
4. Transmit the response back to the client in the correct format.
71+
4. Transmit the response back to the client in OpenAI format.
9272
"""
9373

9474
body = await request.body()
9575
data = json.loads(body)
76+
is_fim_request = FIMAnalyzer.is_fim_request(rest_of_path, data)
77+
78+
# 1. Get destination provider from DB and active workspace.
79+
thing_to_match = mux_models.ThingToMatchMux(
80+
body=data,
81+
url_request_path=rest_of_path,
82+
is_fim_request=is_fim_request,
83+
client_type=request.state.detected_client,
84+
)
85+
model_route = await self._get_model_route(thing_to_match)
86+
if not model_route:
87+
raise HTTPException(
88+
"No matching rule found for the active workspace", status_code=404
89+
)
9690

97-
filenames_in_data = self._extract_request_filenames(request.state.detected_client, data)
98-
logger.info(f"Extracted filenames from request: {filenames_in_data}")
99-
100-
model_routes = await self._get_model_routes(filenames_in_data)
101-
if not model_routes:
102-
raise HTTPException("No rule found for the active workspace", status_code=404)
103-
104-
# We still need some logic here to handle the case where we have multiple model routes.
105-
# For the moment since we match all only pick the first.
106-
model_route = model_routes[0]
107-
108-
# Parse the input data and map it to the destination provider format
91+
# 2. Map the request body to the destination provider format.
10992
rest_of_path = self._ensure_path_starts_with_slash(rest_of_path)
11093
new_data = self._body_adapter.map_body_to_dest(model_route, data)
94+
95+
# 3. Run pipeline. Selecting the correct destination provider.
11196
provider = self._provider_registry.get_provider(model_route.endpoint.provider_type)
11297
api_key = model_route.auth_material.auth_blob
113-
114-
# Send the request to the destination provider. It will run the pipeline
11598
response = await provider.process_request(
116-
new_data, api_key, rest_of_path, request.state.detected_client
99+
new_data, api_key, is_fim_request, request.state.detected_client
117100
)
118-
# Format the response to the client always using the OpenAI format
101+
102+
# 4. Transmit the response back to the client in OpenAI format.
119103
return self._response_adapter.format_response_to_client(
120104
response, model_route.endpoint.provider_type
121105
)

src/codegate/muxing/rulematcher.py

+65-8
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,17 @@
11
import copy
22
from abc import ABC, abstractmethod
33
from asyncio import Lock
4-
from typing import List, Optional
4+
from typing import Dict, List, Optional
55

6+
import structlog
7+
8+
from codegate.clients.clients import ClientType
69
from codegate.db import models as db_models
10+
from codegate.extract_snippets.body_extractor import BodyCodeSnippetExtractorError
11+
from codegate.extract_snippets.factory import BodyCodeExtractorFactory
12+
from codegate.muxing import models as mux_models
13+
14+
logger = structlog.get_logger("codegate")
715

816
_muxrules_sgtn = None
917

@@ -40,11 +48,12 @@ def __init__(
4048
class MuxingRuleMatcher(ABC):
4149
"""Base class for matching muxing rules."""
4250

43-
def __init__(self, route: ModelRoute):
51+
def __init__(self, route: ModelRoute, matcher_blob: str):
4452
self._route = route
53+
self._matcher_blob = matcher_blob
4554

4655
@abstractmethod
47-
def match(self, thing_to_match) -> bool:
56+
def match(self, thing_to_match: mux_models.ThingToMatchMux) -> bool:
4857
"""Return True if the rule matches the thing_to_match."""
4958
pass
5059

@@ -61,23 +70,69 @@ class MuxingMatcherFactory:
6170
def create(mux_rule: db_models.MuxRule, route: ModelRoute) -> MuxingRuleMatcher:
6271
"""Create a muxing matcher for the given endpoint and model."""
6372

64-
factory = {
65-
"catch_all": CatchAllMuxingRuleMatcher,
73+
factory: Dict[mux_models.MuxMatcherType, MuxingRuleMatcher] = {
74+
mux_models.MuxMatcherType.catch_all: CatchAllMuxingRuleMatcher,
75+
mux_models.MuxMatcherType.filename_match: FileMuxingRuleMatcher,
76+
mux_models.MuxMatcherType.request_type_match: RequestTypeMuxingRuleMatcher,
6677
}
6778

6879
try:
69-
return factory[mux_rule.matcher_type](route)
80+
# Initialize the MuxingRuleMatcher
81+
return factory[mux_rule.matcher_type](route, mux_rule.matcher_blob)
7082
except KeyError:
7183
raise ValueError(f"Unknown matcher type: {mux_rule.matcher_type}")
7284

7385

7486
class CatchAllMuxingRuleMatcher(MuxingRuleMatcher):
7587
"""A catch all muxing rule matcher."""
7688

77-
def match(self, thing_to_match) -> bool:
89+
def match(self, thing_to_match: mux_models.ThingToMatchMux) -> bool:
7890
return True
7991

8092

93+
class FileMuxingRuleMatcher(MuxingRuleMatcher):
94+
"""A file muxing rule matcher."""
95+
96+
def _extract_request_filenames(self, detected_client: ClientType, data: dict) -> set[str]:
97+
"""
98+
Extract filenames from the request data.
99+
"""
100+
try:
101+
body_extractor = BodyCodeExtractorFactory.create_snippet_extractor(detected_client)
102+
return body_extractor.extract_unique_filenames(data)
103+
except BodyCodeSnippetExtractorError as e:
104+
logger.error(f"Error extracting filenames from request: {e}")
105+
return set()
106+
107+
def match(self, thing_to_match: mux_models.ThingToMatchMux) -> bool:
108+
"""
109+
Retun True if there is a filename in the request that matches the matcher_blob.
110+
The matcher_blob is either an extension (e.g. .py) or a filename (e.g. main.py).
111+
"""
112+
# If there is no matcher_blob, we don't match
113+
if not self._matcher_blob:
114+
return False
115+
filenames_to_match = self._extract_request_filenames(
116+
thing_to_match.client_type, thing_to_match.body
117+
)
118+
return any(self._matcher_blob in filename for filename in filenames_to_match)
119+
120+
121+
class RequestTypeMuxingRuleMatcher(MuxingRuleMatcher):
122+
"""A catch all muxing rule matcher."""
123+
124+
def match(self, thing_to_match: mux_models.ThingToMatchMux) -> bool:
125+
"""
126+
Return True if the request type matches the matcher_blob.
127+
The matcher_blob is either "fim" or "chat".
128+
"""
129+
# If there is no matcher_blob, we don't match
130+
if not self._matcher_blob:
131+
return False
132+
incoming_request_type = "fim" if thing_to_match.is_fim_request else "chat"
133+
return self._matcher_blob == incoming_request_type
134+
135+
81136
class MuxingRulesinWorkspaces:
82137
"""A thread safe dictionary to store the muxing rules in workspaces."""
83138

@@ -111,7 +166,9 @@ async def get_registries(self) -> List[str]:
111166
async with self._lock:
112167
return list(self._ws_rules.keys())
113168

114-
async def get_match_for_active_workspace(self, thing_to_match) -> Optional[ModelRoute]:
169+
async def get_match_for_active_workspace(
170+
self, thing_to_match: mux_models.ThingToMatchMux
171+
) -> Optional[ModelRoute]:
115172
"""Get the first match for the given thing_to_match."""
116173

117174
# We iterate over all the rules and return the first match

src/codegate/pipeline/secrets/secrets.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@
22
from abc import abstractmethod
33
from typing import List, Optional, Tuple
44

5-
from codegate.extract_snippets.factory import MessageCodeExtractorFactory
65
import structlog
76
from litellm import ChatCompletionRequest, ChatCompletionSystemMessage, ModelResponse
87
from litellm.types.utils import Delta, StreamingChoices
98

109
from codegate.config import Config
10+
from codegate.extract_snippets.factory import MessageCodeExtractorFactory
1111
from codegate.pipeline.base import (
1212
AlertSeverity,
1313
CodeSnippet,

src/codegate/providers/anthropic/provider.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from codegate.providers.anthropic.adapter import AnthropicInputNormalizer, AnthropicOutputNormalizer
1212
from codegate.providers.anthropic.completion_handler import AnthropicCompletion
1313
from codegate.providers.base import BaseProvider, ModelFetchError
14+
from codegate.providers.fim_analyzer import FIMAnalyzer
1415
from codegate.providers.litellmshim import anthropic_stream_generator
1516

1617

@@ -57,10 +58,9 @@ async def process_request(
5758
self,
5859
data: dict,
5960
api_key: str,
60-
request_url_path: str,
61+
is_fim_request: bool,
6162
client_type: ClientType,
6263
):
63-
is_fim_request = self._is_fim_request(request_url_path, data)
6464
try:
6565
stream = await self.complete(data, api_key, is_fim_request, client_type)
6666
except Exception as e:
@@ -98,10 +98,11 @@ async def create_message(
9898

9999
body = await request.body()
100100
data = json.loads(body)
101+
is_fim_request = FIMAnalyzer.is_fim_request(request.url.path, data)
101102

102103
return await self.process_request(
103104
data,
104105
x_api_key,
105-
request.url.path,
106+
is_fim_request,
106107
request.state.detected_client,
107108
)

src/codegate/providers/base.py

+1-56
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ async def process_request(
7979
self,
8080
data: dict,
8181
api_key: str,
82-
request_url_path: str,
82+
is_fim_request: bool,
8383
client_type: ClientType,
8484
):
8585
pass
@@ -173,61 +173,6 @@ async def _run_input_pipeline(
173173

174174
return result
175175

176-
def _is_fim_request_url(self, request_url_path: str) -> bool:
177-
"""
178-
Checks the request URL to determine if a request is FIM or chat completion.
179-
Used by: llama.cpp
180-
"""
181-
# Evaluate first a larger substring.
182-
if request_url_path.endswith("/chat/completions"):
183-
return False
184-
185-
# /completions is for OpenAI standard. /api/generate is for ollama.
186-
if request_url_path.endswith("/completions") or request_url_path.endswith("/api/generate"):
187-
return True
188-
189-
return False
190-
191-
def _is_fim_request_body(self, data: Dict) -> bool:
192-
"""
193-
Determine from the raw incoming data if it's a FIM request.
194-
Used by: OpenAI and Anthropic
195-
"""
196-
messages = data.get("messages", [])
197-
if not messages:
198-
return False
199-
200-
first_message_content = messages[0].get("content")
201-
if first_message_content is None:
202-
return False
203-
204-
fim_stop_sequences = ["</COMPLETION>", "<COMPLETION>", "</QUERY>", "<QUERY>"]
205-
if isinstance(first_message_content, str):
206-
msg_prompt = first_message_content
207-
elif isinstance(first_message_content, list):
208-
msg_prompt = first_message_content[0].get("text", "")
209-
else:
210-
logger.warning(f"Could not determine if message was FIM from data: {data}")
211-
return False
212-
return all([stop_sequence in msg_prompt for stop_sequence in fim_stop_sequences])
213-
214-
def _is_fim_request(self, request_url_path: str, data: Dict) -> bool:
215-
"""
216-
Determine if the request is FIM by the URL or the data of the request.
217-
"""
218-
# first check if we are in specific tools to discard FIM
219-
prompt = data.get("prompt", "")
220-
tools = ["cline", "kodu", "open interpreter"]
221-
for tool in tools:
222-
if tool in prompt.lower():
223-
# those tools can never be FIM
224-
return False
225-
# Avoid more expensive inspection of body by just checking the URL.
226-
if self._is_fim_request_url(request_url_path):
227-
return True
228-
229-
return self._is_fim_request_body(data)
230-
231176
async def _cleanup_after_streaming(
232177
self, stream: AsyncIterator[ModelResponse], context: PipelineContext
233178
) -> AsyncIterator[ModelResponse]:

0 commit comments

Comments
 (0)