Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Adds type checking annotations auto-generated via Instagram's MonkeyType package #108

Open
wants to merge 3 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# monkeytype
monkeytype.sqlite3

nvenv/*


Expand Down
4 changes: 2 additions & 2 deletions Pipfile
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ flake8 = "==3.7.9"
hypothesis = "==5.3.1"
pytest = "==5.3.4"
## type-checking
pyre-check = "==0.0.41"
pyre-check = "*"
## like the Unix `make` but better
invoke = "==1.4.1"

monkeytype = "*"

[packages]
# REST API
Expand Down
61 changes: 41 additions & 20 deletions Pipfile.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

22 changes: 12 additions & 10 deletions QA.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from database_wrapper import NimbusMySQLAlchemy
from pandas import read_csv

from functools import partial
from typing import Dict, List
Extracted_Vars = Dict[str, Any]
DB_Data = Dict[str, Any]
DB_Query = Callable[[Extracted_Vars], DB_Data]
Expand All @@ -33,7 +35,7 @@ class QA:
A class for wrapping functions used to answer a question.
"""

def __init__(self, q_format, db_query, format_answer):
def __init__(self, q_format: str, db_query: partial, format_answer: partial) -> None:
"""
Args:
q_format (str): Question format string
Expand All @@ -49,13 +51,13 @@ def __init__(self, q_format, db_query, format_answer):
self.db_query = db_query
self.format_answer = format_answer

def _get_data_from_db(self, extracted_vars):
def _get_data_from_db(self, extracted_vars: Dict[str, str]) -> str:
return self.db_query(extracted_vars)

def _format_answer(self, extracted_vars, db_data):
def _format_answer(self, extracted_vars: Dict[str, str], db_data: str) -> str:
return self.format_answer(extracted_vars, db_data)

def answer(self, extracted_vars):
def answer(self, extracted_vars: Dict[str, str]) -> str:
db_data = self._get_data_from_db(extracted_vars)
return self._format_answer(extracted_vars, db_data)

Expand All @@ -66,7 +68,7 @@ def __hash__(self):
return hash(self.q_format)


def create_qa_mapping(qa_list):
def create_qa_mapping(qa_list: List[QA]) -> Dict[str, QA]:
"""
Creates a dictionary whose values are QA objects and keys are the question
formats of those QA objects.
Expand Down Expand Up @@ -146,18 +148,18 @@ def create_qa_mapping(qa_list):
# return functools.partial(_single_var_string_sub, a_format)


def _string_sub(a_format, extracted_info, db_data):
def _string_sub(a_format: str, extracted_info: Dict[str, str], db_data: str) -> str:
if db_data is None:
return None
else:
return a_format.format(ex=extracted_info['normalized entity'], db=db_data)


def string_sub(a_format):
def string_sub(a_format: str) -> partial:
return functools.partial(_string_sub, a_format)


def _get_property(prop, extracted_info):
def _get_property(prop: str, extracted_info: Dict[str, str]) -> str:
ent_string = extracted_info["normalized entity"]
ent = tag_lookup[extracted_info['tag']]
try:
Expand All @@ -168,7 +170,7 @@ def _get_property(prop, extracted_info):
return value


def get_property(prop):
def get_property(prop: str) -> partial:
return functools.partial(_get_property, prop)


Expand All @@ -186,7 +188,7 @@ def yes_no(a_format, pred=None):
return functools.partial(_yes_no, a_format, pred)


def generate_fact_QA(csv):
def generate_fact_QA(csv: str) -> List[QA]:
df = read_csv(csv)
text_in_brackets = r'\[[^\[\]]*\]'
qa_objs = []
Expand Down
6 changes: 3 additions & 3 deletions database_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ def __safe_create(SQLAlchemy_object):
__safe_create(self.Locations)
__safe_create(self.QuestionAnswerPair)

def _create_database_session(self):
def _create_database_session(self) -> None:
Session = sessionmaker(bind=self.engine)
self.session = Session()
print("initialized database session")
Expand All @@ -435,13 +435,13 @@ def return_qa_pair_csv(self):
def partial_fuzzy_match(self, tag_value, identifier):
return fuzz.partial_ratio(tag_value, identifier)

def full_fuzzy_match(self, tag_value, identifier):
def full_fuzzy_match(self, tag_value: str, identifier: str) -> int:
return fuzz.ratio(tag_value, identifier)

def get_property_from_entity(
self, prop: str, entity: UNION_ENTITIES, identifier: str,
tag_column_map: dict = default_tag_column_dict
):
) -> str:
"""
This function implements the abstractmethod to get a column of values
from a NimbusDatabase entity.
Expand Down
17 changes: 10 additions & 7 deletions nimbus_nlp/NIMBUS_NLP.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,13 @@
# Temporary import for the classifier
from nimbus_nlp.question_classifier import QuestionClassifier

from google.cloud.automl_v1.types import PredictResponse
from monkeytype.encoding import DUMMY_NAME
from typing import Dict
class NIMBUS_NLP:

@staticmethod
def predict_question(input_question):
def predict_question(input_question: str) -> Dict[str, str]:
'''
Runs through variable extraction and the question classifier to
predict the intended question.
Expand Down Expand Up @@ -55,7 +58,7 @@ def predict_question(input_question):

class Variable_Extraction:

def __init__(self, config_file: str = "config.json"):
def __init__(self, config_file: str = "config.json") -> None:

with open(config_file) as json_data_file:
config = json.load(json_data_file)
Expand All @@ -70,7 +73,7 @@ def __init__(self, config_file: str = "config.json"):
# TODO: consider does this even do anything useful?
os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = credential_path

def inline_text_payload(self, sent):
def inline_text_payload(self, sent: str) -> Dict[str, DUMMY_NAME]:
'''
Converts the input sentence into GCP's callable format

Expand All @@ -82,7 +85,7 @@ def inline_text_payload(self, sent):

return {'text_snippet': {'content': sent, 'mime_type': 'text/plain'} }

def get_prediction(self, sent):
def get_prediction(self, sent: str) -> PredictResponse:
'''
Obtains the prediction from the input sentence and returns the
normalized sentence
Expand All @@ -109,7 +112,7 @@ def get_prediction(self, sent):
# Return the output of the API call
return request

def extract_variables(self, sent):
def extract_variables(self, sent: str) -> Dict[str, str]:
'''
Takes the prediction and replaces the entity with its corresponding tag

Expand Down Expand Up @@ -146,7 +149,7 @@ def extract_variables(self, sent):
}

@staticmethod
def excess_word_removal(entity, tag):
def excess_word_removal(entity: str, tag: str) -> str:
'''
Checks the tag and determines which excess word removal function to use

Expand All @@ -163,7 +166,7 @@ def excess_word_removal(entity, tag):
return entity

@staticmethod
def strip_titles(entity):
def strip_titles(entity: str) -> str:
'''
Strips titles from input entities

Expand Down
20 changes: 11 additions & 9 deletions nimbus_nlp/question_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@
# TODO: move the functionality in this module into class(es), so that it can be more easily used as a dependency


from spacy.tokens.token import Token
from typing import Dict, List, Tuple
class QuestionClassifier:

def __init__(self):
def __init__(self) -> None:
nltk.download('stopwords')
nltk.download('punkt')
nltk.download('averaged_perceptron_tagger')
Expand All @@ -38,12 +40,12 @@ def train_model(self):
save_model(self.classifier, "nlp-model")


def load_latest_classifier(self):
def load_latest_classifier(self) -> None:
self.classifier = load_latest_model()
with open(PROJECT_DIR+ '/models/features/overall_features.json', 'r') as fp:
self.overall_features = json.load(fp)

def get_question_features(self, question):
def get_question_features(self, question: str) -> Dict[str, int]:
# print("using new algorithm")
"""
Method to extract features from each individual question.
Expand Down Expand Up @@ -121,18 +123,18 @@ def get_question_features_old_algorithm(self, question):
# Note: this method of extracting the main verb is not perfect, but
# for single sentence questions that should have no ambiguity about the main verb,
# it should be sufficient.
def extract_main_verb(self, question):
def extract_main_verb(self, question: str) -> Token:
doc = self.nlp(question)
sents = list(doc.sents)
if len(sents) == 0:
raise ValueError("Empty question")

return sents[0].root

def get_lemmas(self, words):
def get_lemmas(self, words: List[str]) -> List[str]:
return [self.nlp(word)[0].lemma_ for word in words]

def is_wh_word(self, pos):
def is_wh_word(self, pos: str) -> bool:
return pos in self.WH_WORDS

def build_question_classifier(self):
Expand Down Expand Up @@ -174,7 +176,7 @@ def build_question_classifier(self):

return new_classifier

def filterWHTags(self, question):
def filterWHTags(self, question: str) -> List[Tuple[str, str]]:
# ADD ALL VARIABLES TO THE FEATURE DICT WITH A WEIGHT OF 90
matches = re.findall(r'(\[(.*?)\])', question)
for match in matches:
Expand All @@ -193,7 +195,7 @@ def filterWHTags(self, question):
tag for tag in question_tags if self.is_wh_word(tag[1])]
return question_tags

def validate_WH(self, test_question, predicted_question):
def validate_WH(self, test_question: str, predicted_question: str) -> bool:
"""
Assumes that only 1 WH word exists
Returns True if the WH word in the test question equals the
Expand Down Expand Up @@ -221,7 +223,7 @@ def validate_WH(self, test_question, predicted_question):
i += 1
return wh_match

def classify_question(self, test_question):
def classify_question(self, test_question: str) -> str:
"""
Match a user query with a question in the database based on the classifier we trained and overall features we calculated.
Return relevant question.
Expand Down
Loading