Skip to content

Commit

Permalink
flake8 pipelines examples,ui,restapi (PaddlePaddle#5078)
Browse files Browse the repository at this point in the history
  • Loading branch information
w5688414 authored Mar 3, 2023
1 parent b4f11f9 commit 485ce9d
Show file tree
Hide file tree
Showing 14 changed files with 89 additions and 129 deletions.
16 changes: 9 additions & 7 deletions pipelines/examples/FAQ/dense_faq_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.

### 城市百科知识智能问答系统
# 城市百科知识智能问答系统
import argparse
import logging
import os

import paddle
from pipelines.document_stores import FAISSDocumentStore
from pipelines.utils import convert_files_to_dicts, fetch_archive_from_http, print_documents
from pipelines.nodes import ErnieRanker, DensePassageRetriever
from pipelines.nodes import DensePassageRetriever, ErnieRanker
from pipelines.utils import (
convert_files_to_dicts,
fetch_archive_from_http,
print_documents,
)

# yapf: disable
parser = argparse.ArgumentParser()
Expand Down Expand Up @@ -82,10 +84,10 @@ def dense_faq_pipeline():
# save index
document_store.save(args.index_name)

### Ranker
# Ranker
ranker = ErnieRanker(model_name_or_path="rocketqa-zh-dureader-cross-encoder", use_gpu=use_gpu)

# ### Pipeline
# Pipeline
from pipelines import SemanticSearchPipeline

pipe = SemanticSearchPipeline(retriever, ranker)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,9 @@
# limitations under the License.

import argparse
import logging
import os

import paddle
from pipelines.nodes import DocOCRProcessor, DocPrompter
from pipelines import DocPipeline
from pipelines.nodes import DocOCRProcessor, DocPrompter

# yapf: disable
parser = argparse.ArgumentParser()
Expand Down
16 changes: 9 additions & 7 deletions pipelines/examples/question-answering/dense_qa_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.

### 城市百科知识智能问答系统
# 城市百科知识智能问答系统
import argparse
import logging
import os

import paddle
from pipelines.document_stores import FAISSDocumentStore
from pipelines.utils import convert_files_to_dicts, fetch_archive_from_http, print_answers
from pipelines.nodes import ErnieReader, ErnieRanker, DensePassageRetriever
from pipelines.nodes import DensePassageRetriever, ErnieRanker, ErnieReader
from pipelines.utils import (
convert_files_to_dicts,
fetch_archive_from_http,
print_answers,
)

# yapf: disable
parser = argparse.ArgumentParser()
Expand Down Expand Up @@ -82,14 +84,14 @@ def dense_qa_pipeline():
# save index
document_store.save(args.index_name)

### Ranker
# Ranker
ranker = ErnieRanker(model_name_or_path="rocketqa-zh-dureader-cross-encoder", use_gpu=use_gpu)

reader = ErnieReader(
model_name_or_path="ernie-gram-zh-finetuned-dureader-robust", use_gpu=use_gpu, num_processes=1
)

# ### Pipeline
# Pipeline
from pipelines import ExtractiveQAPipeline

pipe = ExtractiveQAPipeline(reader, ranker, retriever)
Expand Down
8 changes: 3 additions & 5 deletions pipelines/examples/text_to_image/text_to_image_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import argparse

import paddle
from pipelines.nodes import ErnieTextToImageGenerator
from pipelines import TextToImagePipeline
from pipelines.nodes import ErnieTextToImageGenerator

# yapf: disable
parser = argparse.ArgumentParser()
Expand All @@ -26,8 +24,7 @@
parser.add_argument("--prompt_text", default='宁静的小镇', type=str, help="The prompt_text.")
parser.add_argument("--output_dir", default='ernievilg_output', type=str, help="The output path.")
parser.add_argument("--style", default='探索无限', type=str, help="The style text.")
parser.add_argument("--size", default='1024*1024',
choices=['1024*1024', '1024*1536', '1536*1024'], help="Size of the generation images")
parser.add_argument("--size", default='1024*1024', choices=['1024*1024', '1024*1536', '1536*1024'], help="Size of the generation images")
parser.add_argument("--topk", default=5, type=int, help="The top k images.")
args = parser.parse_args()
# yapf: enable
Expand All @@ -47,6 +44,7 @@ def text_to_image():
}
},
)
print(prediction)
pipe.save_to_yaml("text_to_image.yaml")


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,10 @@
# limitations under the License.

import argparse
import logging
import os
from pprint import pprint

import paddle
from pipelines.nodes import AnswerExtractor, QAFilter, QuestionGenerator
from pipelines.nodes import ErnieRanker, DensePassageRetriever
from pipelines.document_stores import FAISSDocumentStore
from pipelines.utils import convert_files_to_dicts, fetch_archive_from_http, print_documents
from pipelines.pipelines import QAGenerationPipeline, SemanticSearchPipeline
from pipelines.pipelines import QAGenerationPipeline

