Skip to content

Commit a093546

Browse files
committed
add github workflow
1 parent 5d5fac5 commit a093546

File tree

10 files changed

+400
-1
lines changed

10 files changed

+400
-1
lines changed
+44
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
on:
2+
push:
3+
branches: [main]
4+
pull_request:
5+
types: [opened, synchronize]
6+
7+
name: linting and testing
8+
9+
jobs:
10+
test:
11+
name: Automatic linting and testing
12+
runs-on: ubuntu-latest
13+
14+
steps:
15+
- uses: actions/checkout@v2
16+
17+
- name: Setup Python 3.10.5
18+
uses: actions/setup-python@v2
19+
with:
20+
python-version: 3.10.5
21+
22+
- name: Setup Poetry 1.5.1
23+
run: pip install poetry==1.5.1
24+
25+
- name: Check pyproject.toml
26+
run: poetry check
27+
28+
- name: Install python dependencies with Poetry
29+
run: poetry install
30+
31+
- name: Check files with isort
32+
run: poetry run isort . --check
33+
34+
- name: Check files with Black
35+
run: poetry run black . --check
36+
37+
- name: Lint files with flake8
38+
run: poetry run flake8
39+
40+
- name: Check types with mypy
41+
run: poetry run mypy
42+
43+
- name: Run tests
44+
run: poetry run pytest

Dockerfile

