-
Notifications
You must be signed in to change notification settings - Fork 24
/
Copy pathutils.py
105 lines (93 loc) · 2.91 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import importlib
import asyncio
import openai
from recbole.utils import get_model as recbole_get_model
import os
true=True
false=False
def check_path(path):
if not os.path.exists(path):
os.makedirs(path)
def get_model(model_name):
if importlib.util.find_spec(f'model.{model_name.lower()}', __name__):
model_module = importlib.import_module(f'model.{model_name.lower()}', __name__)
model_class = getattr(model_module, model_name)
return model_class
else:
return recbole_get_model(model_name)
async def dispatch_openai_requests(
messages_list,
model: str,
temperature: float
):
"""Dispatches requests to OpenAI API asynchronously.
Args:
messages_list: List of messages to be sent to OpenAI ChatCompletion API.
model: OpenAI model to use.
temperature: Temperature to use for the model.
max_tokens: Maximum number of tokens to generate.
top_p: Top p to use for the model.
Returns:
List of responses from OpenAI API.
"""
async_responses = [
openai.ChatCompletion.acreate(
model=model,
messages=x,
temperature=temperature
)
for x in messages_list
]
return await asyncio.gather(*async_responses)
def dispatch_single_openai_requests(
message,
model: str,
temperature: float
):
"""Dispatches requests to OpenAI API asynchronously.
Args:
messages_list: List of messages to be sent to OpenAI ChatCompletion API.
model: OpenAI model to use.
temperature: Temperature to use for the model.
max_tokens: Maximum number of tokens to generate.
top_p: Top p to use for the model.
Returns:
List of responses from OpenAI API.
"""
responses = openai.ChatCompletion.create(
model=model,
messages=message,
temperature=temperature
)
return responses
amazon_dataset2fullname = {
'Beauty': 'All_Beauty',
'Fashion': 'AMAZON_FASHION',
'Appliances': 'Appliances',
'Arts': 'Arts_Crafts_and_Sewing',
'Automotive': 'Automotive',
'Books': 'Books',
'CDs': 'CDs_and_Vinyl',
'Cell': 'Cell_Phones_and_Accessories',
'Clothing': 'Clothing_Shoes_and_Jewelry',
'Music': 'Digital_Music',
'Electronics': 'Electronics',
'Gift': 'Gift_Cards',
'Food': 'Grocery_and_Gourmet_Food',
'Home': 'Home_and_Kitchen',
'Scientific': 'Industrial_and_Scientific',
'Kindle': 'Kindle_Store',
'Luxury': 'Luxury_Beauty',
'Magazine': 'Magazine_Subscriptions',
'Movies': 'Movies_and_TV',
'Instruments': 'Musical_Instruments',
'Office': 'Office_Products',
'Garden': 'Patio_Lawn_and_Garden',
'Pantry': 'Prime_Pantry',
'Pet': 'Pet_Supplies',
'Software': 'Software',
'Sports': 'Sports_and_Outdoors',
'Tools': 'Tools_and_Home_Improvement',
'Toys': 'Toys_and_Games',
'Games': 'Video_Games'
}