# yapf: disable
parser = argparse.ArgumentParser()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,19 @@
# limitations under the License.

import argparse
import logging
import os
from pprint import pprint

import paddle
from pipelines.nodes import AnswerExtractor, QAFilter, QuestionGenerator
from pipelines.nodes import ErnieRanker, DensePassageRetriever
from pipelines.document_stores import FAISSDocumentStore
from pipelines.utils import convert_files_to_dicts, fetch_archive_from_http, print_documents
from pipelines.nodes import (
AnswerExtractor,
DensePassageRetriever,
ErnieRanker,
QAFilter,
QuestionGenerator,
)
from pipelines.pipelines import QAGenerationPipeline, SemanticSearchPipeline
from pipelines.utils import convert_files_to_dicts, print_documents

# yapf: disable
parser = argparse.ArgumentParser()
Expand Down Expand Up @@ -84,7 +87,7 @@ def dense_faq_pipeline():
# save index
document_store.save(args.index_name)

### Ranker
# Ranker
ranker = ErnieRanker(model_name_or_path="rocketqa-zh-dureader-cross-encoder", use_gpu=use_gpu)

pipe = SemanticSearchPipeline(retriever, ranker)
Expand Down
6 changes: 3 additions & 3 deletions pipelines/pipelines/utils/export_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import logging
import pprint
from collections import defaultdict
from typing import Any, Dict, List, Optional
from typing import Optional

import pandas as pd

Expand Down Expand Up @@ -161,7 +161,7 @@ def export_answers_to_csv(agg_results: list, output_file):
assert "query" in agg_results[0], f"Wrong format used for {agg_results[0]}"
assert "answers" in agg_results[0], f"Wrong format used for {agg_results[0]}"

data = {} # type: Dict[str, List[Any]]
data = {}
data["query"] = []
data["prediction"] = []
data["prediction_rank"] = []
Expand Down Expand Up @@ -193,7 +193,7 @@ def convert_labels_to_squad(labels_file: str):
for label in labels:
labels_grouped_by_documents[label["document_id"]].append(label)

labels_in_squad_format = {"data": []} # type: Dict[str, Any]
labels_in_squad_format = {"data": []}
for document_id, labels in labels_grouped_by_documents.items():
qas = []
for label in labels:
Expand Down
22 changes: 11 additions & 11 deletions pipelines/rest_api/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,29 +13,29 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import sys
import logging

sys.path.append(".")

logging.basicConfig(format="%(asctime)s %(message)s", datefmt="%m/%d/%Y %I:%M:%S %p")
logger = logging.getLogger(__name__)
logging.getLogger("elasticsearch").setLevel(logging.WARNING)
logging.getLogger("pipelines").setLevel(logging.INFO)
import sys

import uvicorn
from fastapi import FastAPI, HTTPException
from fastapi.routing import APIRoute
from fastapi.openapi.utils import get_openapi
from fastapi.routing import APIRoute
from starlette.middleware.cors import CORSMiddleware

# flake8: noqa
sys.path.append(".")
from rest_api.config import ROOT_PATH
from rest_api.controller.errors.http_error import http_error_handler
from rest_api.config import ROOT_PATH, PIPELINE_YAML_PATH
from rest_api.controller.router import router as api_router

logging.basicConfig(format="%(asctime)s %(message)s", datefmt="%m/%d/%Y %I:%M:%S %p")
logger = logging.getLogger(__name__)
logging.getLogger("elasticsearch").setLevel(logging.WARNING)
logging.getLogger("pipelines").setLevel(logging.INFO)

try:
from pipelines import __version__ as pipelines_version
except:
except Exception:
# For development
pipelines_version = "0.0.0"

Expand Down
18 changes: 8 additions & 10 deletions pipelines/ui/webapp_docprompt_gradio.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import json
import argparse
import base64
from io import BytesIO
from PIL import Image
import traceback
import argparse
from io import BytesIO

import requests
import numpy as np
import gradio as gr
import fitz
import cv2
import fitz
import gradio as gr
import numpy as np
import requests
from PIL import Image

fitz_tools = fitz.Tools()