+53
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
FROM python:3.10.5-slim as base
2+
3+
# Python settings
4+
ENV PYTHONFAULTHANDLER=1 \
5+
PYTHONUNBUFFERED=1 \
6+
PYTHONHASHSEED=random \
7+
PYTHONDONTWRITEBYTECODE=1
8+
9+
10+
WORKDIR /app
11+
12+
13+
FROM base as builder
14+
15+
ENV \
16+
# Pip settings
17+
PIP_DEFAULT_TIMEOUT=100 \
18+
PIP_DISABLE_PIP_VERSION_CHECK=1 \
19+
PIP_NO_CACHE_DIR=1 \
20+
# Poetry settings
21+
POETRY_NO_INTERACTION=1 \
22+
POETRY_VERSION=1.5.1
23+
24+
# Setup poetry
25+
RUN pip install "poetry==$POETRY_VERSION"
26+
RUN python -m venv /venv
27+
28+
# Copy relevant files
29+
COPY pyproject.toml poetry.lock .
30+
COPY vector_embedding_server ./vector_embedding_server
31+
RUN touch /app/README.md # required by poetry somehow
32+
#CMD ["tail", "-f", "/dev/null"]
33+
34+
35+
36+
37+
RUN . /venv/bin/activate && \
38+
# Install dependencies
39+
poetry install -n --only main --no-root && \
40+
# Install root package
41+
poetry build -f wheel -n && \
42+
pip install --no-deps dist/*.whl && \
43+
rm -rf dist *.egg-info
44+
45+
46+
FROM base as final
47+
48+
ENV PATH="/venv/bin:$PATH"
49+
50+
COPY --from=builder /venv /venv
51+
COPY ./vector_embedding_server/server.py ./server.py
52+
COPY ./vector_embedding_server/templates ./templates
53+
CMD uvicorn server:app --host 0.0.0.0 --port 8080

README.md

+18-1
Original file line numberDiff line numberDiff line change
@@ -1 +1,18 @@
1-
# vector-embedding-server
1+
# vector-embedding-server
2+
3+
4+
https://huggingface.co/intfloat/e5-large-v2/blob/main/README.md
5+
6+
7+
8+
9+
10+
11+
pip install fastapi-security python-jose[cryptography] passlib[bcrypt]
12+
13+
14+
-H "Authorization: Bearer $OPENAI_API_KEY" \
15+
16+
17+
18+
TODO include fastapi-jwt

docker-compose.yml

+12
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
version: "3"
2+
3+
services:
4+
vector-embedding-server:
5+
build:
6+
context: .
7+
dockerfile: Dockerfile
8+
image: vector-embedding-server:latest
9+
environment:
10+
TZ: Europe/Berlin
11+
ports:
12+
- "8080:8080"

vector_embedding_server/__init__py

Whitespace-only changes.

vector_embedding_server/auth.py

+44
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
from typing import Optional
2+
from datetime import datetime, timedelta
3+
from jose import jwt
4+
from passlib.context import CryptContext
5+
from fastapi import Depends, HTTPException, status
6+
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
7+
8+
9+
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
10+
11+
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
12+
13+
FAKE_USERS_DB = {
14+
"BCH": {
15+
"username": "BCH",
16+
"hashed_password": "$2b$12$YP6UgESiJ6.3c0EwnxNEnu9Ts075Jz82AcqawG7fxvFiMSUgs6cWK",
17+
"disabled": False,
18+
}
19+
}
20+
21+
def authenticate_user(fake_db, username: str, password: str):
22+
user = get_user(username)
23+
print(user)
24+
print(password)
25+
if not user:
26+
return False
27+
if not pwd_context.verify(password, user["hashed_password"]):
28+
return False
29+
return user
30+
31+
def get_user(username: str):
32+
if username in FAKE_USERS_DB:
33+
user_dict = FAKE_USERS_DB[username]
34+
return user_dict
35+
36+
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None):
37+
to_encode = data.copy()
38+
if expires_delta:
39+
expire = datetime.utcnow() + expires_delta
40+
else:
41+
expire = datetime.utcnow() + timedelta(minutes=15)
42+
to_encode.update({"exp": expire})
43+
encoded_jwt = jwt.encode(to_encode, "SECRET_KEY", algorithm="HS256") # Ersetzen Sie SECRET_KEY durch Ihren geheimen Schlüssel
44+
return encoded_jwt
+24
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# copied mostly from https://huggingface.co/intfloat/e5-large-v2/blob/main/README.md
2+
3+
import torch.nn.functional as F
4+
from torch import Tensor
5+
from transformers import AutoModel, AutoTokenizer
6+
7+
8+
def average_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
9+
last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
10+
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
11+
12+
13+
def predict(input_text: str) -> tuple[list[float], int]:
14+
tokenizer = AutoTokenizer.from_pretrained("intfloat/e5-large-v2")
15+
model = AutoModel.from_pretrained("intfloat/e5-large-v2")
16+
17+
# Tokenize the input texts
18+
batch_dict = tokenizer(
19+
[input_text], max_length=512, padding=True, truncation=True, return_tensors="pt"
20+
)
21+
22+
outputs = model(**batch_dict)
23+
embeddings = average_pool(outputs.last_hidden_state, batch_dict["attention_mask"])
24+
return embeddings.tolist()[0], len(tokenizer.all_special_ids)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
from typing import List
2+
from pydantic import BaseModel
3+
from enum import Enum
4+
5+
6+
class EmbeddingData(BaseModel):
7+
object: str
8+
embedding: List[float]
9+
index: int
10+
11+
12+
class Usage(BaseModel):
13+
prompt_tokens: int
14+
total_tokens: int
15+
16+
17+
class EmbeddingResponse(BaseModel):
18+
object: str
19+
data: List[EmbeddingData]
20+
model: str
21+
usage: Usage
22+
23+
24+
class ModelName(str, Enum):
25+
e5_large_v2 = "e5-large-v2"
26+
27+
28+
class EmbeddingInput(BaseModel):
29+
model: ModelName
30+
input: str

vector_embedding_server/server.py

+106
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
import os
2+
from pathlib import Path
3+
from typing import Optional
4+
5+
from fastapi import FastAPI, Query, Request, Depends, HTTPException, status
6+
from fastapi.middleware.cors import CORSMiddleware
7+
from fastapi.responses import HTMLResponse
8+
from fastapi.templating import Jinja2Templates
9+
from datetime import datetime, timedelta
10+
from vector_embedding_server.auth import authenticate_user, create_access_token, FAKE_USERS_DB, OAuth2PasswordRequestForm, get_user
11+
12+
from vector_embedding_server.openai_like_api_models import (
13+
EmbeddingResponse,
14+
EmbeddingInput,
15+
ModelName,
16+
EmbeddingData,
17+
Usage,
18+
)
19+
from vector_embedding_server import e5_large_v2
20+
from pydantic import BaseModel
21+
from fastapi import Depends, HTTPException, status
22+
from fastapi.security import OAuth2PasswordBearer
23+
from jose import JWTError, jwt
24+
25+
26+
BASE_DIR = Path(__file__).resolve().parent
27+
28+
templates = Jinja2Templates(directory=str(BASE_DIR / "templates"))
29+
30+
app = FastAPI(title="Vector Embedding Server", docs_url=None)
31+
32+
33+
class Credentials(BaseModel):
34+
35+
username: str
36+
password: str
37+
38+
39+
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
40+
41+
42+
def get_current_user(token: str = Depends(oauth2_scheme)):
43+
credentials_exception = HTTPException(
44+
status_code=status.HTTP_401_UNAUTHORIZED,
45+
detail="Could not validate credentials",
46+
headers={"WWW-Authenticate": "Bearer"},
47+
)
48+
try:
49+
payload = jwt.decode(token, "SECRET_KEY", algorithms=["HS256"])
50+
print(payload)
51+
username: str = payload.get("sub")
52+
if username is None:
53+
raise credentials_exception
54+
except JWTError:
55+
raise credentials_exception
56+
print("username:udirtae", username)
57+
user = get_user(username) # Hier sollten Sie die Funktion zur Abrufung des Benutzers aus Ihrer Benutzerdatenbank aufrufen
58+
if user is None:
59+
raise credentials_exception
60+
return user
61+
62+
@app.post("/token")
63+
def login(credentials: Credentials):
64+
user = authenticate_user(FAKE_USERS_DB, credentials.username, credentials.password)
65+
if not user:
66+
raise HTTPException(
67+
status_code=status.HTTP_401_UNAUTHORIZED,
68+
detail="Incorrect username or password",
69+
headers={"WWW-Authenticate": "Bearer"},
70+
)
71+
access_token_expires = timedelta(minutes=15)
72+
access_token = create_access_token(
73+
data={"sub": user["username"]}, expires_delta=access_token_expires
74+
)
75+
return {"access_token": access_token}
76+
77+
78+
@app.post("/v1/embeddings", response_model=EmbeddingResponse)
79+
async def create_embedding(embedding_input: EmbeddingInput, current_user: str = Depends(get_current_user)):
80+
if embedding_input.model == ModelName.e5_large_v2:
81+
embedding, prompt_tokens = e5_large_v2.predict(embedding_input.input)
82+
else:
83+
raise NotImplemented
84+
85+
embedding_data = EmbeddingData(
86+
object="embedding",
87+
embedding=embedding,
88+
index=0,
89+
)
90+
91+
usage = Usage(prompt_tokens=prompt_tokens, total_tokens=prompt_tokens)
92+
93+
embedding_response = EmbeddingResponse(
94+
model=embedding_input.model.name,
95+
object="list",
96+
data=[embedding_data],
97+
usage=usage,
98+
)
99+
return embedding_response
100+
101+
102+
@app.get("/docs", response_class=HTMLResponse)
103+
async def docs(request: Request): # type: ignore
104+
return templates.TemplateResponse(
105+
"stoplight-element-api-doc.html", {"request": request}
106+
)

0 commit comments

Comments
 (0)