Skip to content

Commit

Permalink
Merge pull request #165 from mit-submit/feature/801_piazza_service
Browse files Browse the repository at this point in the history
adding piazza service for 801
  • Loading branch information
pmlugato authored Sep 19, 2024
2 parents e26e03d + d99d8bd commit 6e8d154
Show file tree
Hide file tree
Showing 6 changed files with 213 additions and 0 deletions.
9 changes: 9 additions & 0 deletions .github/workflows/prod-801-ci-cd.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,15 @@ jobs:
touch ${{ github.workspace }}/deploy/prod-801/secrets/hf_token.txt
echo "${{ secrets.HF_TOKEN }}" >> ${{ github.workspace }}/deploy/prod-801/secrets/hf_token.txt
chmod 400 ${{ github.workspace }}/deploy/prod-801/secrets/hf_token.txt
touch ${{ github.workspace }}/deploy/prod-801/secrets/piazza_email.txt
echo "${{ secrets.PROD_801_PIAZZA_EMAIL }}" >> ${{ github.workspace }}/deploy/prod-801/secrets/piazza_email.txt
chmod 400 ${{ github.workspace }}/deploy/prod-801/secrets/piazza_email.txt
touch ${{ github.workspace }}/deploy/prod-801/secrets/piazza_password.txt
echo "${{ secrets.PROD_801_PIAZZA_PASSWORD }}" >> ${{ github.workspace }}/deploy/prod-801/secrets/piazza_password.txt
chmod 400 ${{ github.workspace }}/deploy/prod-801/secrets/piazza_password.txt
touch ${{ github.workspace }}/deploy/prod-801/secrets/slack_webhook.txt
echo "${{ secrets.PROD_801_SLACK_WEBHOOK }}" >> ${{ github.workspace }}/deploy/prod-801/secrets/slack_webhook.txt
chmod 400 ${{ github.workspace }}/deploy/prod-801/secrets/slack_webhook.txt
# create env file to set tag(s) for docker-compose
- name: Create Env File
Expand Down
152 changes: 152 additions & 0 deletions A2rchi/bin/service_piazza_801.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
#!/bin/python
from A2rchi.chains.chain import Chain
from A2rchi.interfaces.uploader_app.app import FlaskAppWrapper
from A2rchi.utils.config_loader import Config_Loader
from A2rchi.utils.data_manager import DataManager
from A2rchi.utils.env import read_secret
from A2rchi.utils.scraper import Scraper

from flask import Flask
from piazza_api import Piazza
from threading import Thread

import json
import os
import requests
import time

# DEFINITIONS
SLACK_HEADERS = {'content-type': 'application/json'}
MIN_NEXT_POST_FILE = "/root/data/min_next_post.json"

# set openai
os.environ['OPENAI_API_KEY'] = read_secret("OPENAI_API_KEY")
os.environ['HUGGING_FACE_HUB_TOKEN'] = read_secret("HUGGING_FACE_HUB_TOKEN")
slack_url = read_secret("SLACK_WEBHOOK")
piazza_email = read_secret("PIAZZA_EMAIL")
piazza_password = read_secret("PIAZZA_PASSWORD")
piazza_config = Config_Loader().config["utils"].get("piazza", None)

# scrape data onto the filesystem
scraper = Scraper()
scraper.hard_scrape(verbose=True)
# unresolved_posts = scraper.piazza_scrape(verbose=True)

# update vector store
data_manager = DataManager()
data_manager.update_vectorstore()

# go through unresolved posts and suggest answers

# from this point on; filter feed for new posts and propose answers

# ^also filter for new posts that have been resolved and add to vector store

# for now, just iter through all posts and send replies for unresolved


# login to piazza
piazza = Piazza()
piazza.user_login(email=piazza_email, password=piazza_password)
piazza_net = piazza.network(piazza_config["network_id"])

# create chain
a2rchi_chain = Chain()

def call_chain(chain, post):
# convert post --> history
post_str = "SUBJECT: " + post['history'][-1]['subject'] + "\n\nCONTENT: " + post['history'][-1]['content']
history = [("User", post_str)]

return chain(history)['answer'], post_str


def write_min_next_post(post_nr):
with open(MIN_NEXT_POST_FILE, 'w') as f:
json.dump({"min_next_post_nr": post_nr}, f)


def read_min_next_post():
with open(MIN_NEXT_POST_FILE, 'r') as f:
min_next_post_data = json.load(f)

return int(min_next_post_data['min_next_post_nr'])

# # get generator for all posts
# max_post_nr = 0
# posts = piazza_net.iter_all_posts(sleep=1.5)
# for idx, post in enumerate(posts):
# # update highest post # seen
# max_post_nr = max(post['nr'], max_post_nr)

# # if post has no answer or an unresolved followup, send to A2rchi
# if post.get("no_answer", False): # or post.get("no_answer_followup", False)
# print(f"{idx} PROCESSING POST: {post['nr']}")