Expand Down Expand Up @@ -171,7 +169,7 @@ def read_content(file_path: str) -> str:
padding-bottom: 2px !important;
padding-left: 8px !important;
padding-right: 8px !important;
margin-top: 10px;
margin-top: 10px;
}
.gradio-container .gr-button-primary {
background: linear-gradient(180deg, #CDF9BE 0%, #AFF497 100%);
Expand Down
21 changes: 7 additions & 14 deletions pipelines/ui/webapp_faq.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import os
import sys
import logging
import pandas as pd
from json import JSONDecodeError
from pathlib import Path

import pandas as pd
import streamlit as st
from annotated_text import annotation
from markdown import markdown

sys.path.append("ui")
from utils import pipelines_is_ready, semantic_search, send_feedback, upload_doc, pipelines_version, get_backlink
from utils import pipelines_is_ready, semantic_search, upload_doc

# Adjust to a question that you would like users to see in the search bar when they load the UI:
DEFAULT_QUESTION_AT_STARTUP = os.getenv("DEFAULT_QUESTION_AT_STARTUP", "如何办理企业养老保险?")
Expand Down Expand Up @@ -58,7 +56,7 @@ def upload():
for data_file in data_files:
# Upload file
if data_file and data_file.name not in st.session_state.upload_files["uploaded_files"]:
raw_json = upload_doc(data_file)
upload_doc(data_file)
st.session_state.upload_files["uploaded_files"].append(data_file.name)
# Save the uploaded files
st.session_state.upload_files["uploaded_files"] = list(set(st.session_state.upload_files["uploaded_files"]))
Expand Down Expand Up @@ -115,16 +113,11 @@ def reset_results(*args):
for data_file in st.session_state.upload_files["uploaded_files"]:
st.sidebar.write(str(data_file) + "    ✅ ")

hs_version = ""
try:
hs_version = f" <small>(v{pipelines_version()})</small>"
except Exception:
pass
# Load csv into pandas dataframe
try:
df = pd.read_csv(EVAL_LABELS, sep=";")
except Exception:
st.error(f"The eval file was not found.")
st.error("The eval file was not found.")
sys.exit(f"The eval file was not found under `{EVAL_LABELS}`.")

# Search bar
Expand Down Expand Up @@ -181,7 +174,7 @@ def reset_results(*args):
st.session_state.results, st.session_state.raw_json = semantic_search(
question, top_k_reader=top_k_reader, top_k_retriever=top_k_retriever
)
except JSONDecodeError as je:
except JSONDecodeError:
st.error("👓 &nbsp;&nbsp; An error occurred reading the results. Is the document store working?")
return
except Exception as e:
Expand Down
18 changes: 7 additions & 11 deletions pipelines/ui/webapp_question_answering.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import os
import sys
import logging
import pandas as pd
from json import JSONDecodeError
from pathlib import Path

import pandas as pd
import streamlit as st
from annotated_text import annotation
from markdown import markdown
from ui.utils import pipelines_is_ready, query, send_feedback, upload_doc, pipelines_version, get_backlink
from ui.utils import get_backlink, pipelines_is_ready, query, upload_doc

# Adjust to a question that you would like users to see in the search bar when they load the UI:
DEFAULT_QUESTION_AT_STARTUP = os.getenv("DEFAULT_QUESTION_AT_STARTUP", "中国的首都在哪里?")
Expand Down Expand Up @@ -54,7 +55,7 @@ def upload():
for data_file in data_files:
# Upload file
if data_file and data_file.name not in st.session_state.upload_files["uploaded_files"]:
raw_json = upload_doc(data_file)
upload_doc(data_file)
st.session_state.upload_files["uploaded_files"].append(data_file.name)
# Save the uploaded files
st.session_state.upload_files["uploaded_files"] = list(set(st.session_state.upload_files["uploaded_files"]))
Expand Down Expand Up @@ -109,7 +110,7 @@ def reset_results(*args):
try:
df = pd.read_csv(EVAL_LABELS, sep=";")
except Exception:
st.error(f"The eval file was not found.")
st.error("The eval file was not found.")
sys.exit(f"The eval file was not found under `{EVAL_LABELS}`.")

# File upload block
Expand All @@ -122,11 +123,6 @@ def reset_results(*args):
st.sidebar.button("文件上传", on_click=upload)
for data_file in st.session_state.upload_files["uploaded_files"]:
st.sidebar.write(str(data_file) + " &nbsp;&nbsp; ✅ ")
hs_version = ""
try:
hs_version = f" <small>(v{pipelines_version()})</small>"
except Exception:
pass

# Search bar
question = st.text_input(
Expand Down Expand Up @@ -185,7 +181,7 @@ def reset_results(*args):
st.session_state.results, st.session_state.raw_json = query(
question, top_k_reader=top_k_reader, top_k_ranker=top_k_ranker, top_k_retriever=top_k_retriever
)
except JSONDecodeError as je:
except JSONDecodeError:
st.error("👓 &nbsp;&nbsp; An error occurred reading the results. Is the document store working?")
return
except Exception as e:
Expand Down
Loading

0 comments on commit 485ce9d

Please sign in to comment.