1
1
import json
2
+ from typing import Optional
2
3
3
4
import structlog
4
5
from fastapi import APIRouter , HTTPException , Request
5
6
6
- from codegate .clients .clients import ClientType
7
7
from codegate .clients .detector import DetectClient
8
- from codegate .extract_snippets .body_extractor import BodyCodeSnippetExtractorError
9
- from codegate .extract_snippets .factory import BodyCodeExtractorFactory
8
+ from codegate .muxing import models as mux_models
10
9
from codegate .muxing import rulematcher
11
10
from codegate .muxing .adapter import BodyAdapter , ResponseAdapter
11
+ from codegate .providers .fim_analyzer import FIMAnalyzer
12
12
from codegate .providers .registry import ProviderRegistry
13
13
from codegate .workspaces .crud import WorkspaceCrud
14
14
@@ -39,40 +39,20 @@ def get_routes(self) -> APIRouter:
39
39
def _ensure_path_starts_with_slash (self , path : str ) -> str :
40
40
return path if path .startswith ("/" ) else f"/{ path } "
41
41
42
- def _extract_request_filenames (self , detected_client : ClientType , data : dict ) -> set [str ]:
42
+ async def _get_model_route (
43
+ self , thing_to_match : mux_models .ThingToMatchMux
44
+ ) -> Optional [rulematcher .ModelRoute ]:
43
45
"""
44
- Extract filenames from the request data .
46
+ Get the model route for the given things_to_match .
45
47
"""
46
- try :
47
- body_extractor = BodyCodeExtractorFactory .create_snippet_extractor (detected_client )
48
- return body_extractor .extract_unique_filenames (data )
49
- except BodyCodeSnippetExtractorError as e :
50
- logger .error (f"Error extracting filenames from request: { e } " )
51
- return set ()
52
-
53
- async def _get_model_routes (self , filenames : set [str ]) -> list [rulematcher .ModelRoute ]:
54
- """
55
- Get the model routes for the given filenames.
56
- """
57
- model_routes = []
58
48
mux_registry = await rulematcher .get_muxing_rules_registry ()
59
49
try :
60
- # Try to get a catch_all route
61
- single_model_route = await mux_registry .get_match_for_active_workspace (
62
- thing_to_match = None
63
- )
64
- model_routes .append (single_model_route )
65
-
66
- # Get the model routes for each filename
67
- for filename in filenames :
68
- model_route = await mux_registry .get_match_for_active_workspace (
69
- thing_to_match = filename
70
- )
71
- model_routes .append (model_route )
50
+ # Try to get a model route for the active workspace
51
+ model_route = await mux_registry .get_match_for_active_workspace (thing_to_match )
52
+ return model_route
72
53
except Exception as e :
73
54
logger .error (f"Error getting active workspace muxes: { e } " )
74
55
raise HTTPException (str (e ), status_code = 404 )
75
- return model_routes
76
56
77
57
def _setup_routes (self ):
78
58
@@ -88,34 +68,38 @@ async def route_to_dest_provider(
88
68
1. Get destination provider from DB and active workspace.
89
69
2. Map the request body to the destination provider format.
90
70
3. Run pipeline. Selecting the correct destination provider.
91
- 4. Transmit the response back to the client in the correct format.
71
+ 4. Transmit the response back to the client in OpenAI format.
92
72
"""
93
73
94
74
body = await request .body ()
95
75
data = json .loads (body )
76
+ is_fim_request = FIMAnalyzer .is_fim_request (rest_of_path , data )
77
+
78
+ # 1. Get destination provider from DB and active workspace.
79
+ thing_to_match = mux_models .ThingToMatchMux (
80
+ body = data ,
81
+ url_request_path = rest_of_path ,
82
+ is_fim_request = is_fim_request ,
83
+ client_type = request .state .detected_client ,
84
+ )
85
+ model_route = await self ._get_model_route (thing_to_match )
86
+ if not model_route :
87
+ raise HTTPException (
88
+ "No matching rule found for the active workspace" , status_code = 404
89
+ )
96
90
97
- filenames_in_data = self ._extract_request_filenames (request .state .detected_client , data )
98
- logger .info (f"Extracted filenames from request: { filenames_in_data } " )
99
-
100
- model_routes = await self ._get_model_routes (filenames_in_data )
101
- if not model_routes :
102
- raise HTTPException ("No rule found for the active workspace" , status_code = 404 )
103
-
104
- # We still need some logic here to handle the case where we have multiple model routes.
105
- # For the moment since we match all only pick the first.
106
- model_route = model_routes [0 ]
107
-
108
- # Parse the input data and map it to the destination provider format
91
+ # 2. Map the request body to the destination provider format.
109
92
rest_of_path = self ._ensure_path_starts_with_slash (rest_of_path )
110
93
new_data = self ._body_adapter .map_body_to_dest (model_route , data )
94
+
95
+ # 3. Run pipeline. Selecting the correct destination provider.
111
96
provider = self ._provider_registry .get_provider (model_route .endpoint .provider_type )
112
97
api_key = model_route .auth_material .auth_blob
113
-
114
- # Send the request to the destination provider. It will run the pipeline
115
98
response = await provider .process_request (
116
- new_data , api_key , rest_of_path , request .state .detected_client
99
+ new_data , api_key , is_fim_request , request .state .detected_client
117
100
)
118
- # Format the response to the client always using the OpenAI format
101
+
102
+ # 4. Transmit the response back to the client in OpenAI format.
119
103
return self ._response_adapter .format_response_to_client (
120
104
response , model_route .endpoint .provider_type
121
105
)
0 commit comments