# # generate response
# response, post_str = call_chain(a2rchi_chain, post)
# response = f"====================\nReplying to Post @{post['nr']}\n==========\n\n{post_str}\n==========\n\nA2RCHI RESPONSE: {response}\n====================\n"

# # send response to Slack
# r = requests.post(slack_url, data=json.dumps({"text": response}), headers=SLACK_HEADERS)
# print(r)

# else:
# print(f"{idx} skipping post: {post['nr']}")

# continuously poll for next post
# min_next_post_nr = max_post_nr + 1

# write min next post number if we're initializing for the first time
if not os.path.isfile(MIN_NEXT_POST_FILE):
print("WRITING INITIAL MIN. NEXT POST")
write_min_next_post(44)

# read min next post number
min_next_post_nr = read_min_next_post()

while True:
try:
# get new post(s) and sort them by 'nr'
feed = piazza_net.get_feed(limit=999999, offset=0)
post_nrs = sorted(list(map(lambda post: post['nr'], feed['feed'])))
largest_post_nr = post_nrs[-1]
except Exception as e:
print("ERROR - Failed to parse feed due to the following exception:")
print(str(e))
time.sleep(60)
continue

# keep processing posts >= min_next_post_nr
while len(post_nrs) > 0:
# get next post number
post_nr = post_nrs.pop(-1)

# stop if we've already processed it
if post_nr < min_next_post_nr:
break

try:
# otherwise, process it
post = piazza_net.get_post(post_nr)

# if successful, send to A2rchi
print(f"PROCESSING NEW POST: {post_nr}")
response, post_str = call_chain(a2rchi_chain, post)
response = f"====================\nReplying to Post @{post['nr']}\n==========\n\n{post_str}\n==========\n\nA2RCHI RESPONSE: {response}\n====================\n"

# send response to Slack
r = requests.post(slack_url, data=json.dumps({"text": response}), headers=SLACK_HEADERS)
print(r)
except Exception as e:
print(f"ERROR - Failed to process post {post_nr} due to the following exception:")
print(str(e))

# set min. next post to be one greater than max we just saw
min_next_post_nr = largest_post_nr + 1

# write min_next_post_nr so we don't start over on restart
write_min_next_post(min_next_post_nr)

# sleep for 60s
time.sleep(60)
2 changes: 2 additions & 0 deletions config/prod-801-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -114,3 +114,5 @@ utils:
reset_data: True # delete websites and sources.yml in data folder
verify_urls: False # should be true when possible
enable_warnings: False # keeps output clean if verify == False
piazza:
network_id: "m0g3v0ahsqm2lg"
13 changes: 13 additions & 0 deletions deploy/dockerfiles/Dockerfile-piazza
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# syntax=docker/dockerfile:1
# FROM pytorch/pytorch:2.0.1-cuda11.7-cudnn8-devel
FROM python:3.10
RUN mkdir -p /root/A2rchi
WORKDIR /root/A2rchi
COPY pyproject.toml pyproject.toml
COPY README.md README.md
COPY LICENSE LICENSE
COPY config config
COPY A2rchi A2rchi
RUN pip install --upgrade pip && pip install .

CMD ["python", "-u", "A2rchi/bin/service_piazza.py"]
36 changes: 36 additions & 0 deletions deploy/prod-801/prod-801-compose.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,34 @@
services:
piazza-prod-801:
image: piazza-prod-801:${TAG}
build:
context: ../..
dockerfile: deploy/dockerfiles/Dockerfile-piazza
args:
TAG: ${TAG}
depends_on:
chromadb-prod-801:
condition: service_healthy
environment:
RUNTIME_ENV: prod-801
OPENAI_API_KEY_FILE: /run/secrets/openai_api_key
HUGGING_FACE_HUB_TOKEN_FILE: /run/secrets/hf_token
PIAZZA_EMAIL_FILE: /run/secrets/piazza_email
PIAZZA_PASSWORD_FILE: /run/secrets/piazza_password
SLACK_WEBHOOK_FILE: /run/secrets/slack_webhook
secrets:
- openai_api_key
- hf_token
- piazza_email
- piazza_password
- slack_webhook
volumes:
- a2rchi-prod-801-data:/root/data/
logging:
options:
max-size: 10m
restart: always

chat-prod-801:
image: chat-prod-801:${TAG}
build:
Expand Down Expand Up @@ -91,3 +121,9 @@ secrets:
file: secrets/openai_api_key.txt
hf_token:
file: secrets/hf_token.txt
piazza_email:
file: secrets/piazza_email.txt
piazza_password:
file: secrets/piazza_password.txt
slack_webhook:
file: secrets/slack_webhook.txt
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ dependencies = [
"overrides==7.3.1",
"pandas==2.1.0",
"peft==0.5.0",
"piazza-api==0.14.0"
"posthog==3.0.1",
"pulsar-client==3.2.0",
"pypdf==3.16.1",
Expand Down

0 comments on commit 6e8d154

Please sign in to comment.