Skip to content

Commit

Permalink
Support a schema update mode in stats runner (#344)
Browse files Browse the repository at this point in the history
* Add schema update mode to stats runner

* Test new behaviors of Db class

* Add tests for new Runner behavior

* Reword comment

* Use human-readable files for unit test database setup

* Make update more explicit with naming and comments
  • Loading branch information
hqpho authored Oct 24, 2024
1 parent fab89e3 commit 02db128
Show file tree
Hide file tree
Showing 38 changed files with 800 additions and 100 deletions.
10 changes: 5 additions & 5 deletions run_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ function run_lint_fix {
echo -e "#### Fixing Python code"
python3 -m venv .env
source .env/bin/activate
pip3 install yapf==0.33.0 -q
pip3 install yapf==0.40.2 -q
if ! command -v isort &> /dev/null
then
pip3 install isort -q
Expand All @@ -35,12 +35,12 @@ function run_lint_fix {
function run_lint_test {
python3 -m venv .env
source .env/bin/activate
pip3 install yapf==0.33.0 -q
pip3 install yapf==0.40.2 -q
if ! command -v isort &> /dev/null
then
pip3 install isort -q
fi

echo -e "#### Checking Python style"
if ! yapf --recursive --diff --style='{based_on_style: google, indent_width: 2}' -p simple/ -e=*pb2.py -e=.env/*; then
echo "Fix Python lint errors by running ./run_test.sh -f"
Expand Down Expand Up @@ -74,9 +74,9 @@ function py_test {

python3 -m venv .env
source .env/bin/activate

cd simple
pip3 install -r requirements.txt
pip3 install -r requirements.txt -q

echo -e "#### Running stats tests"
python3 -m pytest tests/stats/ -s
Expand Down
2 changes: 1 addition & 1 deletion simple/run_stats.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ Options:
-c <file> Json config file for stats importer
-i <dir> Input directory to process
-o <dir> Output folder for stats importer. Default: $OUTPUT_DIR
-m <customdc|maindc> Mode of operation for simple importer. Default: $MODE
-m <customdc|schemaupdate|maindc> Mode of operation for simple importer. Default: $MODE
-k <api-key> DataCommons API Key
-j <jar> DC Import java jar file.
Download latest from https://github.com/datacommonsorg/import/releases/
Expand Down
63 changes: 45 additions & 18 deletions simple/stats/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,11 +118,14 @@

_SELECT_ENTITY_NAMES = "select subject_id, object_value from triples where subject_id in (%s) and predicate = 'name' and object_value <> ''"

_INIT_STATEMENTS = [
_INIT_TABLE_STATEMENTS = [
_CREATE_TRIPLES_TABLE,
_CREATE_OBSERVATIONS_TABLE,
_CREATE_KEY_VALUE_STORE_TABLE,
_CREATE_IMPORTS_TABLE,
]

_CLEAR_TABLE_FOR_IMPORT_STATEMENTS = [
# Clearing tables for now (not the import tables though since we want to maintain its history).
_DELETE_TRIPLES_STATEMENT,
_DELETE_OBSERVATIONS_STATEMENT,
Expand Down Expand Up @@ -195,6 +198,9 @@ class Db:
The "DB" could be a traditional sql db or a file system with the output being files.
"""

def maybe_clear_before_import(self):
pass

def insert_triples(self, triples: list[Triple]):
pass

Expand Down Expand Up @@ -285,8 +291,13 @@ class SqlDb(Db):

def __init__(self, config: dict) -> None:
self.engine = create_db_engine(config)
self.engine.init_or_update_tables()
self.num_observations = 0
self.variables: set[str] = set()
self.indexes_cleared = False

def maybe_clear_before_import(self):
self.engine.clear_tables_and_indexes()

def insert_triples(self, triples: list[Triple]):
logging.info("Writing %s triples to [%s]", len(triples), self.engine)
Expand Down Expand Up @@ -345,6 +356,12 @@ def from_triple_tuple(tuple: tuple) -> Triple:

class DbEngine:

def init_or_update_tables(self):
pass

def clear_tables_and_indexes(self):
pass

def execute(self, sql: str, parameters=None):
pass

Expand Down Expand Up @@ -379,14 +396,8 @@ def __init__(self, db_params: dict) -> None:
logging.info("Connected to SQLite: %s", self.local_db_file_path)

self.cursor = self.connection.cursor()
# Drop indexes first so inserts are faster.
self._drop_indexes()
for statement in _INIT_STATEMENTS:
self.cursor.execute(statement)
# Apply schema updates.
self._schema_updates()

def _schema_updates(self) -> None:
def _maybe_update_schema(self) -> None:
"""
Add any sqlite schema updates here.
Ensure that all schema updates always check if the update is necessary before applying it.
Expand Down Expand Up @@ -415,6 +426,15 @@ def _create_indexes(self) -> None:
def __str__(self) -> str:
return f"{TYPE_SQLITE}: {self.db_file_path}"

def init_or_update_tables(self):
for statement in _INIT_TABLE_STATEMENTS:
self.cursor.execute(statement)
self._maybe_update_schema()

def clear_tables_and_indexes(self):
for statement in _CLEAR_TABLE_FOR_IMPORT_STATEMENTS:
self.cursor.execute(statement)

def execute(self, sql: str, parameters=None):
if not parameters:
self.cursor.execute(sql)
Expand Down Expand Up @@ -461,8 +481,8 @@ def commit_and_close(self):
_CLOUD_MY_SQL_PARAMS = [CLOUD_MY_SQL_INSTANCE] + _CLOUD_MY_SQL_DB_CONNECT_PARAMS

_CLOUD_MYSQL_PROPERTIES_COLUMN_EXISTS_STATEMENT = """
SELECT 1
FROM INFORMATION_SCHEMA.COLUMNS
SELECT 1
FROM INFORMATION_SCHEMA.COLUMNS
WHERE TABLE_SCHEMA = DATABASE() AND TABLE_NAME = 'observations' AND COLUMN_NAME = 'properties';
"""

Expand All @@ -486,14 +506,8 @@ def __init__(self, db_params: dict[str, str]) -> None:
db_params[CLOUD_MY_SQL_INSTANCE], db_params[CLOUD_MY_SQL_DB])
self.description = f"{TYPE_CLOUD_SQL}: {db_params[CLOUD_MY_SQL_INSTANCE]} ({db_params[CLOUD_MY_SQL_DB]})"
self.cursor: Cursor = self.connection.cursor()
# Drop indexes first so inserts are faster.
self._drop_indexes()
for statement in _INIT_STATEMENTS:
self.cursor.execute(statement)
# Apply schema updates.
self._schema_updates()

def _schema_updates(self) -> None:
def _maybe_update_schema(self) -> None:
"""
Add any cloud sql schema updates here.
Ensure that all schema updates always check if the update is necessary before applying it.
Expand Down Expand Up @@ -555,6 +569,16 @@ def _db_exists(cursor) -> bool:
def __str__(self) -> str:
return self.description

def init_or_update_tables(self):
for statement in _INIT_TABLE_STATEMENTS:
self.cursor.execute(statement)
self._maybe_update_schema()

def clear_tables_and_indexes(self):
for statement in _CLEAR_TABLE_FOR_IMPORT_STATEMENTS:
self.cursor.execute(statement)
self._drop_indexes()

def execute(self, sql: str, parameters=None):
self.cursor.execute(_pymysql(sql), parameters)

Expand Down Expand Up @@ -599,7 +623,10 @@ def create_db_engine(config: dict) -> DbEngine:
assert False


def create_db(config: dict) -> Db:
def create_and_update_db(config: dict) -> Db:
""" Creates and initializes a Db, performing any setup and updates
(e.g. table creation, table schema changes) that are needed.
"""
db_type = config[FIELD_DB_TYPE]
if db_type and db_type == TYPE_MAIN_DC:
return MainDcDb(config)
Expand Down
95 changes: 56 additions & 39 deletions simple/stats/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from stats.data import ParentSVG2ChildSpecializedNames
from stats.data import Triple
from stats.data import VerticalSpec
from stats.db import create_db
from stats.db import create_and_update_db
from stats.db import create_main_dc_config
from stats.db import create_sqlite_config
from stats.db import get_cloud_sql_config_from_env
Expand All @@ -48,6 +48,7 @@

class RunMode(StrEnum):
CUSTOM_DC = "customdc"
SCHEMA_UPDATE = "schemaupdate"
MAIN_DC = "maindc"


Expand Down Expand Up @@ -113,59 +114,75 @@ def __init__(self,
self.reporter = ImportReporter(report_fh=self.process_dir_fh.make_file(
constants.REPORT_JSON_FILE_NAME))

# DB setup.
def _get_db_config() -> dict:
if self.mode == RunMode.MAIN_DC:
logging.info("Using Main DC config.")
return create_main_dc_config(self.output_dir_fh.path)
# Attempt to get from env (cloud sql, then sqlite),
# then config file, then default.
db_cfg = get_cloud_sql_config_from_env()
if db_cfg:
logging.info("Using Cloud SQL settings from env.")
return db_cfg
db_cfg = get_sqlite_config_from_env()
if db_cfg:
logging.info("Using SQLite settings from env.")
return db_cfg
logging.info("Using default DB settings.")
return create_sqlite_config(
self.output_dir_fh.make_file(constants.DB_FILE_NAME).path)

self.db = create_db(_get_db_config())
self.nodes = Nodes(self.config)
self.db = None

def run(self):
try:
# Run all data imports.
self._run_imports()
if (self.db is None):
self.db = create_and_update_db(self._get_db_config())

# Generate triples.
triples = self.nodes.triples()
# Write triples to DB.
self.db.insert_triples(triples)
if self.mode == RunMode.SCHEMA_UPDATE:
logging.info("Skipping imports because run mode is schema update.")

# Generate SVG hierarchy.
self._generate_svg_hierarchy()
elif self.mode == RunMode.CUSTOM_DC or self.mode == RunMode.MAIN_DC:
self._run_imports_and_do_post_import_work()

# Generate SVG cache.
self._generate_svg_cache()

# Generate NL sentences for creating embeddings.
self._generate_nl_sentences()

# Write import info to DB.
self.db.insert_import_info(status=ImportStatus.SUCCESS)
else:
raise ValueError(f"Unsupported mode: {self.mode}")

# Commit and close DB.
self.db.commit_and_close()

# Report done.
self.reporter.report_done()
except Exception as e:
logging.exception("Error running import")
logging.exception("Error updating stats")
self.reporter.report_failure(error=str(e))

def _get_db_config(self) -> dict:
if self.mode == RunMode.MAIN_DC:
logging.info("Using Main DC config.")
return create_main_dc_config(self.output_dir_fh.path)
# Attempt to get from env (cloud sql, then sqlite),
# then config file, then default.
db_cfg = get_cloud_sql_config_from_env()
if db_cfg:
logging.info("Using Cloud SQL settings from env.")
return db_cfg
db_cfg = get_sqlite_config_from_env()
if db_cfg:
logging.info("Using SQLite settings from env.")
return db_cfg
logging.info("Using default DB settings.")
return create_sqlite_config(
self.output_dir_fh.make_file(constants.DB_FILE_NAME).path)

def _run_imports_and_do_post_import_work(self):
# (SQL only) Drop data in existing tables (except import metadata).
# Also drop indexes for faster writes.
self.db.maybe_clear_before_import()

# Import data from all input files.
self._run_all_data_imports()

# Generate triples.
triples = self.nodes.triples()
# Write triples to DB.
self.db.insert_triples(triples)

# Generate SVG hierarchy.
self._generate_svg_hierarchy()

# Generate SVG cache.
self._generate_svg_cache()

# Generate NL sentences for creating embeddings.
self._generate_nl_sentences()

# Write import info to DB.
self.db.insert_import_info(status=ImportStatus.SUCCESS)

def _generate_nl_sentences(self):
triples: list[Triple] = []
# Get topic triples if generating topics else get SV triples.
Expand Down Expand Up @@ -247,7 +264,7 @@ def _maybe_set_special_fh(self, fh: FileHandler) -> bool:
return True
return False

def _run_imports(self):
def _run_all_data_imports(self):
input_fhs: list[FileHandler] = []
input_mcf_fhs: list[FileHandler] = []
for input_handler in self.input_handlers:
Expand Down
Loading

0 comments on commit 02db128

Please sign in to comment.