|
| 1 | +import os |
| 2 | +from pathlib import Path |
| 3 | +from typing import Optional |
| 4 | + |
| 5 | +from fastapi import FastAPI, Query, Request, Depends, HTTPException, status |
| 6 | +from fastapi.middleware.cors import CORSMiddleware |
| 7 | +from fastapi.responses import HTMLResponse |
| 8 | +from fastapi.templating import Jinja2Templates |
| 9 | +from datetime import datetime, timedelta |
| 10 | +from vector_embedding_server.auth import authenticate_user, create_access_token, FAKE_USERS_DB, OAuth2PasswordRequestForm, get_user |
| 11 | + |
| 12 | +from vector_embedding_server.openai_like_api_models import ( |
| 13 | + EmbeddingResponse, |
| 14 | + EmbeddingInput, |
| 15 | + ModelName, |
| 16 | + EmbeddingData, |
| 17 | + Usage, |
| 18 | +) |
| 19 | +from vector_embedding_server import e5_large_v2 |
| 20 | +from pydantic import BaseModel |
| 21 | +from fastapi import Depends, HTTPException, status |
| 22 | +from fastapi.security import OAuth2PasswordBearer |
| 23 | +from jose import JWTError, jwt |
| 24 | + |
| 25 | + |
| 26 | +BASE_DIR = Path(__file__).resolve().parent |
| 27 | + |
| 28 | +templates = Jinja2Templates(directory=str(BASE_DIR / "templates")) |
| 29 | + |
| 30 | +app = FastAPI(title="Vector Embedding Server", docs_url=None) |
| 31 | + |
| 32 | + |
| 33 | +class Credentials(BaseModel): |
| 34 | + |
| 35 | + username: str |
| 36 | + password: str |
| 37 | + |
| 38 | + |
| 39 | +oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") |
| 40 | + |
| 41 | + |
| 42 | +def get_current_user(token: str = Depends(oauth2_scheme)): |
| 43 | + credentials_exception = HTTPException( |
| 44 | + status_code=status.HTTP_401_UNAUTHORIZED, |
| 45 | + detail="Could not validate credentials", |
| 46 | + headers={"WWW-Authenticate": "Bearer"}, |
| 47 | + ) |
| 48 | + try: |
| 49 | + payload = jwt.decode(token, "SECRET_KEY", algorithms=["HS256"]) |
| 50 | + print(payload) |
| 51 | + username: str = payload.get("sub") |
| 52 | + if username is None: |
| 53 | + raise credentials_exception |
| 54 | + except JWTError: |
| 55 | + raise credentials_exception |
| 56 | + print("username:udirtae", username) |
| 57 | + user = get_user(username) # Hier sollten Sie die Funktion zur Abrufung des Benutzers aus Ihrer Benutzerdatenbank aufrufen |
| 58 | + if user is None: |
| 59 | + raise credentials_exception |
| 60 | + return user |
| 61 | + |
| 62 | +@app.post("/token") |
| 63 | +def login(credentials: Credentials): |
| 64 | + user = authenticate_user(FAKE_USERS_DB, credentials.username, credentials.password) |
| 65 | + if not user: |
| 66 | + raise HTTPException( |
| 67 | + status_code=status.HTTP_401_UNAUTHORIZED, |
| 68 | + detail="Incorrect username or password", |
| 69 | + headers={"WWW-Authenticate": "Bearer"}, |
| 70 | + ) |
| 71 | + access_token_expires = timedelta(minutes=15) |
| 72 | + access_token = create_access_token( |
| 73 | + data={"sub": user["username"]}, expires_delta=access_token_expires |
| 74 | + ) |
| 75 | + return {"access_token": access_token} |
| 76 | + |
| 77 | + |
| 78 | +@app.post("/v1/embeddings", response_model=EmbeddingResponse) |
| 79 | +async def create_embedding(embedding_input: EmbeddingInput, current_user: str = Depends(get_current_user)): |
| 80 | + if embedding_input.model == ModelName.e5_large_v2: |
| 81 | + embedding, prompt_tokens = e5_large_v2.predict(embedding_input.input) |
| 82 | + else: |
| 83 | + raise NotImplemented |
| 84 | + |
| 85 | + embedding_data = EmbeddingData( |
| 86 | + object="embedding", |
| 87 | + embedding=embedding, |
| 88 | + index=0, |
| 89 | + ) |
| 90 | + |
| 91 | + usage = Usage(prompt_tokens=prompt_tokens, total_tokens=prompt_tokens) |
| 92 | + |
| 93 | + embedding_response = EmbeddingResponse( |
| 94 | + model=embedding_input.model.name, |
| 95 | + object="list", |
| 96 | + data=[embedding_data], |
| 97 | + usage=usage, |
| 98 | + ) |
| 99 | + return embedding_response |
| 100 | + |
| 101 | + |
| 102 | +@app.get("/docs", response_class=HTMLResponse) |
| 103 | +async def docs(request: Request): # type: ignore |
| 104 | + return templates.TemplateResponse( |
| 105 | + "stoplight-element-api-doc.html", {"request": request} |
| 106 | + ) |
0 commit comments