Skip to content

Commit

Permalink
🐛 Bug: Fix the bug where vercel cannot set app.state.config.
Browse files Browse the repository at this point in the history
  • Loading branch information
yym68686 committed Oct 19, 2024
1 parent 6719e81 commit 2c0a348
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 68 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,8 @@ There are other statistical data that you can query yourself by writing SQL in t

[![Deploy with Vercel](https://vercel.com/button)](https://vercel.com/new/clone?repository-url=https%3A%2F%2Fgithub.com%2Fyym68686%2Funi-api%2Ftree%2Fmain&env=CONFIG_URL,DISABLE_DATABASE&project-name=uni-api-vercel&repository-name=uni-api-vercel)

After clicking the one-click deployment button, set the environment variable `CONFIG_URL` to the direct link of the configuration file, and set `DISABLE_DATABASE` to true, then click Create to create the project.

## Docker local deployment

Start the container
Expand Down
2 changes: 1 addition & 1 deletion README_CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ yym68686/uni-api:latest

[![Deploy with Vercel](https://vercel.com/button)](https://vercel.com/new/clone?repository-url=https%3A%2F%2Fgithub.com%2Fyym68686%2Funi-api%2Ftree%2Fmain&env=CONFIG_URL,DISABLE_DATABASE&project-name=uni-api-vercel&repository-name=uni-api-vercel)

点击上面的一键部署按钮后,设置环境变量 `CONFIG_URL` 为配置文件的直链,然后点击 Create 创建项目。
点击上面的一键部署按钮后,设置环境变量 `CONFIG_URL` 为配置文件的直链, `DISABLE_DATABASE` 为 true,然后点击 Create 创建项目。

## Docker 本地部署

Expand Down
125 changes: 58 additions & 67 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,17 +106,6 @@ async def lifespan(app: FastAPI):
verify=True, # 保持 SSL 验证(如需禁用,设为 False,但不建议)
follow_redirects=True, # 自动跟随重定向
)
# app.state.client = httpx.AsyncClient(timeout=timeout)
app.state.config, app.state.api_keys_db, app.state.api_list = await load_config(app)

for item in app.state.api_keys_db:
if item.get("role") == "admin":
app.state.admin_api_key = item.get("api")
if not hasattr(app.state, "admin_api_key"):
if len(app.state.api_keys_db) >= 1:
app.state.admin_api_key = app.state.api_keys_db[0].get("api")
else:
raise Exception("No admin API key found")

yield
# 关闭时的代码
Expand Down Expand Up @@ -224,6 +213,41 @@ def calculate_cost(model: str, input_tokens: int, output_tokens: int) -> Decimal
# 返回精确到15位小数的结果
return total_cost.quantize(Decimal('0.000000000000001'))

async def update_stats(current_info):
if DISABLE_DATABASE:
return
# 这里添加更新数据库的逻辑
async with async_session() as session:
async with session.begin():
try:
columns = [column.key for column in RequestStat.__table__.columns]
filtered_info = {k: v for k, v in current_info.items() if k in columns}
new_request_stat = RequestStat(**filtered_info)
session.add(new_request_stat)
await session.commit()
except Exception as e:
await session.rollback()
logger.error(f"Error updating stats: {str(e)}")

async def update_channel_stats(request_id, provider, model, api_key, success):
if DISABLE_DATABASE:
return
async with async_session() as session:
async with session.begin():
try:
channel_stat = ChannelStat(
request_id=request_id,
provider=provider,
model=model,
api_key=api_key,
success=success,
)
session.add(channel_stat)
await session.commit()
except Exception as e:
await session.rollback()
logger.error(f"Error updating channel stats: {str(e)}")

class LoggingStreamingResponse(Response):
def __init__(self, content, status_code=200, headers=None, media_type=None, current_info=None):
super().__init__(content=None, status_code=status_code, headers=headers, media_type=media_type)
Expand Down Expand Up @@ -263,31 +287,14 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:

process_time = time() - self.current_info["start_time"]
self.current_info["process_time"] = process_time
await self.update_stats()

async def update_stats(self):
# 这里添加更新数据库的逻辑
# print("current_info2")
if DISABLE_DATABASE:
return
async with async_session() as session:
async with session.begin():
try:
columns = [column.key for column in RequestStat.__table__.columns]
filtered_info = {k: v for k, v in self.current_info.items() if k in columns}
new_request_stat = RequestStat(**filtered_info)
session.add(new_request_stat)
await session.commit()
except Exception as e:
await session.rollback()
logger.error(f"Error updating stats: {str(e)}")
await update_stats(self.current_info)

async def _logging_iterator(self):
try:
async for chunk in self.body_iterator:
if isinstance(chunk, str):
chunk = chunk.encode('utf-8')
line = chunk.decode()
line = chunk.decode('utf-8')
if is_debug:
logger.info(f"{line}")
if line.startswith("data:"):
Expand Down Expand Up @@ -435,41 +442,6 @@ async def dispatch(self, request: Request, call_next):
# print("current_request_info", current_request_info)
request_info.reset(current_request_info)

async def update_stats(self, current_info):
if DISABLE_DATABASE:
return
# 这里添加更新数据库的逻辑
async with async_session() as session:
async with session.begin():
try:
columns = [column.key for column in RequestStat.__table__.columns]
filtered_info = {k: v for k, v in current_info.items() if k in columns}
new_request_stat = RequestStat(**filtered_info)
session.add(new_request_stat)
await session.commit()
except Exception as e:
await session.rollback()
logger.error(f"Error updating stats: {str(e)}")

async def update_channel_stats(self, request_id, provider, model, api_key, success):
if DISABLE_DATABASE:
return
async with async_session() as session:
async with session.begin():
try:
channel_stat = ChannelStat(
request_id=request_id,
provider=provider,
model=model,
api_key=api_key,
success=success,
)
session.add(channel_stat)
await session.commit()
except Exception as e:
await session.rollback()
logger.error(f"Error updating channel stats: {str(e)}")

async def moderate_content(self, content, token):
moderation_request = ModerationRequest(input=content)

Expand Down Expand Up @@ -500,6 +472,23 @@ async def moderate_content(self, content, token):

app.add_middleware(StatsMiddleware)

@app.middleware("http")
async def ensure_config(request: Request, call_next):
if not hasattr(app.state, 'config'):
logger.warning("Config not found, attempting to reload")
app.state.config, app.state.api_keys_db, app.state.api_list = await load_config(app)

for item in app.state.api_keys_db:
if item.get("role") == "admin":
app.state.admin_api_key = item.get("api")
if not hasattr(app.state, "admin_api_key"):
if len(app.state.api_keys_db) >= 1:
app.state.admin_api_key = app.state.api_keys_db[0].get("api")
else:
raise Exception("No admin API key found")

return await call_next(request)

# 在 process_request 函数中更新成功和失败计数
async def process_request(request: Union[RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest], provider: Dict, endpoint=None, token=None):
url = provider['base_url']
Expand Down Expand Up @@ -581,14 +570,16 @@ async def process_request(request: Union[RequestModel, ImageGenerationRequest, A
# response = JSONResponse(first_element)

# 更新成功计数和首次响应时间
await app.middleware_stack.app.update_channel_stats(current_info["request_id"], provider['provider'], request.model, token, success=True)
await update_channel_stats(current_info["request_id"], provider['provider'], request.model, token, success=True)
# await app.middleware_stack.app.update_channel_stats(current_info["request_id"], provider['provider'], request.model, token, success=True)
current_info["first_response_time"] = first_response_time
current_info["success"] = True
current_info["provider"] = provider['provider']

return response
except (Exception, HTTPException, asyncio.CancelledError, httpx.ReadError, httpx.RemoteProtocolError) as e:
await app.middleware_stack.app.update_channel_stats(current_info["request_id"], provider['provider'], request.model, token, success=False)
await update_channel_stats(current_info["request_id"], provider['provider'], request.model, token, success=False)
# await app.middleware_stack.app.update_channel_stats(current_info["request_id"], provider['provider'], request.model, token, success=False)

raise e

Expand Down

0 comments on commit 2c0a348

Please sign in to comment.