-
Notifications
You must be signed in to change notification settings - Fork 56
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #20 from feast-dev/rag
uploading rag demo
- Loading branch information
Showing
23 changed files
with
4,809 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,4 +11,5 @@ terraform.tfstate.backup | |
.vscode/* | ||
**/derby.log | ||
**/metastore_db/* | ||
.env | ||
.env | ||
.idea |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
data/* |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
3.9 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
FROM python:3.9 | ||
|
||
# Set environment varibles | ||
ENV PYTHONDONTWRITEBYTECODE 1 | ||
ENV PYTHONUNBUFFERED 1 | ||
|
||
# Set work directory | ||
WORKDIR /code | ||
|
||
|
||
# Install dependencies | ||
RUN LIBMEMCACHED=/opt/local | ||
RUN apt-get update && apt-get install -y \ | ||
libmemcached11 \ | ||
libmemcachedutil2 \ | ||
libmemcached-dev \ | ||
libz-dev \ | ||
curl \ | ||
gettext | ||
|
||
ENV PYTHONHASHSEED=random \ | ||
PIP_NO_CACHE_DIR=off \ | ||
PIP_DISABLE_PIP_VERSION_CHECK=on \ | ||
PIP_DEFAULT_TIMEOUT=100 \ | ||
# Poetry's configuration: \ | ||
POETRY_NO_INTERACTION=1 \ | ||
POETRY_VIRTUALENVS_CREATE=false \ | ||
POETRY_CACHE_DIR='/var/cache/pypoetry' \ | ||
POETRY_HOME='/usr/local' \ | ||
POETRY_VERSION=1.4.1 | ||
|
||
RUN curl -sSL https://install.python-poetry.org | python3 - --version $POETRY_VERSION | ||
|
||
COPY pyproject.toml poetry.lock /code/ | ||
RUN poetry install --no-interaction --no-ansi --no-root | ||
|
||
COPY . ./code/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
This is a demo to show how you can use Feast to do RAG | ||
|
||
## Installation via PyEnv and Poetry | ||
|
||
This demo assumes you have Pyenv (2.3.10) and Poetry (1.4.1) installed on your machine as well as Python 3.9. | ||
|
||
```bash | ||
pyenv local 3.9 | ||
poetry shell | ||
poetry install | ||
``` | ||
## Setting up the data and Feast | ||
|
||
To fetch the data simply run | ||
```bash | ||
python pull_states.py | ||
``` | ||
Which will output a file called `city_wikipedia_summaries.csv`. | ||
|
||
Then run | ||
```bash | ||
python batch_score_documents.py | ||
``` | ||
Which will output data to `data/city_wikipedia_summaries_with_embeddings.parquet` | ||
|
||
Next we'll need to do some Feast work and move the data into a repo created by | ||
Feast. | ||
|
||
## Feast | ||
|
||
To get started, make sure to have Feast installed and PostGreSQL. | ||
|
||
First run | ||
```bash | ||
cp ./data feature_repo/ | ||
``` | ||
|
||
And then open the `module_4.ipynb` notebook and follow those instructions. | ||
|
||
It will walk you through a trivial tutorial to retrieve the top `k` most similar | ||
documents using PGVector. | ||
|
||
# Overview | ||
|
||
The overview is relatively simple, the goal is to define an architecture | ||
to support the following: | ||
|
||
```mermaid | ||
flowchart TD; | ||
A[Pull Data] --> B[Batch Score Embeddings]; | ||
B[Batch Score Embeddings] --> C[Materialize Online]; | ||
C[Materialize Online] --> D[Retrieval Augmented Generation]; | ||
``` | ||
|
||
# Results | ||
|
||
The simple demo shows the code below with the retrieved data shown. | ||
|
||
```python | ||
import pandas as pd | ||
|
||
from feast import FeatureStore | ||
from batch_score_documents import run_model, TOKENIZER, MODEL | ||
from transformers import AutoTokenizer, AutoModel | ||
|
||
df = pd.read_parquet("./feature_repo/data/city_wikipedia_summaries_with_embeddings.parquet") | ||
|
||
store = FeatureStore(repo_path=".") | ||
|
||
# Prepare a query vector | ||
question = "the most populous city in the U.S. state of Texas?" | ||
|
||
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER) | ||
model = AutoModel.from_pretrained(MODEL) | ||
query_embedding = run_model(question, tokenizer, model) | ||
query = query_embedding.detach().cpu().numpy().tolist()[0] | ||
|
||
# Retrieve top k documents | ||
features = store.retrieve_online_documents( | ||
feature="city_embeddings:Embeddings", | ||
query=query, | ||
top_k=3 | ||
) | ||
``` | ||
And running `features_df.head()` will show: | ||
|
||
``` | ||
features_df.head() | ||
Embeddings distance | ||
0 [0.11749928444623947, -0.04684492573142052, 0.... 0.935567 | ||
1 [0.10329511761665344, -0.07897591590881348, 0.... 0.939936 | ||
2 [0.11634305864572525, -0.10321836173534393, -0... 0.983343 | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
from flask import ( | ||
Flask, | ||
jsonify, | ||
request, | ||
render_template, | ||
) | ||
from flasgger import Swagger | ||
from datetime import datetime | ||
|
||
app = Flask(__name__) | ||
swagger = Swagger(app) | ||
|
||
|
||
@app.route("/get_documents") | ||
def get_documents(): | ||
"""Example endpoint returning features by id | ||
This is using docstrings for specifications. | ||
--- | ||
parameters: | ||
- name: state | ||
type: string | ||
in: query | ||
required: true | ||
default: NJ | ||
responses: | ||
200: | ||
description: A JSON of documents | ||
schema: | ||
id: Document ID | ||
properties: | ||
is_gt_18_years_old: | ||
type: array | ||
items: | ||
schema: | ||
id: value | ||
type: number | ||
""" | ||
question = request.form["question"] | ||
documents = store.get_online_documents(query) | ||
return render_template("documents.html", documents=documents) | ||
|
||
|
||
@app.route("/") | ||
def home(): | ||
return render_template("home.html") | ||
|
||
|
||
if __name__ == "__main__": | ||
app.run(debug=True) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
import os | ||
import pandas as pd | ||
from transformers import AutoTokenizer, AutoModel | ||
import torch | ||
import torch.nn.functional as F | ||
|
||
INPUT_FILENAME = "./data/city_wikipedia_summaries.csv" | ||
EXPORT_FILENAME = "./data/city_wikipedia_summaries_with_embeddings.parquet" | ||
TOKENIZER = "sentence-transformers/all-MiniLM-L6-v2" | ||
MODEL = "sentence-transformers/all-MiniLM-L6-v2" | ||
|
||
|
||
def mean_pooling(model_output, attention_mask): | ||
token_embeddings = model_output[ | ||
0 | ||
] # First element of model_output contains all token embeddings | ||
input_mask_expanded = ( | ||
attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() | ||
) | ||
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp( | ||
input_mask_expanded.sum(1), min=1e-9 | ||
) | ||
|
||
|
||
def run_model(sentences, tokenizer, model): | ||
encoded_input = tokenizer( | ||
sentences, padding=True, truncation=True, return_tensors="pt" | ||
) | ||
# Compute token embeddings | ||
with torch.no_grad(): | ||
model_output = model(**encoded_input) | ||
|
||
sentence_embeddings = mean_pooling(model_output, encoded_input["attention_mask"]) | ||
sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1) | ||
return sentence_embeddings | ||
|
||
|
||
def score_data() -> None: | ||
if EXPORT_FILENAME not in os.listdir(): | ||
print("scored data not found...generating embeddings...") | ||
df = pd.read_csv(INPUT_FILENAME) | ||
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER) | ||
model = AutoModel.from_pretrained(MODEL) | ||
embeddings = run_model(df["Wiki Summary"].tolist(), tokenizer, model) | ||
print(embeddings) | ||
print("shape = ", df.shape) | ||
df["Embeddings"] = list(embeddings.detach().cpu().numpy()) | ||
print("embeddings generated...") | ||
df["event_timestamp"] = pd.to_datetime("today") | ||
df["item_id"] = df.index | ||
print(df.head()) | ||
df.to_parquet(EXPORT_FILENAME, index=False) | ||
print("...data exported. job complete") | ||
else: | ||
print("scored data found...skipping generating embeddings.") | ||
|
||
|
||
if __name__ == "__main__": | ||
score_data() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
version: '3.9' | ||
|
||
services: | ||
web: | ||
env_file: | ||
- .env | ||
build: . | ||
command: | ||
- /bin/bash | ||
- -c | ||
- python3 /code/run.py | ||
|
||
volumes: | ||
- .:/code |
Oops, something went wrong.