-
-
Notifications
You must be signed in to change notification settings - Fork 45
/
send_request.py
421 lines (393 loc) · 19.4 KB
/
send_request.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
#send_request.py
import aiohttp
import asyncio
import json
import logging
from typing import List, Union, Optional, Dict, Any
#from json_repair import repair_json
# Existing imports
from .anthropic_api import send_anthropic_request
from .ollama_api import send_ollama_request, create_ollama_embedding
from .openai_api import send_openai_request, create_openai_compatible_embedding, generate_image, generate_image_variations, edit_image
from .xai_api import send_xai_request
from .kobold_api import send_kobold_request
from .groq_api import send_groq_request
from .lms_api import send_lmstudio_request
from .textgen_api import send_textgen_request
from .llamacpp_api import send_llama_cpp_request
from .mistral_api import send_mistral_request
from .vllm_api import send_vllm_request
from .gemini_api import send_gemini_request
from .transformers_api import TransformersModelManager # Import the manager
from .utils import convert_images_for_api, format_response
# Set up logging
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# Initialize the TransformersModelManager
_transformers_manager = TransformersModelManager()
def run_async(coroutine):
"""Helper function to run coroutines in a new event loop if necessary"""
try:
loop = asyncio.get_event_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
return loop.run_until_complete(coroutine)
async def send_request(
llm_provider: str,
base_ip: str,
port: str,
images: List[str],
llm_model: str,
system_message: str,
user_message: str,
messages: List[Dict[str, Any]],
seed: Optional[int],
temperature: float,
max_tokens: int,
random: bool,
top_k: int,
top_p: float,
repeat_penalty: float,
stop: Optional[List[str]],
keep_alive: bool,
llm_api_key: Optional[str] = None,
tools: Optional[Any] = None,
tool_choice: Optional[Any] = None,
precision: Optional[str] = "fp16",
attention: Optional[str] = "sdpa",
aspect_ratio: Optional[str] = "1:1",
strategy: Optional[str] = "normal",
batch_count: Optional[int] = 4,
mask: Optional[str] = None,
) -> Union[str, Dict[str, Any]]:
"""
Sends a request to the specified LLM provider and returns a unified response.
Args:
llm_provider (str): The LLM provider to use.
base_ip (str): Base IP address for the API.
port (int): Port number for the API.
base64_images (List[str]): List of images encoded in base64.
llm_model (str): The model to use.
system_message (str): System message for the LLM.
user_message (str): User message for the LLM.
messages (List[Dict[str, Any]]): Conversation messages.
seed (Optional[int]): Random seed.
temperature (float): Temperature for randomness.
max_tokens (int): Maximum tokens to generate.
random (bool): Whether to use randomness.
top_k (int): Top K for sampling.
top_p (float): Top P for sampling.
repeat_penalty (float): Penalty for repetition.
stop (Optional[List[str]]): Stop sequences.
keep_alive (bool): Whether to keep the session alive.
llm_api_key (Optional[str], optional): API key for the LLM provider.
tools (Optional[Any], optional): Tools to be used.
tool_choice (Optional[Any], optional): Tool choice.
precision (Optional[str], optional): Precision for the model.
attention (Optional[str], optional): Attention mechanism for the model.
aspect_ratio (Optional[str], optional): Desired aspect ratio for image generation/editing.
Options: "1:1", "4:5", "3:4", "5:4", "16:9", "9:16". Defaults to "1:1".
image_mode (Optional[str], optional): Mode for image processing.
Options: "create", "edit", "variations". Defaults to "create".
Returns:
Union[str, Dict[str, Any]]: Unified response format.
"""
try:
# Define aspect ratio to size mapping
aspect_ratio_mapping = {
"1:1": "1024x1024",
"4:5": "1024x1280",
"3:4": "1024x1365",
"5:4": "1280x1024",
"16:9": "1600x900",
"9:16": "900x1600"
}
# Get the size based on the provided aspect_ratio
size = aspect_ratio_mapping.get(aspect_ratio.lower(), "1024x1024") # Default to square if invalid
# Convert images to base64 format for API consumption
if llm_provider == "transformers":
formatted_images = convert_images_for_api(images, target_format='pil') if images is not None and len(images) > 0 else None
response = await _transformers_manager.send_transformers_request(
model_name=llm_model,
system_message=system_message,
user_message=user_message,
messages=messages,
max_new_tokens=max_tokens,
images=formatted_images,
temperature=temperature,
top_p=top_p,
top_k=top_k,
stop_strings_list=stop,
repetition_penalty=repeat_penalty,
seed=seed,
keep_alive=keep_alive,
precision=precision,
attention=attention
)
return response
else:
# For other providers, convert to base64 only if images exist
formatted_images = convert_images_for_api(images, target_format='base64') if images is not None and len(images) > 0 else None
#formatted_masks = convert_images_for_api(mask, target_format='base64') if mask is not None and len(mask) > 0 else None
api_functions = {
"groq": send_groq_request,
"anthropic": send_anthropic_request,
"openai": send_openai_request,
"xai": send_xai_request,
"kobold": send_kobold_request,
"ollama": send_ollama_request,
"lmstudio": send_lmstudio_request,
"textgen": send_textgen_request,
"llamacpp": send_llama_cpp_request,
"mistral": send_mistral_request,
"vllm": send_vllm_request,
"gemini": send_gemini_request,
"transformers": None, # Handled separately
}
if llm_provider not in api_functions and llm_provider != "transformers":
raise ValueError(f"Invalid llm_provider: {llm_provider}")
if llm_provider == "transformers":
# This should be handled above, but included for safety
raise ValueError("Transformers provider should be handled separately.")
else:
# Existing logic for other providers
api_function = api_functions[llm_provider]
# Prepare API-specific keyword arguments
kwargs = {}
if llm_provider == "ollama":
api_url = f"http://{base_ip}:{port}/api/chat"
kwargs = dict(
api_url=api_url,
base64_images=formatted_images,
model=llm_model,
system_message=system_message,
user_message=user_message,
messages=messages,
seed=seed,
temperature=temperature,
max_tokens=max_tokens,
random=random,
top_k=top_k,
top_p=top_p,
repeat_penalty=repeat_penalty,
stop=stop,
keep_alive=keep_alive,
tools=tools,
tool_choice=tool_choice,
)
elif llm_provider in ["kobold", "lmstudio", "textgen", "llamacpp", "vllm"]:
api_url = f"http://{base_ip}:{port}/v1/chat/completions"
kwargs = {
"api_url": api_url,
"base64_images": formatted_images,
"model": llm_model,
"system_message": system_message,
"user_message": user_message,
"messages": messages,
"seed": seed,
"temperature": temperature,
"max_tokens": max_tokens,
"top_k": top_k,
"top_p": top_p,
"repeat_penalty": repeat_penalty,
"stop": stop,
"tools": tools,
"tool_choice": tool_choice,
}
if llm_provider == "llamacpp":
kwargs.pop("tool_choice", None)
elif llm_provider == "vllm":
kwargs["api_key"] = llm_api_key
elif llm_provider == "gemini":
kwargs = {
"base64_images": formatted_images,
"model": llm_model,
"system_message": system_message,
"user_message": user_message,
"messages": messages,
"temperature": temperature,
"max_tokens": max_tokens,
"top_k": top_k,
"top_p": top_p,
"stop": stop,
"api_key": llm_api_key,
"tools": tools,
"tool_choice": tool_choice,
}
elif llm_provider == "openai":
if llm_model.startswith("dall-e"):
try:
# Handle image formatting for edit/variations
formatted_image = None
formatted_mask = None
if images is not None and (strategy == "edit" or strategy == "variations"):
# Convert to base64 and take first image only
formatted_images = convert_images_for_api(images[0:1], target_format='base64')
if formatted_images:
formatted_image = formatted_images[0]
# Handle mask for edit strategy
if strategy == "edit" and mask is not None:
formatted_masks = convert_images_for_api(mask[0:1], target_format='base64')
if formatted_masks:
formatted_mask = formatted_masks[0]
# Make appropriate API call based on strategy
if strategy == "create":
response = await generate_image(
prompt=user_message,
model=llm_model,
n=batch_count,
size=size,
api_key=llm_api_key
)
elif strategy == "edit":
response = await edit_image(
image_base64=formatted_image,
mask_base64=formatted_mask,
prompt=user_message,
model=llm_model,
n=batch_count,
size=size,
api_key=llm_api_key
)
elif strategy == "variations":
response = await generate_image_variations(
image_base64=formatted_image,
model=llm_model,
n=batch_count,
size=size,
api_key=llm_api_key
)
else:
raise ValueError(f"Invalid strategy: {strategy}")
# Return the response directly - it will be a list of base64 strings
return {"images": response}
except Exception as e:
error_msg = f"Error in DALL·E {strategy}: {str(e)}"
logger.error(error_msg)
return {"error": error_msg}
else:
api_url = f"https://api.openai.com/v1/chat/completions"
kwargs = {
"api_url": api_url,
"base64_images": formatted_images,
"model": llm_model,
"system_message": system_message,
"user_message": user_message,
"messages": messages,
"api_key": llm_api_key,
"seed": seed if random else None,
"temperature": temperature,
"max_tokens": max_tokens,
"top_p": top_p,
"repeat_penalty": repeat_penalty,
"tools": tools,
"tool_choice": tool_choice,
}
elif llm_provider == "xai":
api_url = f"https://api.x.ai/v1/chat/completions"
kwargs = {
"api_url": api_url,
"base64_images": formatted_images,
"model": llm_model,
"system_message": system_message,
"user_message": user_message,
"messages": messages,
"api_key": llm_api_key,
"seed": seed if random else None,
"temperature": temperature,
"max_tokens": max_tokens,
"top_p": top_p,
"repeat_penalty": repeat_penalty,
"tools": tools,
"tool_choice": tool_choice,
}
elif llm_provider == "anthropic":
kwargs = {
"api_key": llm_api_key,
"model": llm_model,
"system_message": system_message,
"user_message": user_message,
"messages": messages,
"temperature": temperature,
"max_tokens": max_tokens,
"base64_images": formatted_images,
"tools": tools,
"tool_choice": tool_choice
}
elif llm_provider == "groq":
kwargs = {
"base64_images": formatted_images,
"model": llm_model,
"system_message": system_message,
"user_message": user_message,
"messages": messages,
"api_key": llm_api_key,
"temperature": temperature,
"max_tokens": max_tokens,
"top_p": top_p,
"tools": tools,
"tool_choice": tool_choice,
}
elif llm_provider == "mistral":
kwargs = {
"base64_images": formatted_images,
"model": llm_model,
"system_message": system_message,
"user_message": user_message,
"messages": messages,
"api_key": llm_api_key,
"seed": seed if random else None,
"temperature": temperature,
"max_tokens": max_tokens,
"top_p": top_p,
"tools": tools,
"tool_choice": tool_choice,
}
else:
raise ValueError(f"Unsupported llm_provider: {llm_provider}")
response = await api_function(**kwargs)
# Ensure response is properly awaited if it's a coroutine
if asyncio.iscoroutine(response):
response = await response
if isinstance(response, dict):
choices = response.get("choices", [])
if choices and "content" in choices[0].get("message", {}):
content = choices[0]["message"]["content"]
if content.startswith("Error:"):
print(f"Error from {llm_provider} API: {content}")
if tools:
return response
else:
try:
return response["choices"][0]["message"]["content"]
except (KeyError, IndexError, TypeError) as e:
error_msg = f"Error formatting response: {str(e)}"
logger.error(error_msg)
return {"choices": [{"message": {"content": error_msg}}]}
except Exception as e:
logger.error(f"Exception in send_request: {str(e)}", exc_info=True)
return {"choices": [{"message": {"content": f"Exception: {str(e)}"}}]}
def format_response(response, tools):
"""Helper function to format the response consistently"""
if tools:
return response
try:
if isinstance(response, dict) and "choices" in response:
return response["choices"][0]["message"]["content"]
return response
except (KeyError, IndexError, TypeError) as e:
error_msg = f"Error formatting response: {str(e)}"
logger.error(error_msg)
return {"choices": [{"message": {"content": error_msg}}]}
async def create_embedding(embedding_provider: str, api_base: str, embedding_model: str, input: Union[str, List[str]], embedding_api_key: Optional[str] = None) -> Union[List[float], None]: # Correct return type hint
if embedding_provider == "ollama":
return await create_ollama_embedding(api_base, embedding_model, input)
elif embedding_provider in ["openai", "lmstudio", "llamacpp", "textgen", "mistral", "xai"]:
try:
return await create_openai_compatible_embedding(api_base, embedding_model, input, embedding_api_key) # Try block for more precise error handling
except ValueError as e:
print(f"Error creating embedding: {e}")
return None # Return None on error
else:
raise ValueError(f"Unsupported embedding_provider: {embedding_provider}")