Skip to content

Commit

Permalink
NL: reject inappropriate words and improve schools (#2952)
Browse files Browse the repository at this point in the history
* The initial bad-word list is located
[here](https://storage.mtls.cloud.google.com/datcom-website-config/nl_bad_words.txt).

* Also, avoid school types from being considered as stop-words. There
were two other ripple effects to this change:
1. It also requires special-handling for fallback logic (to not say
"schools in sunnyvale" => "sunnyvale")
2. It also requires not regressing the demo query [how big are public
schools in sunnyvale]

Note: making the schools change also uncovered
#2953. This whole
stop-word removal business needs streamlining! Post fishfood maybe, and
as part of fixing 2853.
 
Screenshot


![image](https://github.com/datacommonsorg/website/assets/4375037/55a824e5-2ed3-4358-b928-e421cb9dd99f)
  • Loading branch information
pradh authored Jul 18, 2023
1 parent 25566e8 commit 09e5654
Show file tree
Hide file tree
Showing 18 changed files with 918 additions and 74 deletions.
19 changes: 6 additions & 13 deletions nl_server/gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.

BUCKET = 'datcom-nl-models'
TEMP_DIR = '/tmp/'

import os
from pathlib import Path
Expand All @@ -23,21 +22,15 @@
from google.cloud import storage
from sentence_transformers import SentenceTransformer

from shared.lib import gcs as gcs_lib

# Downloads the `embeddings_file` from GCS to TEMP_DIR
# and return its path.
def download_embeddings(embeddings_file: str) -> str:
storage_client = storage.Client()
bucket = storage_client.bucket(bucket_name=BUCKET)
blob = bucket.get_blob(embeddings_file)
# Download
local_embeddings_path = local_path(embeddings_file)
blob.download_to_filename(local_embeddings_path)
return local_embeddings_path

def download_embeddings(embeddings_filename: str) -> str:
return gcs_lib.download_file(bucket=BUCKET, filename=embeddings_filename)


def local_path(embeddings_file: str) -> str:
return os.path.join(TEMP_DIR, embeddings_file)
return os.path.join(gcs_lib.TEMP_DIR, embeddings_file)


def download_model_from_gcs(gcs_bucket: Any, local_dir: str,
Expand Down Expand Up @@ -78,7 +71,7 @@ def download_model_from_gcs(gcs_bucket: Any, local_dir: str,
def download_model_folder(model_folder: str) -> str:
sc = storage.Client()
bucket = sc.bucket(bucket_name=BUCKET)
directory = TEMP_DIR
directory = gcs_lib.TEMP_DIR

# Only download if needed.
model_path = os.path.join(directory, model_folder)
Expand Down
2 changes: 2 additions & 0 deletions server/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import server.lib.config as libconfig
from server.lib.disaster_dashboard import get_disaster_dashboard_data
import server.lib.i18n as i18n
from server.lib.nl.common import bad_words
import server.lib.util as libutil
import server.services.bigtable as bt
from server.services.discovery import configure_endpoints_from_ingress
Expand Down Expand Up @@ -370,6 +371,7 @@ def create_app():
secret_response = secret_client.access_secret_version(name=secret_name)
app.config['PALM_API_KEY'] = secret_response.payload.data.decode(
'UTF-8')
app.config['NL_BAD_WORDS'] = bad_words.load_bad_words()

# Get and save the blocklisted svgs.
blocklist_svg = []
Expand Down
17 changes: 16 additions & 1 deletion server/integration_tests/nlnext_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ def run_sequence(self,
detector='hybrid',
check_place_detection=False,
expected_detectors=[],
place_detector='ner'):
place_detector='ner',
failure=''):
if detector == 'heuristic':
detection_method = 'Heuristic Based'
elif detector == 'llm':
Expand Down Expand Up @@ -76,6 +77,11 @@ def run_sequence(self,
}
infile.write(json.dumps(dbg_to_write, indent=2))
else:
if failure:
self.assertTrue(failure in resp["failure"]), resp["failure"]
self.assertTrue(not resp["config"])
return

if not expected_detectors:
self.assertTrue(dbg.get('detection_type').startswith(detection_method)), \
'Query {q} failed!'
Expand Down Expand Up @@ -135,13 +141,16 @@ def test_demo_cities_feb2023(self):
self.run_sequence(
'demo2_cities_feb2023',
[
# This should list public school entities.
'How big are the public schools in Sunnyvale',
'What is the prevalence of asthma there',
'What is the commute pattern there',
'How does that compare with San Bruno',
# Proxy for parks in magiceye
'Which cities in the SF Bay Area have the highest larceny',
'What countries in Africa had the greatest increase in life expectancy',
# This should list stats about the middle school students.
'How many middle schools are there in Sunnyvale',
])

def test_demo_fallback(self):
Expand Down Expand Up @@ -254,3 +263,9 @@ def test_medium_index(self):
self.run_sequence('medium_index',
['cars per family in california counties'],
idx='medium')

def test_inappropriate_query(self):
self.run_sequence('inappropriate_query',
['how many wise asses live in sunnyvale?'],
idx='medium',
failure='inappropriate words')
Loading

0 comments on commit 09e5654

Please sign in to comment.