Skip to content

Commit

Permalink
lint
Browse files Browse the repository at this point in the history
  • Loading branch information
etwk committed Sep 3, 2024
1 parent ee3a450 commit 9110b06
Show file tree
Hide file tree
Showing 12 changed files with 35 additions and 34 deletions.
5 changes: 2 additions & 3 deletions datasets/wiki_dpr/prepare_files.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os, subprocess
import os
import subprocess
import shutil
from huggingface_hub import snapshot_download
from tenacity import retry, stop_after_attempt, wait_fixed
Expand All @@ -23,8 +24,6 @@
]
revision = "main"

import os

def check_exists(folder_path):
# Check if the folder exists
if os.path.exists(folder_path) and os.path.isdir(folder_path):
Expand Down
2 changes: 2 additions & 0 deletions src/api/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
__all__ = ['ReadUrl', 'SearchWeb']

from .read import ReadUrl
from .search import SearchWeb
1 change: 0 additions & 1 deletion src/api/read.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import httpx
import json
from tenacity import retry, stop_after_attempt, wait_fixed

import utils
Expand Down
1 change: 0 additions & 1 deletion src/api/search.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import asyncio
import httpx
import json
from tenacity import retry, stop_after_attempt, wait_fixed
Expand Down
2 changes: 2 additions & 0 deletions src/integrations/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
__all__ = ['InfinityEmbedding', 'OllamaEmbedding']

from .infinity_embedding import InfinityEmbedding
from .ollama_embedding import OllamaEmbedding
11 changes: 5 additions & 6 deletions src/main.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import asyncio
import json
import logging
from fastapi import FastAPI, HTTPException, Request, Header
from fastapi.concurrency import run_in_threadpool
from fastapi.responses import Response, JSONResponse, HTMLResponse, PlainTextResponse, FileResponse, RedirectResponse, StreamingResponse
from fastapi import FastAPI, HTTPException, Header
from fastapi.responses import HTMLResponse, PlainTextResponse, RedirectResponse, StreamingResponse

import pipeline, utils, web
from modules import Search
import pipeline
import utils
import web
from settings import settings

logging.basicConfig(
Expand Down
16 changes: 9 additions & 7 deletions src/modules/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,12 @@
__all__ = ['Citation', 'ContextVerdict', 'LlamaIndexRM', 'Search', 'SearchQuery', 'Statements']

from .citation import Citation
from .context_verdict import ContextVerdict
from .retrieve import LlamaIndexRM
from .search import Search
from .search_query import SearchQuery
from .statements import Statements

import dspy

from settings import settings
Expand All @@ -8,10 +17,3 @@

# LM with higher token limits
llm_long = dspy.OpenAI(model=settings.LLM_MODEL_NAME, api_base=f"{settings.OPENAI_BASE_URL}/", max_tokens=500, stop='\n\n')

from .citation import Citation
from .context_verdict import ContextVerdict
from .retrieve import LlamaIndexRM
from .search import Search
from .search_query import SearchQuery
from .statements import Statements
13 changes: 4 additions & 9 deletions src/modules/retrieve.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import concurrent.futures
import dspy
import logging
from typing import Optional

from llama_index.core import (
Document,
Settings,
StorageContext,
VectorStoreIndex,
Expand All @@ -13,15 +12,13 @@
from llama_index.core.indices.postprocessor import SentenceTransformerRerank
from llama_index.core.query_engine import RetrieverQueryEngine
from llama_index.core.llms import MockLLM

Settings.llm = MockLLM(max_tokens=256) # retrieve only, do not use LLM for synthesize
from llama_index.postprocessor.jinaai_rerank import JinaRerank

import utils
from integrations import InfinityEmbedding, OllamaEmbedding
from settings import settings

from llama_index.postprocessor.jinaai_rerank import JinaRerank

from integrations import InfinityEmbedding
Settings.llm = MockLLM(max_tokens=256) # retrieve only, do not use LLM for synthesize

if settings.EMBEDDING_MODEL_DEPLOY == "local":
embed_model="local:" + settings.EMBEDDING_MODEL_NAME
Expand Down Expand Up @@ -132,8 +129,6 @@ def retrieve(self, query):

return contexts

import dspy

NO_TOP_K_WARNING = "The underlying LlamaIndex retriever does not support top k retrieval. Ignoring k value."

class LlamaIndexRM(dspy.Retrieve):
Expand Down
9 changes: 5 additions & 4 deletions src/pipeline/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import dspy
import logging
import os
from fastapi import HTTPException
from fastapi.concurrency import run_in_threadpool
from tenacity import retry, stop_after_attempt, wait_fixed
from urllib.parse import urlparse
Expand Down Expand Up @@ -69,12 +70,12 @@ async def _pipe_source(self, data_source, statement):
# update docs
_task_docs = []
for _, data_doc in data_source['docs'].items():
if not data_doc.get('doc') and data_doc.get('valid') != False: # TODO: better way to decide if update doc
if not data_doc.get('doc') and data_doc.get('valid') is not False: # TODO: better way to decide if update doc
_task_docs.append(asyncio.create_task(self.update_doc(data_doc)))
await asyncio.gather(*_task_docs) # finish all docs processing

# update retriever
docs = [v['doc'] for v in data_source['docs'].values() if v.get('valid') != False]
docs = [v['doc'] for v in data_source['docs'].values() if v.get('valid') is not False]
if docs:
data_source["retriever"] = await run_in_threadpool(LlamaIndexRM, docs=docs)

Expand Down Expand Up @@ -129,7 +130,7 @@ async def update_doc(self, data_doc):
"""Update doc (URL content for now)"""
try:
_rep = await ReadUrl(url=data_doc['url']).get()
except:
except Exception:
data_doc['valid'] = False
logging.warning(f"Failed to read URL, mark as invalid: {data_doc['url']}")
return
Expand Down Expand Up @@ -176,7 +177,7 @@ def update_summary(self, data_statement):
}

for hostname, verdict in data_statement['sources'].items():
if verdict.get('valid') == False:
if verdict.get('valid') is False:
continue
weight_total += 1
v = verdict['verdict'].lower()
Expand Down
5 changes: 3 additions & 2 deletions src/settings.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os, ast
import ast
import os

class Settings:
def __init__(self):
Expand Down Expand Up @@ -28,7 +29,7 @@ def __init__(self):
# set Index chunk sizes
try:
self.INDEX_CHUNK_SIZES = ast.literal_eval(os.environ.get("INDEX_CHUNK_SIZES"))
except:
except (ValueError, SyntaxError):
self.INDEX_CHUNK_SIZES = [1024, 256]

"""
Expand Down
2 changes: 2 additions & 0 deletions src/web/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
__all__ = ['get_homepage', 'html_browser']

from .homepage import get_homepage
from .html import html_browser
2 changes: 1 addition & 1 deletion src/web/homepage.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
def get_homepage():
# get tech stack
stack = utils.get_stack()
md = f"## Tech stack\n"
md = "## Tech stack\n"
lines = [md]
lines.extend([f"**{key}**: {value}" for key, value in stack.items()])
md = "\n\n".join(lines)
Expand Down

0 comments on commit 9110b06

Please sign in to comment.