diff --git a/.codecov.yml b/.codecov.yml index 3687acc9e..05789d61f 100644 --- a/.codecov.yml +++ b/.codecov.yml @@ -2,7 +2,7 @@ coverage: ignore: - */tests/* - qcfractal/dashboard/* # early state - - qcfractal/migrations/* # difficult to test + - qcfractal/alembic/* # difficult to test - qcfractal/_version.py - setup.py status: diff --git a/.lgtm.yml b/.lgtm.yml index 92973b8a4..4bf3585fa 100644 --- a/.lgtm.yml +++ b/.lgtm.yml @@ -9,6 +9,7 @@ path_classifiers: - versioneer.py # Set Versioneer.py to an external "library" (3rd party code) - devtools/* - qcfractal/dashboard/* # Very early state, some conditions forcing LGTM issues + - qcfractal/alembic/* # One-shot, from templates generated: - qcfractal/_version.py queries: diff --git a/devtools/conda-envs/adapters.yaml b/devtools/conda-envs/adapters.yaml index fef4d9879..6eb60c566 100644 --- a/devtools/conda-envs/adapters.yaml +++ b/devtools/conda-envs/adapters.yaml @@ -5,6 +5,7 @@ channels: dependencies: - python - numpy + - msgpack-python>=0.6.1 - pandas - tornado - requests @@ -17,6 +18,7 @@ dependencies: - psycopg2>=2.7 - postgresql - alembic + - tqdm # Test depends - pytest @@ -33,8 +35,8 @@ dependencies: - parsl>=0.8.0 # QCArchive includes - - qcengine>=0.8.2 - - qcelemental>=0.5.0 + - qcengine>=0.9.0 + - qcelemental>=0.6.0 # Pip includes - pip: diff --git a/devtools/conda-envs/base.yaml b/devtools/conda-envs/base.yaml index 4afe9ca38..37a9d2623 100644 --- a/devtools/conda-envs/base.yaml +++ b/devtools/conda-envs/base.yaml @@ -5,6 +5,7 @@ channels: dependencies: - python - numpy + - msgpack-python>=0.6.1 - pandas - tornado - requests @@ -17,6 +18,7 @@ dependencies: - psycopg2>=2.7 - postgresql - alembic + - tqdm # Test depends - pytest @@ -24,5 +26,5 @@ dependencies: - codecov # QCArchive includes - - qcengine>=0.8.2 - - qcelemental>=0.5.0 + - qcengine>=0.9.0 + - qcelemental>=0.6.0 diff --git a/devtools/conda-envs/dev_head.yaml b/devtools/conda-envs/dev_head.yaml index cce8464ae..acf4aba04 100644 --- a/devtools/conda-envs/dev_head.yaml +++ b/devtools/conda-envs/dev_head.yaml @@ -5,6 +5,7 @@ channels: dependencies: - python - numpy + - msgpack-python>=0.6.1 - pandas - tornado - requests @@ -17,6 +18,7 @@ dependencies: - psycopg2>=2.7 - postgresql - alembic + - tqdm # Test depends - pytest diff --git a/devtools/conda-envs/generate_envs.py b/devtools/conda-envs/generate_envs.py index aeaab0c03..a723ff1b9 100644 --- a/devtools/conda-envs/generate_envs.py +++ b/devtools/conda-envs/generate_envs.py @@ -15,6 +15,7 @@ dependencies: - python - numpy + - msgpack-python>=0.6.1 - pandas - tornado - requests @@ -27,13 +28,14 @@ - psycopg2>=2.7 - postgresql - alembic + - tqdm # Test depends - pytest - pytest-cov - codecov """ -qca_ecosystem_template = ["qcengine>=0.8.2", "qcelemental>=0.5.0"] +qca_ecosystem_template = ["qcengine>=0.9.0", "qcelemental>=0.6.0"] pip_depends_template = [] diff --git a/devtools/conda-envs/openff.yaml b/devtools/conda-envs/openff.yaml index 79343b249..b23a4daa0 100644 --- a/devtools/conda-envs/openff.yaml +++ b/devtools/conda-envs/openff.yaml @@ -6,6 +6,7 @@ channels: dependencies: - python - numpy + - msgpack-python>=0.6.1 - pandas - tornado - requests @@ -18,6 +19,7 @@ dependencies: - psycopg2>=2.7 - postgresql - alembic + - tqdm # Test depends - pytest @@ -31,5 +33,5 @@ dependencies: - torsiondrive # QCArchive includes - - qcengine>=0.8.2 - - qcelemental>=0.5.0 + - qcengine>=0.9.0 + - qcelemental>=0.6.0 diff --git a/devtools/scripts/create_staging.py b/devtools/scripts/create_staging.py index 64a935c8b..7ca900e3f 100644 --- a/devtools/scripts/create_staging.py +++ b/devtools/scripts/create_staging.py @@ -5,7 +5,7 @@ """ from qcfractal.storage_sockets import storage_socket_factory -from qcfractal.storage_sockets.models import (BaseResultORM, ResultORM, CollectionORM, +from qcfractal.storage_sockets.sql_models import (BaseResultORM, ResultORM, CollectionORM, OptimizationProcedureORM, GridOptimizationProcedureORM, TorsionDriveProcedureORM, TaskQueueORM) from qcfractal.interface.models import (ResultRecord, OptimizationRecord, @@ -13,9 +13,10 @@ # production_uri = "postgresql+psycopg2://qcarchive:mypass@localhost:5432/test_qcarchivedb" production_uri = "postgresql+psycopg2://postgres:@localhost:11711/qcarchivedb" -staging_uri = "postgresql+psycopg2://qcarchive:mypass@localhost:5432/staging_qcarchivedb" +staging_uri = "postgresql+psycopg2://localhost:5432/staging_qcarchivedb" SAMPLE_SIZE = 0.0001 # 0.1 is 10% MAX_LIMIT = 10000 +VERBOSE = False def connect_to_DBs(staging_uri, production_uri, max_limit): @@ -34,7 +35,7 @@ def connect_to_DBs(staging_uri, production_uri, max_limit): def get_number_to_copy(total_size, sample_size): to_copy = int(total_size*sample_size) if to_copy: - return to_copy + return max(to_copy, 10) else: return 1 # avoid zero because zero means no limit in storage @@ -48,12 +49,15 @@ def copy_molecules(staging_storage, prod_storage, prod_ids): print('----Total # of Molecules to copy: ', len(prod_ids)) ret = prod_storage.get_molecules(id=prod_ids) - print('Get from prod:', ret) + if VERBOSE: + print('Get from prod:', ret) staging_ids = staging_storage.add_molecules(ret['data']) - print('Add to staging:', staging_ids) + if VERBOSE: + print('Add to staging:', staging_ids) map = {m1: m2 for m1, m2 in zip(prod_ids, staging_ids['data'])} - print('MAP: ', map) + if VERBOSE: + print('MAP: ', map) print('---- Done copying molecules\n\n') @@ -71,12 +75,14 @@ def copy_keywords(staging_storage, prod_storage, prod_ids): ret = prod_storage.get_keywords(id=prod_ids) - print('Get from prod:', ret) + if VERBOSE: + print('Get from prod:', ret) staging_ids = staging_storage.add_keywords(ret['data']) print('Add to staging:', staging_ids) map = {m1: m2 for m1, m2 in zip(prod_ids, staging_ids['data'])} - print('MAP: ', map) + if VERBOSE: + print('MAP: ', map) print('---- Done copying keywords\n\n') @@ -94,12 +100,15 @@ def copy_kv_store(staging_storage, prod_storage, prod_ids): ret = prod_storage.get_kvstore(id=prod_ids) - print('Get from prod:', ret) + if VERBOSE: + print('Get from prod:', ret) staging_ids = staging_storage.add_kvstore(ret['data'].values()) - print('Add to staging:', staging_ids) + if VERBOSE: + print('Add to staging:', staging_ids) map = {m1: m2 for m1, m2 in zip(prod_ids, staging_ids['data'])} - print('MAP: ', map) + if VERBOSE: + print('MAP: ', map) print('---- Done copying KV_store \n\n') @@ -113,7 +122,8 @@ def copy_users(staging_storage, prod_storage): print('-----Total # of Users in the DB is: ', len(prod_users)) sql_insered = staging_storage._copy_users(prod_users)['data'] - print('Inserted in SQL:', len(sql_insered)) + if VERBOSE: + print('Inserted in SQL:', len(sql_insered)) print('---- Done copying Users\n\n') @@ -130,7 +140,8 @@ def copy_managers(staging_storage, prod_storage, mang_list): sql_insered = staging_storage._copy_managers(prod_mangers)['data'] - print('Inserted in SQL:', len(sql_insered)) + if VERBOSE: + print('Inserted in SQL:', len(sql_insered)) print('---- Done copying Queue Manager\n\n') @@ -149,7 +160,8 @@ def copy_collections(staging_storage, production_storage, SAMPLE_SIZE=0): for col in prod_results: ret = staging_storage.add_collection(col)['data'] sql_insered += 1 - print('Inserted in SQL:', sql_insered) + if VERBOSE: + print('Inserted in SQL:', sql_insered) print('---- Done copying Collections\n\n') @@ -204,7 +216,8 @@ def copy_results(staging_storage, production_storage, SAMPLE_SIZE=0, results_ids results_py = [ResultRecord(**res) for res in prod_results] staging_ids = staging_storage.add_results(results_py)['data'] - print('Inserted in SQL:', len(staging_ids)) + if VERBOSE: + print('Inserted in SQL:', len(staging_ids)) print('---- Done copying Results\n\n') @@ -265,7 +278,8 @@ def copy_optimization_procedure(staging_storage, production_storage, SAMPLE_SIZE procedures_py = [OptimizationRecord(**proc) for proc in prod_proc] staging_ids = staging_storage.add_procedures(procedures_py)['data'] - print('Inserted in SQL:', len(staging_ids)) + if VERBOSE: + print('Inserted in SQL:', len(staging_ids)) print('---- Done copying Optimization procedures\n\n') @@ -325,7 +339,8 @@ def copy_torsiondrive_procedure(staging_storage, production_storage, SAMPLE_SIZE procedures_py = [TorsionDriveRecord(**proc) for proc in prod_proc] staging_ids = staging_storage.add_procedures(procedures_py)['data'] - print('Inserted in SQL:', len(staging_ids)) + if VERBOSE: + print('Inserted in SQL:', len(staging_ids)) print('---- Done copying Torsiondrive procedures\n\n') @@ -450,7 +465,8 @@ def copy_task_queue(staging_storage, production_storage, SAMPLE_SIZE=None): raise Exception('Result not found!', rec.base_result.id) staging_ids = staging_storage._copy_task_to_queue(prod_tasks)['data'] - print('Inserted in SQL:', len(staging_ids)) + if VERBOSE: + print('Inserted in SQL:', len(staging_ids)) print('---- Done copying Task Queue\n\n') @@ -488,6 +504,11 @@ def main(): print('Exit without creating the DB.') return + # Copy metadata + #with production_storage.session_scope() as session: + # alembic = session.execute("select * from alembic_version") + # version = alembic.first()[0] + # copy all users, small tables, no need for sampling copy_users(staging_storage, production_storage) @@ -514,5 +535,6 @@ def main(): copy_alembic(staging_storage, production_storage) + if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/docs/qcfractal/source/setup_quickstart.rst b/docs/qcfractal/source/setup_quickstart.rst index 80f9abf8b..ea6277d8f 100644 --- a/docs/qcfractal/source/setup_quickstart.rst +++ b/docs/qcfractal/source/setup_quickstart.rst @@ -209,7 +209,7 @@ You may optionally provide a TLS certificate to enable host verification for the using the ``--tls-cert`` and ``--tls-key`` options. If a TLS certificate is not provided, communications with the server will still be encrypted, but host verification will be unavailable -(and :term:`Managers ` and clients will need to specify ``--verify False``). +(and :term:`Managers ` and clients will need to specify ``verify=False``). Next, add users for admin, the :term:`Manager`, and a user (you may choose whatever usernames you like):: diff --git a/qcfractal/__init__.py b/qcfractal/__init__.py index c881f12ee..1e269e935 100644 --- a/qcfractal/__init__.py +++ b/qcfractal/__init__.py @@ -8,6 +8,7 @@ from .storage_sockets import storage_socket_factory # Handle top level object imports +from .postgres_harness import PostgresHarness, TemporaryPostgres from .server import FractalServer from .snowflake import FractalSnowflake, FractalSnowflakeHandler from .queue import QueueManager diff --git a/qcfractal/alembic/versions/05ceea11b78a_base_records_msgpack_2.py b/qcfractal/alembic/versions/05ceea11b78a_base_records_msgpack_2.py new file mode 100644 index 000000000..231aee35e --- /dev/null +++ b/qcfractal/alembic/versions/05ceea11b78a_base_records_msgpack_2.py @@ -0,0 +1,34 @@ +"""Msgpack Base Results Phase 2 + +Revision ID: 05ceea11b78a +Revises: 8b0cd9accaf2 +Create Date: 2019-08-11 22:30:51.453746 + +""" +from alembic import op +import sqlalchemy as sa + +import os +import sys +sys.path.insert(1, os.path.dirname(os.path.abspath(__file__))) +from migration_helpers import msgpack_migrations +from qcelemental.util import msgpackext_dumps, msgpackext_loads + +# revision identifiers, used by Alembic. +revision = '05ceea11b78a' +down_revision = '8b0cd9accaf2' +branch_labels = None +depends_on = None + +table_name = "base_result" +update_columns = {"extras"} + +nullable = {"extras"} + + +def upgrade(): + msgpack_migrations.json_to_msgpack_table_altercolumns(table_name, update_columns, nullable_true=nullable) + + +def downgrade(): + raise ValueError("Cannot downgrade json to msgpack conversions") diff --git a/qcfractal/alembic/versions/1134312ad4a3_results_msgpack_2.py b/qcfractal/alembic/versions/1134312ad4a3_results_msgpack_2.py new file mode 100644 index 000000000..c488f183f --- /dev/null +++ b/qcfractal/alembic/versions/1134312ad4a3_results_msgpack_2.py @@ -0,0 +1,33 @@ +"""Msgpack Results Phase 2 + +Revision ID: 1134312ad4a3 +Revises: 84c94a48e491 +Create Date: 2019-08-11 17:21:43.328492 + +""" +from alembic import op +import sqlalchemy as sa + +import os +import sys +sys.path.insert(1, os.path.dirname(os.path.abspath(__file__))) +from migration_helpers import msgpack_migrations + +# revision identifiers, used by Alembic. +revision = '1134312ad4a3' +down_revision = '84c94a48e491' +branch_labels = None +depends_on = None + +table_name = "result" +update_columns = {"return_result"} + +nullable = {"return_result"} + + +def upgrade(): + msgpack_migrations.json_to_msgpack_table_altercolumns(table_name, update_columns, nullable_true=nullable) + + +def downgrade(): + raise ValueError("Cannot downgrade json to msgpack conversions") diff --git a/qcfractal/alembic/versions/84c94a48e491_results_msgpack_1.py b/qcfractal/alembic/versions/84c94a48e491_results_msgpack_1.py new file mode 100644 index 000000000..348d5ba8c --- /dev/null +++ b/qcfractal/alembic/versions/84c94a48e491_results_msgpack_1.py @@ -0,0 +1,50 @@ +"""Msgpack Results Phase 1 + +Revision ID: 84c94a48e491 +Revises: d56ac42b9a43 +Create Date: 2019-08-11 17:21:40.264688 + +""" +from alembic import op +import sqlalchemy as sa +import numpy as np + +import os +import sys +sys.path.insert(1, os.path.dirname(os.path.abspath(__file__))) +from migration_helpers import msgpack_migrations +from qcelemental.util import msgpackext_dumps, msgpackext_loads + +# revision identifiers, used by Alembic. +revision = '84c94a48e491' +down_revision = 'd56ac42b9a43' +branch_labels = None +depends_on = None + +block_size = 100 +table_name = "result" + + +def transformer(old_data): + + arr = old_data["return_result"] + if arr is None: + pass + elif old_data["driver"] == "gradient": + arr = np.array(arr, dtype=float).reshape(-1, 3) + elif old_data["driver"] == "hessian": + arr = np.array(arr, dtype=float) + arr.shape = (-1, int(arr.shape[0]**0.5)) + + return {"return_result_": msgpackext_dumps(arr)} + + +update_columns = {"return_result"} + + +def upgrade(): + msgpack_migrations.json_to_msgpack_table(table_name, block_size, update_columns, transformer, read_columns={"driver": sa.String}) + + +def downgrade(): + msgpack_migrations.json_to_msgpack_table_dropcols(table_name, block_size, update_columns) diff --git a/qcfractal/alembic/versions/8b0cd9accaf2_base_records_msgpack_1.py b/qcfractal/alembic/versions/8b0cd9accaf2_base_records_msgpack_1.py new file mode 100644 index 000000000..45112b38e --- /dev/null +++ b/qcfractal/alembic/versions/8b0cd9accaf2_base_records_msgpack_1.py @@ -0,0 +1,42 @@ +"""Msgpack Base Results Phase 1 + +Revision ID: 8b0cd9accaf2 +Revises: 1134312ad4a3 +Create Date: 2019-08-11 22:30:27.613722 + +""" +from alembic import op +import sqlalchemy as sa + +import os +import sys +sys.path.insert(1, os.path.dirname(os.path.abspath(__file__))) +from migration_helpers import msgpack_migrations +from qcelemental.util import msgpackext_dumps, msgpackext_loads + +# revision identifiers, used by Alembic. +revision = '8b0cd9accaf2' +down_revision = '1134312ad4a3' +branch_labels = None +depends_on = None + +block_size = 100 +table_name = "base_result" + +def transformer(old_data): + + extras = old_data["extras"] + extras.pop("_qcfractal_tags", None) # cleanup old tags + + return {"extras_": msgpackext_dumps(extras)} + + +update_columns = {"extras"} + + +def upgrade(): + msgpack_migrations.json_to_msgpack_table(table_name, block_size, update_columns, transformer) + + +def downgrade(): + msgpack_migrations.json_to_msgpack_table_dropcols(table_name, block_size, update_columns) diff --git a/qcfractal/alembic/versions/963822c28879_molecule_msgpack_1.py b/qcfractal/alembic/versions/963822c28879_molecule_msgpack_1.py new file mode 100644 index 000000000..02827777b --- /dev/null +++ b/qcfractal/alembic/versions/963822c28879_molecule_msgpack_1.py @@ -0,0 +1,56 @@ +"""Msgpack Molecule Phase 1 + +Revision ID: 963822c28879 +Revises: 4bb79efa9855 +Create Date: 2019-08-10 17:41:15.520300 + +""" +from alembic import op +import sqlalchemy as sa +import numpy as np + +import os +import sys +sys.path.insert(1, os.path.dirname(os.path.abspath(__file__))) + +from migration_helpers import msgpack_migrations +from qcelemental.util import msgpackext_dumps, msgpackext_loads + +# revision identifiers, used by Alembic. +revision = '963822c28879' +down_revision = '4bb79efa9855' +branch_labels = None +depends_on = None + +block_size = 100 +table_name = "molecule" + +converters = { + "symbols": lambda arr: np.array(arr, dtype=str), + "geometry": lambda arr: np.array(arr, dtype=float), + "masses": lambda arr: np.array(arr, dtype=float), + "real": lambda arr: np.array(arr, dtype=bool), + "atom_labels": lambda arr: np.array(arr, dtype=str), + "atomic_numbers": lambda arr: np.array(arr, dtype=np.int16), + "mass_numbers": lambda arr: np.array(arr, dtype=np.int16), + "fragments": lambda list_arr: [np.array(x, dtype=np.int32) for x in list_arr], +} + +def transformer(old_data): + + row = {} + for k, v in old_data.items(): + if k == "id": + continue + d = msgpackext_dumps(converters[k](v)) + row[k + "_"] = d + + return row + + +def upgrade(): + msgpack_migrations.json_to_msgpack_table(table_name, block_size, converters.keys(), transformer) + + +def downgrade(): + msgpack_migrations.json_to_msgpack_table_dropcols(table_name, block_size, update_columns) diff --git a/qcfractal/alembic/versions/d56ac42b9a43_molecule_msgpack_2.py b/qcfractal/alembic/versions/d56ac42b9a43_molecule_msgpack_2.py new file mode 100644 index 000000000..503a34573 --- /dev/null +++ b/qcfractal/alembic/versions/d56ac42b9a43_molecule_msgpack_2.py @@ -0,0 +1,37 @@ +"""Msgpack Molecule Phase 2 + +Revision ID: d56ac42b9a43 +Revises: 963822c28879 +Create Date: 2019-08-11 16:17:23.856255 + +""" +from alembic import op +import sqlalchemy as sa + +import os +import sys +sys.path.insert(1, os.path.dirname(os.path.abspath(__file__))) + +from migration_helpers import msgpack_migrations + +# revision identifiers, used by Alembic. +revision = 'd56ac42b9a43' +down_revision = '963822c28879' +branch_labels = None +depends_on = None + +table_name = "molecule" +update_columns = { + "symbols", "geometry", "masses", "real", "atom_labels", "atomic_numbers", "mass_numbers", "fragments" +} + +nullable = update_columns.copy() +nullable -= {"symbols", "geometry"} + + +def upgrade(): + msgpack_migrations.json_to_msgpack_table_altercolumns(table_name, update_columns, nullable_true=nullable) + + +def downgrade(): + raise ValueError("Cannot downgrade json to msgpack conversions") diff --git a/qcfractal/alembic/versions/da7c6f141bcb_extras_msgpack_1.py b/qcfractal/alembic/versions/da7c6f141bcb_extras_msgpack_1.py new file mode 100644 index 000000000..a0ab9d9f1 --- /dev/null +++ b/qcfractal/alembic/versions/da7c6f141bcb_extras_msgpack_1.py @@ -0,0 +1,71 @@ +"""Msgpack Remaining Phase 1 + +Revision ID: da7c6f141bcb +Revises: 05ceea11b78a +Create Date: 2019-08-12 10:12:46.478628 + +""" +from alembic import op +import sqlalchemy as sa + +import os +import sys +sys.path.insert(1, os.path.dirname(os.path.abspath(__file__))) +from migration_helpers import msgpack_migrations +from qcelemental.util import msgpackext_dumps, msgpackext_loads + +# revision identifiers, used by Alembic. +revision = 'da7c6f141bcb' +down_revision = '05ceea11b78a' +branch_labels = None +depends_on = None + +block_size = 100 + + +def transformer(old_data): + + extras = old_data["extras"] + extras.pop("_qcfractal_tags", None) # cleanup old tags + + return {"extras_": msgpackext_dumps(extras)} + + +def upgrade(): + + ## Task Queue + table_name = "task_queue" + update_columns = {"spec"} + + def transformer(old_data): + + spec = old_data["spec"] + + return {"spec_": msgpackext_dumps(spec)} + + msgpack_migrations.json_to_msgpack_table(table_name, block_size, update_columns, transformer) + + ## Service Queue + table_name = "service_queue" + update_columns = {"extra"} + + def transformer(old_data): + + spec = old_data["extra"] + + return {"extra_": msgpackext_dumps(spec)} + + msgpack_migrations.json_to_msgpack_table(table_name, block_size, update_columns, transformer, {}) + + +def downgrade(): + + ## Task Queue + table_name = "task_queue" + update_columns = {"spec"} + msgpack_migrations.json_to_msgpack_table_dropcols(table_name, block_size, update_columns) + + ## Service Queue + table_name = "service_queue" + update_columns = {"extra"} + msgpack_migrations.json_to_msgpack_table_dropcols(table_name, block_size, update_columns) diff --git a/qcfractal/alembic/versions/e32b61e2516f_extras_msgpack_2.py b/qcfractal/alembic/versions/e32b61e2516f_extras_msgpack_2.py new file mode 100644 index 000000000..f1dd48334 --- /dev/null +++ b/qcfractal/alembic/versions/e32b61e2516f_extras_msgpack_2.py @@ -0,0 +1,41 @@ +"""Msgpack Remaining Phase 2 + +Revision ID: e32b61e2516f +Revises: da7c6f141bcb +Create Date: 2019-08-12 10:13:09.694643 + +""" +from alembic import op +import sqlalchemy as sa + +import os +import sys +sys.path.insert(1, os.path.dirname(os.path.abspath(__file__))) +from migration_helpers import msgpack_migrations +from qcelemental.util import msgpackext_dumps, msgpackext_loads + +# revision identifiers, used by Alembic. +revision = 'e32b61e2516f' +down_revision = 'da7c6f141bcb' +branch_labels = None +depends_on = None + + +def upgrade(): + ## Task Queue + table_name = "task_queue" + update_columns = {"spec"} + + nullable = set() + msgpack_migrations.json_to_msgpack_table_altercolumns(table_name, update_columns, nullable_true=nullable) + + ## Service Queue + table_name = "service_queue" + update_columns = {"extra"} + + nullable = set() + msgpack_migrations.json_to_msgpack_table_altercolumns(table_name, update_columns, nullable_true=nullable) + + +def downgrade(): + raise ValueError("Cannot downgrade json to msgpack conversions") diff --git a/qcfractal/alembic/versions/migration_helpers/__init__.py b/qcfractal/alembic/versions/migration_helpers/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/qcfractal/alembic/versions/migration_helpers/msgpack_migrations.py b/qcfractal/alembic/versions/migration_helpers/msgpack_migrations.py new file mode 100644 index 000000000..d2618b307 --- /dev/null +++ b/qcfractal/alembic/versions/migration_helpers/msgpack_migrations.py @@ -0,0 +1,146 @@ +import logging +import sqlalchemy as sa +import tqdm +import numpy as np +from alembic import op +from sqlalchemy.dialects.postgresql import BYTEA +from sqlalchemy.sql.expression import func + +from qcelemental.util import msgpackext_dumps, msgpackext_loads +from qcelemental.testing import compare_recursive + +logger = logging.getLogger('alembic') + +old_type = sa.JSON +new_type = BYTEA + + +def _get_colnames(columns): + pairs = {(x, x + "_") for x in columns} + old = [x[0] for x in pairs] + new = [x[1] for x in pairs] + return (pairs, old, new) + + +def _intermediate_table(table_name, columns, read_columns=None): + + column_pairs, old_names, new_names = _get_colnames(columns) + table_data = [table_name, sa.MetaData(), sa.Column("id", sa.Integer, primary_key=True)] + table_data.extend([sa.Column(x, old_type) for x in old_names]) + table_data.extend([sa.Column(x, new_type) for x in new_names]) + + cols = old_names + new_names + + if read_columns: + table_data.extend([sa.Column(k, v) for k, v in read_columns.items()]) + + cols.extend(list(read_columns)) + + table = sa.Table(*table_data) + return table, cols + + +def json_to_msgpack_table(table_name, block_size, update_columns, transformer, read_columns=None): + + if read_columns is None: + read_columns = {} + + update_columns = list(update_columns) + + logger.info(f"Converting {table_name} from JSON to msgpack.") + logger.info(f"Columns: {update_columns}") + column_pairs, old_names, new_names = _get_colnames(update_columns) + + # Schema migration: add all the new columns. + for col_old, col_new in column_pairs: + op.add_column(table_name, sa.Column(col_new, new_type, nullable=True)) + + # Declare a view of the table + table, cols = _intermediate_table(table_name, update_columns, read_columns=read_columns) + + connection = op.get_bind() + + num_records = connection.execute(f"select count(*) from {table_name}").scalar() + + read_names = ["id"] + old_names + list(read_columns) + read_columns = [getattr(table.c, x) for x in read_names] + + logger.info("Converting data, this may take some time...") + for block in tqdm.tqdm(range(0, num_records, block_size)): + + # Pull chunk to migrate + data = connection.execute(sa.select([ + *read_columns, + ], order_by=table.c.id.asc(), offset=block, limit=block_size)).fetchall() + + # Convert chunk to msgpack + for values in data: + data = {k: v for k, v in zip(read_names, values)} + row = transformer(data) + + connection.execute(table.update().where(table.c.id == data["id"]).values(**row)) + + connection.execute('commit;') + + +def json_to_msgpack_table_dropcols(table_name, block_size, update_columns): + + column_pairs, old_names, new_names = _get_colnames(update_columns) + for col in new_names: + op.drop_column(table_name, col) + + +def json_to_msgpack_table_altercolumns(table_name, update_columns, nullable_true=None): + + if nullable_true is None: + nullable_true = set() + + connection = op.get_bind() + table, cols = _intermediate_table(table_name, update_columns) + + column_pairs, old_names, new_names = _get_colnames(update_columns) + num_records = connection.execute(f"select count(*) from {table_name}").scalar() + + old_columns = [getattr(table.c, x) for x in old_names] + new_columns = [getattr(table.c, x) for x in new_names] + + logger.info(f"Checking converted columns...") + # Pull chunk to migrate + data = connection.execute(sa.select([ + table.c.id, + *old_columns, + *new_columns, + ], order_by=table.c.id.asc())).fetchall() + # ], limit=100, order_by=func.random())).fetchall() + + col_names = ["id"] + old_names + new_names + for values in data: + row = {k: v for k, v in zip(col_names, values)} + # print(row["id"]) + # print(row.keys()) + # for k, v in row.items(): + # print(k, v) + + for name in old_names: + comp_data = msgpackext_loads(row[name + "_"]) + # try: + # assert compare_recursive(comp_data, row[name]) + # except AssertionError: + # assert compare_recursive(comp_data.ravel(), row[name], quiet=True) + + + # try: + # print(name, comp_data.dtype, comp_data) + # except: + # print(name, comp_data[0].dtype, comp_data) + # pass + # raise Exception() + logger.info(f"Dropping old columns and renaming new.") + # Drop old tables and swamp new ones in. + for old_name, new_name in column_pairs: + nullable = False + if old_name in nullable_true: + nullable = True + + op.drop_column(table_name, old_name) + op.alter_column(table_name, new_name, new_column_name=old_name, nullable=nullable) diff --git a/qcfractal/cli/qcfractal_manager.py b/qcfractal/cli/qcfractal_manager.py index 2b2cbb02e..210bdf73b 100644 --- a/qcfractal/cli/qcfractal_manager.py +++ b/qcfractal/cli/qcfractal_manager.py @@ -527,6 +527,7 @@ def parse_args(): description='A CLI for a QCFractal QueueManager with a ProcessPoolExecutor, Dask, or Parsl backend. ' 'The Dask and Parsl backends *requires* a config file due to the complexity of its setup. If a config ' 'file is specified, the remaining options serve as CLI overwrites of the config.') + parser.add_argument('--version', action='version', version=f"{qcfractal.__version__}") parser.add_argument("--config-file", type=str, default=None) diff --git a/qcfractal/cli/qcfractal_server.py b/qcfractal/cli/qcfractal_server.py index bcac79fe4..71528e94a 100644 --- a/qcfractal/cli/qcfractal_server.py +++ b/qcfractal/cli/qcfractal_server.py @@ -28,6 +28,8 @@ def ensure_postgres_alive(psql): def parse_args(): parser = argparse.ArgumentParser(description='A CLI for the QCFractalServer.') + parser.add_argument('--version', action='version', version=f"{qcfractal.__version__}") + subparsers = parser.add_subparsers(dest="command") ### Init subcommands @@ -87,8 +89,9 @@ def parse_args(): help='Creates a local pool QueueManager attached to the server.') ### Config subcommands - config = subparsers.add_parser('config', help="Manage users and permissions on a QCFractal server instance.") - config.add_argument("--base-folder", **FractalConfig.help_info("base_folder")) + info = subparsers.add_parser('info', help="Manage users and permissions on a QCFractal server instance.") + info.add_argument("category", nargs="?", default="config", choices=["config", "alembic"], help="The config category to show.") + info.add_argument("--base-folder", **FractalConfig.help_info("base_folder")) ### User subcommands user = subparsers.add_parser('user', help="Configure a QCFractal server instance.") @@ -103,7 +106,7 @@ def parse_args(): user_add.add_argument("--permissions", nargs='+', default=None, type=str, required=True, help="Permissions for the user. Allowed values: read, write, queue, compute, admin.") - user_show = user_subparsers.add_parser("show", help="Show the user's current permissions.") + user_show = user_subparsers.add_parser("info", help="Show the user's current permissions.") user_show.add_argument("username", default=None, type=str, help="The username to show.") user_modify = user_subparsers.add_parser("modify", help="Change a user's password or permissions.") @@ -119,7 +122,6 @@ def parse_args(): user_remove = user_subparsers.add_parser("remove", help="Remove a user.") user_remove.add_argument("username", default=None, type=str, help="The username to remove.") - ### Move args around args = vars(parser.parse_args()) @@ -235,10 +237,16 @@ def server_init(args, config): print("\n>>> Success! Please run `qcfractal-server start` to boot a FractalServer!") -def server_config(args, config): +def server_info(args, config): + + psql = PostgresHarness(config, quiet=False, logger=print) - print(f"Displaying QCFractal configuration:\n") - print(yaml.dump(config.dict(), default_flow_style=False)) + if args["category"] == "config": + print(f"Displaying QCFractal configuration:\n") + print(yaml.dump(config.dict(), default_flow_style=False)) + elif args["category"] == "alembic": + print(f"Displaying QCFractal Alembic CLI configuration:\n") + print(" ".join(psql.alembic_commands())) def server_start(args, config): @@ -381,7 +389,7 @@ def server_user(args, config): else: print("\n>>> Failed to add user. Perhaps the username is already taken?") sys.exit(1) - elif args["user_command"] == "show": + elif args["user_command"] == "info": print(f"\n>>> Showing permissions for user '{args['username']}'...") permissions = storage.get_user_permissions(args["username"]) if permissions is None: @@ -452,8 +460,8 @@ def main(args=None): if command == "init": server_init(args, config) - elif command == "config": - server_config(args, config) + elif command == "info": + server_info(args, config) elif command == "start": server_start(args, config) elif command == 'upgrade': diff --git a/qcfractal/cli/tests/test_cli.py b/qcfractal/cli/tests/test_cli.py index 5d7a96e1b..f1f528bd0 100644 --- a/qcfractal/cli/tests/test_cli.py +++ b/qcfractal/cli/tests/test_cli.py @@ -7,6 +7,8 @@ import pytest +import qcfractal + from qcfractal import testing from qcfractal.cli.cli_utils import read_config_file import yaml @@ -17,14 +19,15 @@ @pytest.fixture(scope="module") -def qcfractal_base_init(postgres_server): +def qcfractal_base_init(): + storage = qcfractal.TemporaryPostgres() tmpdir = tempfile.TemporaryDirectory() args = [ "qcfractal-server", "init", "--base-folder", str(tmpdir.name), "--db-own=False", "--clear-database", - f"--db-port={postgres_server.config.database.port}" + f"--db-port={storage.config.database.port}" ] assert testing.run_process(args, **_options) @@ -65,10 +68,10 @@ def test_cli_user_show(qcfractal_base_init): args = ["qcfractal-server", "user", qcfractal_base_init, "add", "test_user_show", "--permissions", "admin"] assert testing.run_process(args, **_options) - args = ["qcfractal-server", "user", qcfractal_base_init, "show", "test_user_show"] + args = ["qcfractal-server", "user", qcfractal_base_init, "info", "test_user_show"] assert testing.run_process(args, **_options) - args = ["qcfractal-server", "user", qcfractal_base_init, "show", "badname_1234"] + args = ["qcfractal-server", "user", qcfractal_base_init, "info", "badname_1234"] assert testing.run_process(args, **_options) is False @@ -177,6 +180,7 @@ def test_manager_executor_manager_boot_from_file(active_server, tmp_path): assert testing.run_process(args, interupt_after=7, **_options) +@testing.mark_slow def cli_manager_runs(config_data, tmp_path): temp_config = tmp_path / "temp_config.yaml" temp_config.write_text(yaml.dump(config_data)) @@ -184,6 +188,7 @@ def cli_manager_runs(config_data, tmp_path): assert testing.run_process(args, **_options) +@testing.mark_slow def load_manager_config(adapter, scheduler): config = read_config_file(os.path.join(_pwd, "manager_boot_template.yaml")) config["common"]["adapter"] = adapter @@ -247,18 +252,21 @@ def test_cli_managers_none(adapter, tmp_path): cli_manager_runs(config, tmp_path) +@testing.mark_slow def test_cli_managers_help(): """Test that qcfractal_manager --help works""" args = ["qcfractal-manager", "--help"] testing.run_process(args, **_options) +@testing.mark_slow def test_cli_managers_schema(): """Test that qcfractal_manager --schema works""" args = ["qcfractal-manager", "--schema"] testing.run_process(args, **_options) +@testing.mark_slow def test_cli_managers_skel(tmp_path): """Test that qcfractal_manager --skeleton works""" config = tmp_path / "config.yaml" diff --git a/qcfractal/interface/client.py b/qcfractal/interface/client.py index 90997d421..428f3cc58 100644 --- a/qcfractal/interface/client.py +++ b/qcfractal/interface/client.py @@ -80,6 +80,7 @@ def __init__(self, self.username = username self._verify = verify self._headers = {} + self.encoding = "msgpack-ext" # Mode toggle for network error testing, not public facing self._mock_network_error = False @@ -95,7 +96,7 @@ def __init__(self, from . import __version__ # Import here to avoid circular import from . import _isportal - self._headers["content_type"] = 'application/json' + self._headers["Content-Type"] = f'application/{self.encoding}' self._headers["User-Agent"] = f"qcportal/{__version__}" # Try to connect and pull general data @@ -153,6 +154,10 @@ def _repr_html_(self) -> str: """ + def _set_encoding(self, encoding): + self.encoding = encoding + self._headers["Content-Type"] = f'application/{self.encoding}' + def _request(self, method: str, service: str, *, data: str = None, noraise: bool = False, timeout: int = None): addr = self.address + service @@ -219,8 +224,9 @@ def _automodel_request(self, except ValidationError as exc: raise TypeError(str(exc)) - r = self._request(rest, name, data=payload.json(), timeout=timeout) - response = response_model.parse_raw(r.text) + r = self._request(rest, name, data=payload.serialize(self.encoding), timeout=timeout) + encoding = r.headers["Content-Type"].split("/")[1] + response = response_model.parse_raw(r.content, encoding=encoding) if full_return: return response @@ -607,7 +613,7 @@ def query_results(self, # Add references back to the client if not projection: for result in response.data: - result.client = self + result.__dict__["client"] = self if full_return: return response diff --git a/qcfractal/interface/collections/__init__.py b/qcfractal/interface/collections/__init__.py index e79e20e7d..5a4538087 100644 --- a/qcfractal/interface/collections/__init__.py +++ b/qcfractal/interface/collections/__init__.py @@ -6,7 +6,6 @@ from .dataset import Dataset from .generic import Generic from .gridoptimization_dataset import GridOptimizationDataset -from .openffworkflow import OpenFFWorkflow from .optimization_dataset import OptimizationDataset from .reaction_dataset import ReactionDataset from .torsiondrive_dataset import TorsionDriveDataset diff --git a/qcfractal/interface/collections/collection.py b/qcfractal/interface/collections/collection.py index b95728dc6..6c666dd69 100644 --- a/qcfractal/interface/collections/collection.py +++ b/qcfractal/interface/collections/collection.py @@ -10,9 +10,8 @@ from typing import Any, Dict, List, Optional, Set, Union import pandas as pd -from pydantic import BaseModel -from ..models import ObjectId, json_encoders +from ..models import ObjectId, ProtoModel class Collection(abc.ABC): @@ -45,7 +44,7 @@ def __init__(self, name: str, client: 'FractalClient' = None, **kwargs: Dict[str # Create the data model self.data = self.DataModel(**kwargs) - class DataModel(BaseModel): + class DataModel(ProtoModel): """ Internal Data structure base model typed by PyDantic @@ -62,10 +61,6 @@ class DataModel(BaseModel): tags: List[str] = [] id: str = 'local' - class Config: - json_encoders = json_encoders - extra = "forbid" - def __str__(self) -> str: """ A simple string representation of the Collection. @@ -226,7 +221,10 @@ def save(self, client: 'FractalClient' = None) -> 'ObjectId': # Add the database if (self.data.id == self.data.fields['id'].default): - self.data.id = client.add_collection(self.data.dict(), overwrite=False) + response = client.add_collection(self.data.dict(), overwrite=False, full_return=True) + if response.meta.success is False: + raise KeyError(f"Error adding collection: \n{response.meta.error_description}") + self.data.__dict__["id"] = response.data else: client.add_collection(self.data.dict(), overwrite=True) diff --git a/qcfractal/interface/collections/dataset.py b/qcfractal/interface/collections/dataset.py index 0e671328f..c09b2333f 100644 --- a/qcfractal/interface/collections/dataset.py +++ b/qcfractal/interface/collections/dataset.py @@ -5,24 +5,23 @@ import numpy as np import pandas as pd -from pydantic import BaseModel from qcelemental import constants from .collection import Collection from .collection_utils import composition_planner, register_collection -from ..models import ComputeResponse, Molecule, ObjectId +from ..models import ComputeResponse, Molecule, ObjectId, ProtoModel from ..statistics import wrap_statistics from ..visualization import bar_plot, violin_plot -class MoleculeRecord(BaseModel): +class MoleculeRecord(ProtoModel): name: str molecule_id: ObjectId comment: Optional[str] = None local_results: Dict[str, Any] = {} -class ContributedValues(BaseModel): +class ContributedValues(ProtoModel): name: str doi: Optional[str] = None theory_level: Union[str, Dict[str, str]] @@ -649,7 +648,21 @@ def set_default_program(self, program: str) -> bool: The program to default to. """ - self.data.default_program = program.lower() + self.data.__dict__["default_program"] = program.lower() + return True + + def set_default_benchmark(self, benchmark: str) -> bool: + """ + Sets the default benchmark value. + + Parameters + ---------- + benchmark : str + The benchmark to default to. + """ + + self.data.__dict__["default_benchmark"] = benchmark + return True def add_keywords(self, alias: str, program: str, keyword: 'KeywordSet', default: bool=False) -> bool: """ @@ -784,7 +797,11 @@ def get_contributed_values_column(self, key: str) -> 'Series': if isinstance(next(iter(data.values.values())), (int, float)): values = data.values else: - values = {k: [v] for k, v in data.values.items()} + # TODO temporary patch until msgpack collections + if self.data.default_driver == "gradient": + values = {k: [np.array(v).reshape(-1, 3)] for k, v in data.values.items()} + else: + values = {k: [np.array(v)] for k, v in data.values.items()} tmp_idx = pd.DataFrame.from_dict(values, orient="index", columns=[data.name]) diff --git a/qcfractal/interface/collections/gridoptimization_dataset.py b/qcfractal/interface/collections/gridoptimization_dataset.py index 33a7690a2..99bca2e13 100644 --- a/qcfractal/interface/collections/gridoptimization_dataset.py +++ b/qcfractal/interface/collections/gridoptimization_dataset.py @@ -3,15 +3,13 @@ """ from typing import Any, Dict, List, Optional, Set -from pydantic import BaseModel - -from ..models import GridOptimizationInput, Molecule, ObjectId, OptimizationSpecification, QCSpecification +from ..models import GridOptimizationInput, Molecule, ObjectId, OptimizationSpecification, ProtoModel, QCSpecification from ..models.gridoptimization import GOKeywords, ScanDimension from .collection import BaseProcedureDataset from .collection_utils import register_collection -class GOEntry(BaseModel): +class GOEntry(ProtoModel): """Data model for the `reactions` list in Dataset""" name: str initial_molecule: ObjectId @@ -20,7 +18,7 @@ class GOEntry(BaseModel): object_map: Dict[str, ObjectId] = {} -class GOEntrySpecification(BaseModel): +class GOEntrySpecification(ProtoModel): name: str description: Optional[str] optimization_spec: OptimizationSpecification diff --git a/qcfractal/interface/collections/openffworkflow.py b/qcfractal/interface/collections/openffworkflow.py deleted file mode 100644 index 6064bb1a7..000000000 --- a/qcfractal/interface/collections/openffworkflow.py +++ /dev/null @@ -1,341 +0,0 @@ -"""Mongo QCDB Fragment object and helpers -""" - -import copy -from typing import Any, Dict - -from pydantic import BaseModel - -from .collection import Collection -from .collection_utils import register_collection -from ..models import (OptimizationRecord, OptimizationSpecification, QCSpecification, TorsionDriveInput, - TorsionDriveRecord) - - -class TorsionDriveStaticOptions(BaseModel): - - keywords: Dict[str, Any] - optimization_spec: OptimizationSpecification - qc_spec: QCSpecification - - class Config: - extra = "forbid" - allow_mutation = False - - -class OptimizationStaticOptions(BaseModel): - - program: str - keywords: Dict[str, Any] = {} - qc_spec: QCSpecification - - class Config: - extra = "forbid" - allow_mutation = False - - -class OpenFFWorkflow(Collection): - """ - This is a QCA OpenFFWorkflow class. - - Attributes - ---------- - client : client.FractalClient - A FractalClient connected to a server - """ - - def __init__(self, name, client=None, **kwargs): - """ - Initializer for the OpenFFWorkflow object. If no Portal is supplied or the database name - is not present on the server that the Portal is connected to a blank database will be - created. - - Parameters - ---------- - name : str - The name of the OpenFFWorkflow - client : client.FractalClient, optional - A FractalClient connected to a server - - """ - - if client is None: - raise KeyError("OpenFFWorkflow must have a client.") - super().__init__(name, client=client, **kwargs) - - self._torsiondrive_cache = {} - - # First workflow is saved - if self.data.id == self.data.fields['id'].default: - ret = self.save() - if len(ret) == 0: - raise ValueError("Attempted to insert duplicate Workflow with name '{}'".format(name)) - self.data.id = ret[0] - - class DataModel(Collection.DataModel): - """ - Internal Data structure base model typed by PyDantic - - This structure validates input, allows server-side validation and data security, - and will create the information to pass back and forth between server and client - """ - fragments: Dict[str, Any] = {} - enumerate_states: Dict[str, Any] = { - "version": "", - "options": { - "protonation": True, - "tautomers": False, - "stereoisomers": True, - "max_states": 200, - "level": 0, - "reasonable": True, - "carbon_hybridization": True, - "suppress_hydrogen": True - } - } - enumerate_fragments: Dict[str, Any] = {"version": "", "options": {}} - torsiondrive_input: Dict[str, Any] = { - "restricted": True, - "torsiondrive_options": { - "max_conf": 1, - "terminal_torsion_resolution": 30, - "internal_torsion_resolution": 30, - "scan_internal_terminal_combination": 0, - "scan_dimension": 1 - }, - "restricted_optimization_options": { - "maximum_rotation": 30, - "interval": 5 - } - } - torsiondrive_static_options: TorsionDriveStaticOptions - optimization_static_options: OptimizationStaticOptions - - # Valid options which can be fetched from the get_options method - # Kept as separate list to be easier to read for devs - __workflow_options = ("enumerate_states", "enumerate_fragments", "torsiondrive_input", - "torsiondrive_static_options", "optimization_static_options") - - def _pre_save_prep(self, client): - pass - - def get_options(self, key): - """ - Obtains "base" workflow options that do not change. - - Parameters - ---------- - key : str - The original workflow options. - - Returns - ------- - dict - The requested options dictionary. - """ - # Get the set of options unique to the Workflow data model - if key not in self.__workflow_options: - raise KeyError("Key `{}` not understood.".format(key)) - - return copy.deepcopy(getattr(self.data, key)) - - def list_fragments(self): - """ - List all fragments associated with this workflow. - - Returns - ------- - list of str - A list of fragment ID's. - """ - return list(self.data.fragments) - - def add_fragment(self, fragment_id, data, tag=None, priority=None): - """ - Adds a new fragment to the workflow along with the associated input required. - - Parameters - ---------- - fragment_id : str - The tag associated with fragment. In general this should be the canonical isomeric - explicit hydrogen mapped SMILES tag for this fragment. - data : dict - A dictionary of label : {type, intial_molecule, grid_spacing, dihedrals} for torsiondrive type and - label : {type, initial_molecule, contraints} for an optimization type - - Example - ------- - - data = { - "label1": { - "initial_molecule": ptl.data.get_molecule("butane.json"), - "grid_spacing": [60], - "dihedrals": [[0, 2, 3, 1]], - }, - ... - } - wf.add_fragment("CCCC", data) - """ - - if fragment_id not in self.data.fragments: - self.data.fragments[fragment_id] = {} - - frag_data = self.data.fragments[fragment_id] - for name, packet in data.items(): - if name in frag_data: - print("Already found label {} for fragment_ID {}, skipping.".format(name, fragment_id)) - continue - if packet['type'] == 'torsiondrive_input': - ret = self._add_torsiondrive(packet, tag, priority) - elif packet['type'] == 'optimization_input': - ret = self._add_optimize(packet, tag, priority) - else: - raise KeyError("{} is not an OpenFFWorkflow type job".format(packet['type'])) - - # add back to fragment data - frag_data[name] = ret - - # Push collection data back to server - self.save() - - def _add_torsiondrive(self, packet, tag, priority): - # Build out a new service - torsion_meta = self.data.torsiondrive_static_options.copy(deep=True).dict() - - for k in ["grid_spacing", "dihedrals"]: - torsion_meta["keywords"][k] = packet[k] - - # Get hash of torsion - inp = TorsionDriveInput(**torsion_meta, initial_molecule=packet["initial_molecule"]) - ret = self.client.add_service([inp], tag=tag, priority=priority) - - return ret.ids[0] - - def _add_optimize(self, packet, tag, priority): - meta = self.data.optimization_static_options.copy(deep=True).dict() - - for k in ["constraints"]: - meta["keywords"][k] = packet[k] - - # Get hash of optimization - ret = self.client.add_procedure( - "optimization", meta["program"], meta, [packet["initial_molecule"]], tag=tag, priority=priority) - - return ret.ids[0] - - def get_fragment_data(self, fragments=None, refresh_cache=False): - """Obtains fragment torsiondrives from server to local data. - - Parameters - ---------- - fragments : None, optional - A list of fragment ID's to query upon - refresh_cache : bool, optional - If True requery everything, otherwise use the cache to prevent extra lookups. - """ - - # If no fragments explicitly shown, grab all - if fragments is None: - fragments = self.data.fragments.keys() - - # Figure out the lookup - lookup = [] - for frag in fragments: - lookup.extend(list(self.data.fragments[frag].values())) - - if refresh_cache is False: - lookup = list(set(lookup) - self._torsiondrive_cache.keys()) - - # Grab the data and update cache - data = self.client.query_procedures(id=lookup) - self._torsiondrive_cache.update({x.id: x for x in data}) - - def list_final_energies(self, fragments=None, refresh_cache=False): - """ - Returns the final energies for the requested fragments. - - Parameters - ---------- - fragments : None, optional - A list of fragment ID's to query upon - refresh_cache : bool, optional - If True requery everything, otherwise use the cache to prevent extra lookups. - - Returns - ------- - dict - A dictionary structure with fragment and label fields available for access. - """ - - # If no fragments explicitly shown, grab all - if fragments is None: - fragments = self.data.fragments.keys() - - # Get the data if available - self.get_fragment_data(fragments=fragments, refresh_cache=refresh_cache) - - ret = {} - for frag in fragments: - tmp = {} - for k, v in self.data.fragments[frag].items(): - if v in self._torsiondrive_cache: - # TODO figure out a better solution here - obj = self._torsiondrive_cache[v] - if isinstance(obj, TorsionDriveRecord): - tmp[k] = obj.get_final_energies() - elif isinstance(obj, OptimizationRecord): - tmp[k] = obj.get_final_energy() - else: - raise TypeError("Internal type error encoured, buy a dev a coffee.") - else: - tmp[k] = None - - ret[frag] = tmp - - return ret - - def list_final_molecules(self, fragments=None, refresh_cache=False): - """ - Returns the final molecules for the requested fragments. - - Parameters - ---------- - fragments : None, optional - A list of fragment ID's to query upon - refresh_cache : bool, optional - If True requery everything, otherwise use the cache to prevent extra lookups. - - Returns - ------- - dict - A dictionary structure with fragment and label fields available for access. - """ - - # If no fragments explicitly shown, grab all - if fragments is None: - fragments = self.data.fragments.keys() - - # Get the data if available - self.get_fragment_data(fragments=fragments, refresh_cache=refresh_cache) - - ret = {} - for frag in fragments: - tmp = {} - for k, v in self.data.fragments[frag].items(): - if v in self._torsiondrive_cache: - obj = self._torsiondrive_cache[v] - if isinstance(obj, TorsionDriveRecord): - tmp[k] = obj.get_final_molecules() - elif isinstance(obj, OptimizationRecord): - tmp[k] = obj.get_final_molecule() - else: - raise TypeError("Internal type error encoured, buy a dev a coffee.") - else: - tmp[k] = None - - ret[frag] = tmp - - return ret - - -register_collection(OpenFFWorkflow) diff --git a/qcfractal/interface/collections/optimization_dataset.py b/qcfractal/interface/collections/optimization_dataset.py index 15ec0937e..5f1ebf9f7 100644 --- a/qcfractal/interface/collections/optimization_dataset.py +++ b/qcfractal/interface/collections/optimization_dataset.py @@ -4,14 +4,13 @@ from typing import Any, Dict, List, Optional, Set, Union import pandas as pd -from pydantic import BaseModel -from ..models import Molecule, ObjectId, OptimizationSpecification, QCSpecification +from ..models import Molecule, ObjectId, OptimizationSpecification, ProtoModel, QCSpecification from .collection import BaseProcedureDataset from .collection_utils import register_collection -class OptEntry(BaseModel): +class OptEntry(ProtoModel): """Data model for the optimizations in a Dataset""" name: str initial_molecule: ObjectId @@ -20,7 +19,7 @@ class OptEntry(BaseModel): object_map: Dict[str, ObjectId] = {} -class OptEntrySpecification(BaseModel): +class OptEntrySpecification(ProtoModel): name: str description: Optional[str] optimization_spec: OptimizationSpecification diff --git a/qcfractal/interface/collections/reaction_dataset.py b/qcfractal/interface/collections/reaction_dataset.py index ccbe7eb7b..fdac80023 100644 --- a/qcfractal/interface/collections/reaction_dataset.py +++ b/qcfractal/interface/collections/reaction_dataset.py @@ -7,12 +7,11 @@ import numpy as np import pandas as pd -from pydantic import BaseModel from .collection_utils import nCr, register_collection from .dataset import Dataset from ..util import replace_dict_keys -from ..models import ComputeResponse, Molecule +from ..models import ComputeResponse, Molecule, ProtoModel class _ReactionTypeEnum(str, Enum): @@ -21,12 +20,13 @@ class _ReactionTypeEnum(str, Enum): ie = 'ie' -class ReactionRecord(BaseModel): +class ReactionRecord(ProtoModel): """Data model for the `reactions` list in Dataset""" attributes: Dict[str, Union[int, float, str]] # Might be overloaded key types reaction_results: Dict[str, dict] name: str stoichiometry: Dict[str, Dict[str, float]] + extras: Dict[str, Any] = {} class ReactionDataset(Dataset): @@ -495,13 +495,13 @@ def parse_stoichiometry(self, stoichiometry): molecule_hash = qcf_mol.get_hash() if molecule_hash not in list(self._new_molecules): - self._new_molecules[molecule_hash] = qcf_mol.json_dict() + self._new_molecules[molecule_hash] = qcf_mol elif isinstance(mol, Molecule): molecule_hash = mol.get_hash() if molecule_hash not in list(self._new_molecules): - self._new_molecules[molecule_hash] = mol.json_dict() + self._new_molecules[molecule_hash] = mol else: raise TypeError("Dataset: Parse stoichiometry: first value must either be a molecule hash, " @@ -587,8 +587,7 @@ def add_rxn(self, name, stoichiometry, reaction_results=None, attributes=None, o if not isinstance(other_fields, dict): raise TypeError("Dataset:add_rxn: other_fields must be a dictionary, not '{}'".format(type(attributes))) - for k, v in other_fields.items(): - rxn_dict[k] = v + rxn_dict["extras"] = other_fields if "default" in list(reaction_results): rxn_dict["reaction_results"] = reaction_results diff --git a/qcfractal/interface/collections/torsiondrive_dataset.py b/qcfractal/interface/collections/torsiondrive_dataset.py index 926e011f6..edb2835c7 100644 --- a/qcfractal/interface/collections/torsiondrive_dataset.py +++ b/qcfractal/interface/collections/torsiondrive_dataset.py @@ -4,16 +4,15 @@ from typing import Any, Dict, List, Optional, Set, Tuple, Union import pandas as pd -from pydantic import BaseModel -from ..models import Molecule, ObjectId, OptimizationSpecification, QCSpecification, TorsionDriveInput +from ..models import Molecule, ObjectId, OptimizationSpecification, ProtoModel, QCSpecification, TorsionDriveInput from ..models.torsiondrive import TDKeywords from ..visualization import custom_plot from .collection import BaseProcedureDataset from .collection_utils import register_collection -class TDEntry(BaseModel): +class TDEntry(ProtoModel): """Data model for the `reactions` list in Dataset""" name: str initial_molecules: Set[ObjectId] @@ -22,7 +21,7 @@ class TDEntry(BaseModel): object_map: Dict[str, ObjectId] = {} -class TDEntrySpecification(BaseModel): +class TDEntrySpecification(ProtoModel): name: str description: Optional[str] optimization_spec: OptimizationSpecification diff --git a/qcfractal/interface/data/molecules/helium_dimer.npy b/qcfractal/interface/data/molecules/helium_dimer.npy deleted file mode 100644 index 140d62717..000000000 Binary files a/qcfractal/interface/data/molecules/helium_dimer.npy and /dev/null differ diff --git a/qcfractal/interface/models/__init__.py b/qcfractal/interface/models/__init__.py index 359420850..ffe3b20be 100644 --- a/qcfractal/interface/models/__init__.py +++ b/qcfractal/interface/models/__init__.py @@ -4,7 +4,7 @@ from . import rest_models from .rest_models import rest_model, ComputeResponse -from .common_models import KeywordSet, Molecule, ObjectId, OptimizationSpecification, QCSpecification +from .common_models import KeywordSet, Molecule, ObjectId, OptimizationSpecification, ProtoModel, QCSpecification from .gridoptimization import GridOptimizationInput, GridOptimizationRecord from .model_builder import build_procedure from .model_utils import hash_dictionary, json_encoders, prepare_basis diff --git a/qcfractal/interface/models/common_models.py b/qcfractal/interface/models/common_models.py index 7cde354f9..94002bd3b 100644 --- a/qcfractal/interface/models/common_models.py +++ b/qcfractal/interface/models/common_models.py @@ -1,19 +1,19 @@ """ Common models for QCPortal/Fractal """ -import json from enum import Enum from typing import Any, Dict, Optional -from pydantic import BaseModel, validator, Schema -from qcelemental.models import Molecule, Provenance +from pydantic import Schema, validator + +from qcelemental.models import Molecule, Provenance, ProtoModel from .model_utils import hash_dictionary, prepare_basis, recursive_normalizer __all__ = ["QCSpecification", "OptimizationSpecification", "KeywordSet", "ObjectId", "DriverEnum"] # Add in QCElemental models -__all__.extend(["Molecule", "Provenance"]) +__all__.extend(["Molecule", "Provenance", "ProtoModel"]) class ObjectId(str): @@ -48,7 +48,7 @@ class DriverEnum(str, Enum): properties = 'properties' -class QCSpecification(BaseModel): +class QCSpecification(ProtoModel): """ The quantum chemistry metadata specification for individual computations such as energy, gradient, and Hessians. """ @@ -63,8 +63,7 @@ class QCSpecification(BaseModel): basis: Optional[str] = Schema( None, description="The quantum chemistry basis set to evaluate (e.g., 6-31g, cc-pVDZ, ...). Can be ``None`` for " - "methods without basis sets." - ) + "methods without basis sets.") keywords: Optional[ObjectId] = Schema( None, description="The Id of the :class:`KeywordSet` registered in the database to run this calculation with. This " @@ -73,8 +72,7 @@ class QCSpecification(BaseModel): program: str = Schema( ..., description="The quantum chemistry program to evaluate the computation with. Not all quantum chemistry programs" - " support all combinations of driver/method/basis." - ) + " support all combinations of driver/method/basis.") @validator('basis') def check_basis(cls, v): @@ -88,10 +86,6 @@ def check_program(cls, v): def check_method(cls, v): return v.lower() - class Config: - extra = "forbid" - allow_mutation = False - def form_schema_object(self, keywords: Optional['KeywordSet'] = None, checks=True) -> Dict[str, Any]: if checks and self.keywords: assert keywords.id == self.keywords @@ -114,14 +108,11 @@ def form_schema_object(self, keywords: Optional['KeywordSet'] = None, checks=Tru return ret -class OptimizationSpecification(BaseModel): +class OptimizationSpecification(ProtoModel): """ Metadata describing a geometry optimization. """ - program: str = Schema( - ..., - description="Optimization program to run the optimization with" - ) + program: str = Schema(..., description="Optimization program to run the optimization with") keywords: Optional[Dict[str, Any]] = Schema( None, description="Dictionary of keyword arguments to pass into the ``program`` when the program runs. " @@ -139,12 +130,8 @@ def check_keywords(cls, v): v = recursive_normalizer(v) return v - class Config: - extra = "forbid" - allow_mutation = False - -class KeywordSet(BaseModel): +class KeywordSet(ProtoModel): """ A key:value storage object for Keywords. """ @@ -155,13 +142,11 @@ class KeywordSet(BaseModel): hash_index: str = Schema( ..., description="The hash of this keyword set to store and check for collisions. This string is automatically " - "computed." - ) + "computed.") values: Dict[str, Any] = Schema( ..., description="The key-value pairs which make up this KeywordSet. There is no direct relation between this " - "dictionary and applicable program/spec it can be used on." - ) + "dictionary and applicable program/spec it can be used on.") lowercase: bool = Schema( True, description="String keys are in the ``values`` dict are normalized to lowercase if this is True. Assists in " @@ -175,12 +160,7 @@ class KeywordSet(BaseModel): comments: Optional[str] = Schema( None, description="Additional comments for this KeywordSet. Intended for pure human/user consumption " - "and clarity." - ) - - class Config: - extra = "forbid" - allow_mutation = False + "and clarity.") def __init__(self, **data): @@ -189,7 +169,7 @@ def __init__(self, **data): build_index = True data["hash_index"] = "placeholder" - BaseModel.__init__(self, **data) + ProtoModel.__init__(self, **data) # Overwrite options with massaged values kwargs = {"lowercase": self.lowercase} @@ -204,6 +184,3 @@ def __init__(self, **data): def get_hash_index(self): return hash_dictionary(self.values.copy()) - - def json_dict(self, *args, **kwargs): - return json.loads(self.json(*args, **kwargs)) diff --git a/qcfractal/interface/models/gridoptimization.py b/qcfractal/interface/models/gridoptimization.py index 11df7d6f8..bbb26626f 100644 --- a/qcfractal/interface/models/gridoptimization.py +++ b/qcfractal/interface/models/gridoptimization.py @@ -6,10 +6,10 @@ from enum import Enum from typing import Any, Dict, List, Tuple, Union -from pydantic import BaseModel, constr, validator, Schema +from pydantic import Schema, constr, validator -from .common_models import Molecule, ObjectId, OptimizationSpecification, QCSpecification -from .model_utils import json_encoders, recursive_normalizer +from .common_models import Molecule, ObjectId, OptimizationSpecification, ProtoModel, QCSpecification +from .model_utils import recursive_normalizer from .records import RecordBase __all__ = ["GridOptimizationInput", "GridOptimizationRecord"] @@ -34,7 +34,7 @@ class StepTypeEnum(str, Enum): relative = 'relative' -class ScanDimension(BaseModel): +class ScanDimension(ProtoModel): """ A full description of a dimension to scan over. """ @@ -58,10 +58,6 @@ class ScanDimension(BaseModel): description=str(StepTypeEnum.__doc__) ) - class Config: - extra = "forbid" - allow_mutation = False - @validator('type', 'step_type', pre=True) def check_lower_type_step_type(cls, v): return v.lower() @@ -85,7 +81,7 @@ def check_steps(cls, v): return v -class GOKeywords(BaseModel): +class GOKeywords(ProtoModel): """ GridOptimizationRecord options. """ @@ -99,16 +95,13 @@ class GOKeywords(BaseModel): "This is especially useful when combined with ``relative`` ``step_types``." ) - class Config: - extra = "forbid" - allow_mutation = False _gridopt_constr = constr(strip_whitespace=True, regex="gridoptimization") _qcfractal_constr = constr(strip_whitespace=True, regex="qcfractal") -class GridOptimizationInput(BaseModel): +class GridOptimizationInput(ProtoModel): """ The input to create a GridOptimization Service with. @@ -142,10 +135,6 @@ class GridOptimizationInput(BaseModel): "optimization." ) - class Config: - allow_mutation = False - json_encoders = json_encoders - class GridOptimizationRecord(RecordBase): """ @@ -212,10 +201,7 @@ class GridOptimizationRecord(RecordBase): ..., description="Initial grid point from which the Grid Optimization started. This grid point is the closest in " "structure to the ``starting_molecule``." - ) - - class Config(RecordBase.Config): - pass + ) # yapf: disable ## Utility diff --git a/qcfractal/interface/models/records.py b/qcfractal/interface/models/records.py index d2601db19..c3f5e57b7 100644 --- a/qcfractal/interface/models/records.py +++ b/qcfractal/interface/models/records.py @@ -4,17 +4,17 @@ import abc import datetime -import json -import numpy as np from enum import Enum from typing import Any, Dict, List, Optional, Set, Union +import numpy as np +from pydantic import Schema, constr, validator + import qcelemental as qcel -from pydantic import BaseModel, constr, validator, Schema -from .common_models import DriverEnum, ObjectId, QCSpecification -from .model_utils import hash_dictionary, json_encoders, prepare_basis, recursive_normalizer from ..visualization import scatter_plot +from .common_models import DriverEnum, ObjectId, ProtoModel, QCSpecification +from .model_utils import hash_dictionary, prepare_basis, recursive_normalizer __all__ = ["OptimizationRecord", "ResultRecord", "OptimizationRecord", "RecordBase"] @@ -29,7 +29,7 @@ class RecordStatusEnum(str, Enum): error = "ERROR" -class RecordBase(BaseModel, abc.ABC): +class RecordBase(ProtoModel, abc.ABC): """ A BaseRecord object for Result and Procedure records. Contains all basic fields common to the all records. @@ -122,9 +122,7 @@ class RecordBase(BaseModel, abc.ABC): "program which was involved in generating the data for this record." ) - class Config: - json_encoders = json_encoders - extra = "forbid" + class Config(ProtoModel.Config): build_hash_index = True @validator('program') @@ -141,7 +139,7 @@ def __init__(self, **data): # Set hash index if not present if self.Config.build_hash_index and (self.hash_index is None): - self.hash_index = self.get_hash_index() + self.__values__["hash_index"] = self.get_hash_index() def __str__(self) -> str: return f"{self.__class__.__name__}(id='{self.id}' status='{self.status}')" @@ -173,7 +171,7 @@ def get_hash_index(self) -> str: str The objects unique hash index. """ - data = self.json_dict(include=self.get_hash_fields()) + data = self.dict(include=self.get_hash_fields(), encoding="json") return hash_dictionary(data) @@ -182,9 +180,6 @@ def dict(self, *args, **kwargs): # kwargs["skip_defaults"] = True return super().dict(*args, **kwargs) - def json_dict(self, *args, **kwargs): - return json.loads(self.json(*args, **kwargs)) - ### Checkers def check_client(self, noraise: bool = False) -> bool: @@ -310,7 +305,7 @@ class ResultRecord(RecordBase): ) # Output data - return_result: Union[float, List[float], Dict[str, Any]] = Schema( + return_result: Union[float, qcel.models.types.Array[float], Dict[str, Any]] = Schema( None, description="The primary result of the calculation, output is a function of the specified ``driver``." ) @@ -365,21 +360,22 @@ def build_schema_input(self, molecule: 'Molecule', keywords: Optional['KeywordsS extras=self.extras) return model - def consume_output(self, data: Dict[str, Any], checks: bool = True): + def _consume_output(self, data: Dict[str, Any], checks: bool = True): assert self.method == data["model"]["method"] + values = self.__dict__ # Result specific - self.extras = data["extras"] - self.extras.pop("_qcfractal_tags", None) - self.return_result = data["return_result"] - self.properties = data["properties"] + values["extras"] = data["extras"] + values["extras"].pop("_qcfractal_tags", None) + values["return_result"] = data["return_result"] + values["properties"] = data["properties"] # Standard blocks - self.provenance = data["provenance"] - self.error = data["error"] - self.stdout = data["stdout"] - self.stderr = data["stderr"] - self.status = "COMPLETE" + values["provenance"] = data["provenance"] + values["error"] = data["error"] + values["stdout"] = data["stdout"] + values["stderr"] = data["stderr"] + values["status"] = "COMPLETE" ## QCSchema constructors diff --git a/qcfractal/interface/models/rest_models.py b/qcfractal/interface/models/rest_models.py index a586413dd..c383c4e78 100644 --- a/qcfractal/interface/models/rest_models.py +++ b/qcfractal/interface/models/rest_models.py @@ -3,11 +3,10 @@ """ from typing import Any, Dict, List, Optional, Tuple, Union -from pydantic import BaseConfig, BaseModel, constr, validator, Schema +from pydantic import Schema, constr, validator -from .common_models import KeywordSet, Molecule, ObjectId +from .common_models import KeywordSet, Molecule, ObjectId, ProtoModel from .gridoptimization import GridOptimizationInput -from .model_utils import json_encoders from .records import ResultRecord from .task_models import PriorityEnum, TaskRecord from .torsiondrive import TorsionDriveInput @@ -20,7 +19,7 @@ __rest_models = {} -def register_model(name: str, rest: str, body: 'BaseModel', response: 'BaseModel') -> None: +def register_model(name: str, rest: str, body: 'ProtoModel', response: 'ProtoModel') -> None: """ Register a REST model. @@ -30,9 +29,9 @@ def register_model(name: str, rest: str, body: 'BaseModel', response: 'BaseModel The REST endpoint name. rest : str The REST endpoint type. - body : BaseModel + body : ProtoModel The REST query body model. - response : BaseModel + response : ProtoModel The REST query response model. """ @@ -49,7 +48,7 @@ def register_model(name: str, rest: str, body: 'BaseModel', response: 'BaseModel __rest_models[name][rest] = (body, response) -def rest_model(name: str, rest: str) -> Tuple['BaseModel', 'BaseModel']: +def rest_model(name: str, rest: str) -> Tuple['ProtoModel', 'ProtoModel']: """Aquires a REST Model Parameters @@ -61,7 +60,7 @@ def rest_model(name: str, rest: str) -> Tuple['BaseModel', 'BaseModel']: Returns ------- - Tuple['BaseModel', 'BaseModel'] + Tuple['ProtoModel', 'ProtoModel'] The (body, response) models of the REST request. """ @@ -82,23 +81,20 @@ def rest_model(name: str, rest: str) -> Tuple['BaseModel', 'BaseModel']: QueryProjection = Optional[Dict[str, bool]] -class RESTConfig(BaseConfig): - json_encoders = json_encoders - extra = "forbid" +class EmptyMeta(ProtoModel): + pass -class EmptyMeta(BaseModel): +class EmptyMeta(ProtoModel): """ There is no metadata accepted, so an empty metadata is sent for completion. """ - class Config(RESTConfig): - pass auto_gen_docs_on_demand(EmptyMeta) -class ResponseMeta(BaseModel): +class ResponseMeta(ProtoModel): """ Standard Fractal Server response metadata """ @@ -117,9 +113,6 @@ class ResponseMeta(BaseModel): "of no errors." ) - class Config(RESTConfig): - pass - auto_gen_docs_on_demand(ResponseMeta) @@ -137,9 +130,6 @@ class ResponseGETMeta(ResponseMeta): description="The number of entries which were already found in the database from the set which was provided." ) - class Config(RESTConfig): - pass - auto_gen_docs_on_demand(ResponseGETMeta, force_reapply=True) @@ -162,14 +152,11 @@ class ResponsePOSTMeta(ResponseMeta): description="All errors with validating submitted objects will be documented here." ) - class Config(RESTConfig): - pass - auto_gen_docs_on_demand(ResponsePOSTMeta, force_reapply=True) -class QueryMeta(BaseModel): +class QueryMeta(ProtoModel): """ Standard Fractal Server metadata for Database queries containing pagination information """ @@ -182,9 +169,6 @@ class QueryMeta(BaseModel): description="The number of records to skip on the query." ) - class Config(RESTConfig): - pass - auto_gen_docs_on_demand(QueryMeta) @@ -198,14 +182,11 @@ class QueryMetaProjection(QueryMeta): description="Additional projection information to pass to the query. Expert-level object." ) - class Config(RESTConfig): - pass - auto_gen_docs_on_demand(QueryMetaProjection, force_reapply=True) -class ComputeResponse(BaseModel): +class ComputeResponse(ProtoModel): """ The response model from the Fractal Server when new Compute or Services are added. """ @@ -222,9 +203,6 @@ class ComputeResponse(BaseModel): description="The list of object Ids which already existed in the database." ) - class Config(RESTConfig): - pass - def __str__(self) -> str: return f"ComputeResponse(nsubmitted={len(self.submitted)} nexisting={len(self.existing)})" @@ -266,15 +244,12 @@ def merge(self, other: 'ComputeResponse') -> 'ComputeResponse': ### Information -class InformationGETBody(BaseModel): - - class Config(RESTConfig): - pass +class InformationGETBody(ProtoModel): + pass -class InformationGETResponse(BaseModel): - - class Config(RESTConfig): +class InformationGETResponse(ProtoModel): + class Config(ProtoModel.Config): extra = "allow" @@ -283,8 +258,8 @@ class Config(RESTConfig): ### KVStore -class KVStoreGETBody(BaseModel): - class Data(BaseModel): +class KVStoreGETBody(ProtoModel): + class Data(ProtoModel): id: QueryObjectId = Schema( None, description="Id of the Key/Value Storage object to get." @@ -299,11 +274,8 @@ class Data(BaseModel): description="Data of the KV Get field: consists of a dict for Id of the Key/Value object to fetch." ) - class Config(RESTConfig): - pass - -class KVStoreGETResponse(BaseModel): +class KVStoreGETResponse(ProtoModel): meta: ResponseGETMeta = Schema( ..., description=common_docs[ResponseGETMeta] @@ -313,9 +285,6 @@ class KVStoreGETResponse(BaseModel): description="The entries of Key/Value object requested." ) - class Config(RESTConfig): - pass - register_model("kvstore", "GET", KVStoreGETBody, KVStoreGETResponse) auto_gen_docs_on_demand(KVStoreGETBody) @@ -324,8 +293,8 @@ class Config(RESTConfig): ### Molecule response -class MoleculeGETBody(BaseModel): - class Data(BaseModel): +class MoleculeGETBody(ProtoModel): + class Data(ProtoModel): id: QueryObjectId = Schema( None, description="Exact Id of the Molecule to fetch from the database." @@ -341,9 +310,6 @@ class Data(BaseModel): "contains no connectivity information." ) - class Config(RESTConfig): - pass - meta: QueryMeta = Schema( QueryMeta(), description=common_docs[QueryMeta] @@ -353,11 +319,8 @@ class Config(RESTConfig): description="Data fields for a Molecule query." # Because Data is internal, this may not document sufficiently ) - class Config(RESTConfig): - pass - -class MoleculeGETResponse(BaseModel): +class MoleculeGETResponse(ProtoModel): meta: ResponseGETMeta = Schema( ..., description=common_docs[ResponseGETMeta] @@ -367,16 +330,13 @@ class MoleculeGETResponse(BaseModel): description="The List of Molecule objects found by the query." ) - class Config(RESTConfig): - pass - register_model("molecule", "GET", MoleculeGETBody, MoleculeGETResponse) auto_gen_docs_on_demand(MoleculeGETBody) auto_gen_docs_on_demand(MoleculeGETResponse) -class MoleculePOSTBody(BaseModel): +class MoleculePOSTBody(ProtoModel): meta: EmptyMeta = Schema( {}, description=common_docs[EmptyMeta] @@ -386,11 +346,8 @@ class MoleculePOSTBody(BaseModel): description="A list of :class:`Molecule` objects to add to the Database." ) - class Config(RESTConfig): - pass - -class MoleculePOSTResponse(BaseModel): +class MoleculePOSTResponse(ProtoModel): meta: ResponsePOSTMeta = Schema( ..., description=common_docs[ResponsePOSTMeta] @@ -402,9 +359,6 @@ class MoleculePOSTResponse(BaseModel): "existing Id (entries are not duplicated)." ) - class Config(RESTConfig): - pass - register_model("molecule", "POST", MoleculePOSTBody, MoleculePOSTResponse) auto_gen_docs_on_demand(MoleculePOSTBody) @@ -413,14 +367,11 @@ class Config(RESTConfig): ### Keywords -class KeywordGETBody(BaseModel): - class Data(BaseModel): +class KeywordGETBody(ProtoModel): + class Data(ProtoModel): id: QueryObjectId = None hash_index: QueryStr = None - class Config(RESTConfig): - pass - meta: QueryMeta = Schema( QueryMeta(), description=common_docs[QueryMeta] @@ -430,11 +381,8 @@ class Config(RESTConfig): description="The formal query for a Keyword fetch, contains ``id`` or ``hash_index`` for the object to fetch." ) - class Config(RESTConfig): - pass - -class KeywordGETResponse(BaseModel): +class KeywordGETResponse(ProtoModel): meta: ResponseGETMeta = Schema( ..., description=common_docs[ResponseGETMeta] @@ -444,16 +392,13 @@ class KeywordGETResponse(BaseModel): description="The :class:`KeywordSet` found from in the database based on the query." ) - class Config(RESTConfig): - pass - register_model("keyword", "GET", KeywordGETBody, KeywordGETResponse) auto_gen_docs_on_demand(KeywordGETBody) auto_gen_docs_on_demand(KeywordGETResponse) -class KeywordPOSTBody(BaseModel): +class KeywordPOSTBody(ProtoModel): meta: EmptyMeta = Schema( {}, description="There is no metadata with this, so an empty metadata is sent for completion." @@ -463,11 +408,8 @@ class KeywordPOSTBody(BaseModel): description="The list of :class:`KeywordSet` objects to add to the database." ) - class Config(RESTConfig): - pass - -class KeywordPOSTResponse(BaseModel): +class KeywordPOSTResponse(ProtoModel): data: List[Optional[ObjectId]] = Schema( ..., description="The Ids assigned to the added :class:`KeywordSet` objects. In the event of duplicates, the Id " @@ -478,9 +420,6 @@ class KeywordPOSTResponse(BaseModel): description=common_docs[ResponsePOSTMeta] ) - class Config(RESTConfig): - pass - register_model("keyword", "POST", KeywordPOSTBody, KeywordPOSTResponse) auto_gen_docs_on_demand(KeywordPOSTBody) @@ -489,8 +428,8 @@ class Config(RESTConfig): ### Collections -class CollectionGETBody(BaseModel): - class Data(BaseModel): +class CollectionGETBody(ProtoModel): + class Data(ProtoModel): collection: str = Schema( None, description="The specific collection to look up as its identified in the database." @@ -504,18 +443,12 @@ class Data(BaseModel): def cast_to_lower(cls, v): return v.lower() - class Config(RESTConfig): - pass - - class Meta(BaseModel): + class Meta(ProtoModel): projection: Dict[str, Any] = Schema( None, description="Additional projection information to pass to the query. Expert-level object." ) - class Config(RESTConfig): - pass - meta: Meta = Schema( None, description="Additional metadata to make with the query. Collections can only have a ``projection`` key in its " @@ -526,11 +459,8 @@ class Config(RESTConfig): description="Information about the Collection to search the database with." ) - class Config(RESTConfig): - pass - -class CollectionGETResponse(BaseModel): +class CollectionGETResponse(ProtoModel): meta: ResponseGETMeta = Schema( ..., description=common_docs[ResponseGETMeta] @@ -547,27 +477,21 @@ def ensure_collection_name_in_data_get_res(cls, v): raise ValueError("Dicts in 'data' must have both 'collection' and 'name'") return v - class Config(RESTConfig): - pass - register_model("collection", "GET", CollectionGETBody, CollectionGETResponse) auto_gen_docs_on_demand(CollectionGETBody) auto_gen_docs_on_demand(CollectionGETResponse) -class CollectionPOSTBody(BaseModel): - class Meta(BaseModel): +class CollectionPOSTBody(ProtoModel): + class Meta(ProtoModel): overwrite: bool = Schema( False, description="The existing Collection in the database will be updated if this is True, otherwise will " "remain unmodified if it already exists." ) - class Config(RESTConfig): - pass - - class Data(BaseModel): + class Data(ProtoModel): id: str = Schema( "local", # Auto blocks overwriting in a socket description="The Id of the object to assign in the database. If 'local', then it will not overwrite " @@ -582,13 +506,13 @@ class Data(BaseModel): description="The common name of this Collection." ) + class Config(ProtoModel.Config): + extra = "allow" + @validator("collection") def cast_to_lower(cls, v): return v.lower() - class Config(RESTConfig): - extra = "allow" - meta: Meta = Schema( Meta(), description="Metadata to specify how the Database should handle adding this Collection if it already exists. " @@ -600,11 +524,8 @@ class Config(RESTConfig): description="The data associated with this Collection to add to the database." ) - class Config(RESTConfig): - pass - -class CollectionPOSTResponse(BaseModel): +class CollectionPOSTResponse(ProtoModel): data: Union[str, None] = Schema( ..., description="The Id of the Collection uniquely pointing to it in the Database. If the Collection was not added " @@ -615,9 +536,6 @@ class CollectionPOSTResponse(BaseModel): description=common_docs[ResponsePOSTMeta] ) - class Config(RESTConfig): - pass - register_model("collection", "POST", CollectionPOSTBody, CollectionPOSTResponse) auto_gen_docs_on_demand(CollectionPOSTBody) @@ -626,8 +544,8 @@ class Config(RESTConfig): ### Result -class ResultGETBody(BaseModel): - class Data(BaseModel): +class ResultGETBody(ProtoModel): + class Data(ProtoModel): id: QueryObjectId = Schema( None, description="The exact Id to fetch from the database. If this is set as a search condition, there is no " @@ -673,9 +591,6 @@ class Data(BaseModel): ":class:`RecordStatusEnum` for valid statuses and more information." ) - class Config(RESTConfig): - pass - @validator('keywords', pre=True) def validate_keywords(cls, v): if v is None: @@ -697,11 +612,8 @@ def validate_basis(cls, v): description="The keys with data to search the database on for individual quantum chemistry computations." ) - class Config(RESTConfig): - pass - -class ResultGETResponse(BaseModel): +class ResultGETResponse(ProtoModel): meta: ResponseGETMeta = Schema( ..., description=common_docs[ResponseGETMeta] @@ -720,9 +632,6 @@ def ensure_list_of_dict(cls, v): return [v] return v - class Config(RESTConfig): - pass - register_model("result", "GET", ResultGETBody, ResultGETResponse) auto_gen_docs_on_demand(ResultGETBody) @@ -731,8 +640,8 @@ class Config(RESTConfig): ### Procedures -class ProcedureGETBody(BaseModel): - class Data(BaseModel): +class ProcedureGETBody(ProtoModel): + class Data(ProtoModel): id: QueryObjectId = Schema( None, description="The exact Id to fetch from the database. If this is set as a search condition, there is no " @@ -766,9 +675,6 @@ class Data(BaseModel): ":class:`RecordStatusEnum` for valid statuses." ) - class Config(RESTConfig): - pass - meta: QueryMetaProjection = Schema( QueryMetaProjection(), description=common_docs[QueryMetaProjection] @@ -778,11 +684,8 @@ class Config(RESTConfig): description="The keys with data to search the database on for Procedures." ) - class Config(RESTConfig): - pass - -class ProcedureGETResponse(BaseModel): +class ProcedureGETResponse(ProtoModel): meta: ResponseGETMeta = Schema( ..., description=common_docs[ResponseGETMeta] @@ -792,9 +695,6 @@ class ProcedureGETResponse(BaseModel): description="The list of Procedure specs found based on the query." ) - class Config(RESTConfig): - pass - register_model("procedure", "GET", ProcedureGETBody, ProcedureGETResponse) auto_gen_docs_on_demand(ProcedureGETBody) @@ -803,8 +703,8 @@ class Config(RESTConfig): ### Task Queue -class TaskQueueGETBody(BaseModel): - class Data(BaseModel): +class TaskQueueGETBody(ProtoModel): + class Data(ProtoModel): id: QueryObjectId = Schema( None, description="The exact Id to fetch from the database. If this is set as a search condition, there is no " @@ -833,9 +733,6 @@ class Data(BaseModel): "database, if it exists. See also :class:`ResultRecord`." ) - class Config(RESTConfig): - pass - meta: QueryMetaProjection = Schema( QueryMetaProjection(), description=common_docs[QueryMetaProjection] @@ -846,7 +743,7 @@ class Config(RESTConfig): ) -class TaskQueueGETResponse(BaseModel): +class TaskQueueGETResponse(ProtoModel): meta: ResponseGETMeta = Schema( ..., description=common_docs[ResponseGETMeta] @@ -858,17 +755,14 @@ class TaskQueueGETResponse(BaseModel): "on the projection." ) - class Config(RESTConfig): - pass - register_model("task_queue", "GET", TaskQueueGETBody, TaskQueueGETResponse) auto_gen_docs_on_demand(TaskQueueGETBody) auto_gen_docs_on_demand(TaskQueueGETResponse) -class TaskQueuePOSTBody(BaseModel): - class Meta(BaseModel): +class TaskQueuePOSTBody(ProtoModel): + class Meta(ProtoModel): procedure: str = Schema( ..., description="Name of the procedure which the Task will execute." @@ -888,7 +782,7 @@ class Meta(BaseModel): description=str(PriorityEnum.__doc__) ) - class Config(RESTConfig): + class Config(ProtoModel.Config): allow_extra = "allow" @validator('priority', pre=True) @@ -907,11 +801,8 @@ def munge_priority(cls, v): "part of this Task." ) - class Config(RESTConfig): - pass - -class TaskQueuePOSTResponse(BaseModel): +class TaskQueuePOSTResponse(ProtoModel): meta: ResponsePOSTMeta = Schema( ..., @@ -922,17 +813,14 @@ class TaskQueuePOSTResponse(BaseModel): description="Data returned from the server from adding a Task." ) - class Config(RESTConfig): - pass - register_model("task_queue", "POST", TaskQueuePOSTBody, TaskQueuePOSTResponse) auto_gen_docs_on_demand(TaskQueuePOSTBody) auto_gen_docs_on_demand(TaskQueuePOSTResponse) -class TaskQueuePUTBody(BaseModel): - class Data(BaseModel): +class TaskQueuePUTBody(ProtoModel): + class Data(ProtoModel): id: QueryObjectId = Schema( None, description="The exact Id to target in database. If this is set as a search condition, there is no " @@ -945,18 +833,12 @@ class Data(BaseModel): "database, if it exists. See also :class:`ResultRecord`." ) - class Config(RESTConfig): - pass - - class Meta(BaseModel): + class Meta(ProtoModel): operation: str = Schema( ..., description="The specific action you are taking as part of this update." ) - class Config(RESTConfig): - pass - @validator("operation") def cast_to_lower(cls, v): return v.lower() @@ -970,20 +852,14 @@ def cast_to_lower(cls, v): description="The information which contains the Task target in the database." ) - class Config(RESTConfig): - pass - -class TaskQueuePUTResponse(BaseModel): - class Data(BaseModel): +class TaskQueuePUTResponse(ProtoModel): + class Data(ProtoModel): n_updated: int = Schema( ..., description="The number of tasks which were changed." ) - class Config(RESTConfig): - pass - meta: ResponseMeta = Schema( ..., description=common_docs[ResponseMeta] @@ -993,9 +869,6 @@ class Config(RESTConfig): description="Information returned from attempting updates of Tasks." ) - class Config(RESTConfig): - pass - register_model("task_queue", "PUT", TaskQueuePUTBody, TaskQueuePUTResponse) auto_gen_docs_on_demand(TaskQueuePUTBody) @@ -1004,8 +877,8 @@ class Config(RESTConfig): ### Service Queue -class ServiceQueueGETBody(BaseModel): - class Data(BaseModel): +class ServiceQueueGETBody(ProtoModel): + class Data(ProtoModel): id: QueryObjectId = Schema( None, description="The exact Id to fetch from the database. If this is set as a search condition, there is no " @@ -1039,11 +912,8 @@ class Data(BaseModel): description="The keys with data to search the database on for Services." ) - class Config(RESTConfig): - pass - -class ServiceQueueGETResponse(BaseModel): +class ServiceQueueGETResponse(ProtoModel): meta: ResponseGETMeta = Schema( ..., description=common_docs[ResponseGETMeta] @@ -1053,17 +923,14 @@ class ServiceQueueGETResponse(BaseModel): description="The return of Services found in the database mapping their Ids to the Service spec." ) - class Config(RESTConfig): - pass - register_model("service_queue", "GET", ServiceQueueGETBody, ServiceQueueGETResponse) auto_gen_docs_on_demand(ServiceQueueGETBody) auto_gen_docs_on_demand(ServiceQueueGETResponse) -class ServiceQueuePOSTBody(BaseModel): - class Meta(BaseModel): +class ServiceQueuePOSTBody(ProtoModel): + class Meta(ProtoModel): tag: Optional[str] = Schema( None, description="Tag to assign to the Tasks this Service will generate so that Queue Managers can pull only " @@ -1075,9 +942,6 @@ class Meta(BaseModel): description="Priority given to this Tasks created by this Service. Higher priority will be pulled first." ) - class Config(RESTConfig): - pass - meta: Meta = Schema( ..., description="Metadata information for the Service for the Tag and Priority of Tasks this Service will create." @@ -1087,11 +951,8 @@ class Config(RESTConfig): description="A list the specification for Procedures this Service will manage and generate Tasks for." ) - class Config(RESTConfig): - pass - -class ServiceQueuePOSTResponse(BaseModel): +class ServiceQueuePOSTResponse(ProtoModel): meta: ResponsePOSTMeta = Schema( ..., @@ -1102,9 +963,6 @@ class ServiceQueuePOSTResponse(BaseModel): description="Data returned from the server from adding a Service." ) - class Config(RESTConfig): - pass - register_model("service_queue", "POST", ServiceQueuePOSTBody, ServiceQueuePOSTResponse) auto_gen_docs_on_demand(ServiceQueuePOSTBody) @@ -1113,7 +971,7 @@ class Config(RESTConfig): ### Queue Manager -class QueueManagerMeta(BaseModel): +class QueueManagerMeta(ProtoModel): """ Validation and identification Meta information for the Queue Manager's communication with the Fractal Server. """ @@ -1163,17 +1021,14 @@ class QueueManagerMeta(BaseModel): description="Optional queue tag to pull Tasks from." ) - class Config(RESTConfig): - pass - # Add the new QueueManagerMeta to the docs auto_gen_docs_on_demand(QueueManagerMeta) common_docs[QueueManagerMeta] = str(get_base_docs(QueueManagerMeta)) -class QueueManagerGETBody(BaseModel): - class Data(BaseModel): +class QueueManagerGETBody(ProtoModel): + class Data(ProtoModel): limit: int = Schema( ..., description="Max number of Queue Managers to get from the server." @@ -1189,11 +1044,8 @@ class Data(BaseModel): "number of tasks to pull." ) - class Config(RESTConfig): - pass - -class QueueManagerGETResponse(BaseModel): +class QueueManagerGETResponse(ProtoModel): meta: ResponseGETMeta = Schema( ..., description=common_docs[ResponseGETMeta] @@ -1203,13 +1055,12 @@ class QueueManagerGETResponse(BaseModel): description="A list of tasks retrieved from the server to compute." ) - register_model("queue_manager", "GET", QueueManagerGETBody, QueueManagerGETResponse) auto_gen_docs_on_demand(QueueManagerGETBody) auto_gen_docs_on_demand(QueueManagerGETResponse) -class QueueManagerPOSTBody(BaseModel): +class QueueManagerPOSTBody(ProtoModel): meta: QueueManagerMeta = Schema( ..., description=common_docs[QueueManagerMeta] @@ -1219,11 +1070,8 @@ class QueueManagerPOSTBody(BaseModel): description="A Dictionary of tasks to return to the server." ) - class Config: - json_encoders = json_encoders - -class QueueManagerPOSTResponse(BaseModel): +class QueueManagerPOSTResponse(ProtoModel): meta: ResponsePOSTMeta = Schema( ..., description=common_docs[ResponsePOSTMeta] @@ -1239,8 +1087,8 @@ class QueueManagerPOSTResponse(BaseModel): auto_gen_docs_on_demand(QueueManagerPOSTResponse) -class QueueManagerPUTBody(BaseModel): - class Data(BaseModel): +class QueueManagerPUTBody(ProtoModel): + class Data(ProtoModel): operation: str meta: QueueManagerMeta = Schema( @@ -1254,7 +1102,7 @@ class Data(BaseModel): ) -class QueueManagerPUTResponse(BaseModel): +class QueueManagerPUTResponse(ProtoModel): meta: Dict[str, Any] = Schema( {}, description=common_docs[EmptyMeta] diff --git a/qcfractal/interface/models/task_models.py b/qcfractal/interface/models/task_models.py index 6a63399c6..416d3ad15 100644 --- a/qcfractal/interface/models/task_models.py +++ b/qcfractal/interface/models/task_models.py @@ -1,15 +1,15 @@ import datetime -import json from enum import Enum from typing import Any, Dict, List, Optional, Union -from pydantic import BaseModel, validator, Schema +from pydantic import validator, Schema + from qcelemental.models import ComputeError -from .common_models import ObjectId +from .common_models import ObjectId, ProtoModel -class DBRef(BaseModel): +class DBRef(ProtoModel): """ Database locator reference object. Identifies an exact record in a database. """ @@ -55,7 +55,7 @@ class BaseResultEnum(str, Enum): procedure = "procedure" -class PythonComputeSpec(BaseModel): +class PythonComputeSpec(ProtoModel): function: str = Schema( ..., description="The module and function name of a Python-callable to call. Of the form 'module.function'." @@ -70,7 +70,7 @@ class PythonComputeSpec(BaseModel): ) -class TaskRecord(BaseModel): +class TaskRecord(ProtoModel): id: ObjectId = Schema( None, @@ -143,9 +143,6 @@ def __init__(self, **data): super().__init__(**data) - class Config: - extra = "forbid" - @validator('priority', pre=True) def munge_priority(cls, v): if isinstance(v, str): @@ -161,6 +158,3 @@ def check_program(cls, v): @validator('procedure') def check_procedure(cls, v): return v.lower() - - def json_dict(self, *args, **kwargs): - return json.loads(self.json(*args, **kwargs)) \ No newline at end of file diff --git a/qcfractal/interface/models/tests/test_hashes.py b/qcfractal/interface/models/tests/test_hashes.py index c5a687b47..2e6f38e1e 100644 --- a/qcfractal/interface/models/tests/test_hashes.py +++ b/qcfractal/interface/models/tests/test_hashes.py @@ -4,7 +4,7 @@ from ..common_models import KeywordSet, Molecule from ..gridoptimization import GridOptimizationRecord -from ..records import ResultRecord, OptimizationRecord +from ..records import OptimizationRecord, ResultRecord from ..torsiondrive import TorsionDriveRecord ## Molecule hashes @@ -337,23 +337,23 @@ def test_gridoptimization_canary_hash(data, hash_index): "keywords": { "dihedrals": [[0, 1, 2, 3]], "grid_spacing": [10], - "tol": 1.e-12 + "energy_upper_limit": 1.e-12 } - }, "cb3f9c9bd4eda742b0429ebea0c3d12719ab2582"), + }, "37b65cba19ec4fbd0d54c10fd74d0a27f627cdea"), ({ "keywords": { "dihedrals": [[0, 1, 2, 3]], "grid_spacing": [10], - "tol": 0 + "energy_upper_limit": 0 } - }, "cb3f9c9bd4eda742b0429ebea0c3d12719ab2582"), + }, "37b65cba19ec4fbd0d54c10fd74d0a27f627cdea"), ({ "keywords": { "dihedrals": [[0, 1, 2, 3]], "grid_spacing": [10], - "tol": 1.e-9 + "energy_upper_limit": 1.e-9 } - }, "903cc0deb4f0e7b8bc41a69cf5fbd0c9420176a4"), + }, "64b400229d3e5bff476e47c093c1a159c69d9fdc"), # Check opt keywords stability ({ diff --git a/qcfractal/interface/models/tests/test_model_utils.py b/qcfractal/interface/models/tests/test_model_utils.py index 534262035..dc7189814 100644 --- a/qcfractal/interface/models/tests/test_model_utils.py +++ b/qcfractal/interface/models/tests/test_model_utils.py @@ -1,7 +1,8 @@ -import pytest import numpy as np +import pytest + +from ..model_utils import hash_dictionary, recursive_normalizer -from ..model_utils import recursive_normalizer, hash_dictionary @pytest.mark.parametrize("unormalized, normalized", [ (5.0 + 1.e-12, 5.0), diff --git a/qcfractal/interface/models/torsiondrive.py b/qcfractal/interface/models/torsiondrive.py index 60c09245e..c845d0260 100644 --- a/qcfractal/interface/models/torsiondrive.py +++ b/qcfractal/interface/models/torsiondrive.py @@ -7,18 +7,18 @@ from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np -from pydantic import BaseModel, constr, validator, Schema +from pydantic import constr, validator, Schema from qcelemental import constants -from .common_models import Molecule, ObjectId, OptimizationSpecification, QCSpecification -from .model_utils import json_encoders, recursive_normalizer -from .records import RecordBase from ..visualization import scatter_plot +from .common_models import Molecule, ObjectId, OptimizationSpecification, QCSpecification, ProtoModel +from .model_utils import recursive_normalizer +from .records import RecordBase __all__ = ["TorsionDriveInput", "TorsionDriveRecord"] -class TDKeywords(BaseModel): +class TDKeywords(ProtoModel): """ TorsionDriveRecord options """ @@ -52,16 +52,12 @@ class TDKeywords(BaseModel): def __init__(self, **kwargs): super().__init__(**recursive_normalizer(kwargs)) - class Config: - extra = "allow" - allow_mutation = False - _td_constr = constr(strip_whitespace=True, regex="torsiondrive") _qcfractal_constr = constr(strip_whitespace=True, regex="qcfractal") -class TorsionDriveInput(BaseModel): +class TorsionDriveInput(ProtoModel): """ A TorsionDriveRecord Input base class """ @@ -100,11 +96,6 @@ def check_initial_molecules(cls, v): v = [v] return v - class Config: - extras = "forbid" - allow_mutation = False - json_encoders = json_encoders - class TorsionDriveRecord(RecordBase): """ @@ -185,7 +176,7 @@ def _deserialize_key(self, key: str) -> Tuple[int, ...]: return tuple(json.loads(key)) def _organize_return(self, data: Dict[str, Any], key: Union[int, str, None], - minimum: bool=False) -> Dict[str, Any]: + minimum: bool = False) -> Dict[str, Any]: if key is None: return {self._deserialize_key(k): copy.deepcopy(v) for k, v in data.items()} @@ -198,9 +189,11 @@ def _organize_return(self, data: Dict[str, Any], key: Union[int, str, None], else: return copy.deepcopy(data[key]) + ## Query - def get_history(self, key: Union[int, Tuple[int, ...], str] = None, minimum: bool = False) -> Dict[str, List['ResultRecord']]: + def get_history(self, key: Union[int, Tuple[int, ...], str] = None, + minimum: bool = False) -> Dict[str, List['ResultRecord']]: """Queries the server for all optimization trajectories. Parameters diff --git a/qcfractal/interface/statistics.py b/qcfractal/interface/statistics.py index 374a0afe0..747a2d0bd 100644 --- a/qcfractal/interface/statistics.py +++ b/qcfractal/interface/statistics.py @@ -19,6 +19,9 @@ def mean_signed_error(value, bench, **kwargs): def mean_unsigned_error(value, bench, **kwargs): + print() + print(value) + print(bench) return np.mean(np.abs(value - bench)) diff --git a/qcfractal/interface/tests/test_dataset.py b/qcfractal/interface/tests/test_dataset.py index f6fdfc4ec..2e0f8bc2f 100644 --- a/qcfractal/interface/tests/test_dataset.py +++ b/qcfractal/interface/tests/test_dataset.py @@ -87,7 +87,7 @@ def water_ds(): ds.add_ie_rxn("Water dimer", dimer.to_string("psi4")) # Add unverified records (requires a active server) - ds.data.records = ds._new_records + ds.data.__dict__["records"] = ds._new_records return ds @@ -189,7 +189,7 @@ def nbody_ds(): } # Add unverified records (requires a active server) - ds.data.records = ds._new_records + ds.data.__dict__["records"] = ds._new_records return ds diff --git a/qcfractal/interface/tests/test_molecule.py b/qcfractal/interface/tests/test_molecule.py index fec77acae..6431c53ad 100644 --- a/qcfractal/interface/tests/test_molecule.py +++ b/qcfractal/interface/tests/test_molecule.py @@ -2,57 +2,50 @@ Tests the imports and exports of the Molecule object. """ +import json + import numpy as np import pytest -from . import portal +import qcelemental as qcel + +from . import portal as ptl def test_molecule_constructors(): ### Water Dimer - water_psi = portal.data.get_molecule("water_dimer_minima.psimol") + water_psi = ptl.data.get_molecule("water_dimer_minima.psimol") ele = np.array([8, 1, 1, 8, 1, 1]).reshape(-1, 1) - npwater = np.hstack((ele, water_psi.geometry)) - water_from_np = portal.Molecule.from_data(npwater, name="water dimer", dtype="numpy", frags=[3]) + npwater = np.hstack((ele, water_psi.geometry * qcel.constants.conversion_factor("Bohr", "angstrom"))) + water_from_np = ptl.Molecule.from_data(npwater, name="water dimer", dtype="numpy", frags=[3]) - assert water_psi.compare(water_psi, water_from_np) + assert water_psi.compare(water_from_np) assert water_psi.get_molecular_formula() == "H4O2" # Check the JSON construct/deconstruct - water_from_json = portal.Molecule(**water_psi.json_dict()) - assert water_psi.compare(water_psi, water_from_json) + water_from_json = ptl.Molecule(**water_psi.dict()) + assert water_psi.compare(water_from_json) ### Neon Tetramer - neon_from_psi = portal.data.get_molecule("neon_tetramer.psimol") + neon_from_psi = ptl.data.get_molecule("neon_tetramer.psimol") ele = np.array([10, 10, 10, 10]).reshape(-1, 1) npneon = np.hstack((ele, neon_from_psi.geometry)) - neon_from_np = portal.Molecule.from_data( + neon_from_np = ptl.Molecule.from_data( npneon, name="neon tetramer", dtype="numpy", frags=[1, 2, 3], units="bohr") - assert neon_from_psi.compare(neon_from_psi, neon_from_np) + assert neon_from_psi.compare(neon_from_np) # Check the JSON construct/deconstruct - neon_from_json = portal.Molecule(**neon_from_psi.json_dict()) - assert neon_from_psi.compare(neon_from_psi, neon_from_json) + neon_from_json = ptl.Molecule(**neon_from_psi.dict()) + assert neon_from_psi.compare(neon_from_json) assert neon_from_json.get_molecular_formula() == "Ne4" - assert water_psi.compare(portal.Molecule.from_data(water_psi.to_string("psi4"))) - - -def test_molecule_file_constructors(): - - mol_psi = portal.data.get_molecule("helium_dimer.psimol") - mol_json = portal.data.get_molecule("helium_dimer.json") - mol_np = portal.data.get_molecule("helium_dimer.npy") - - assert mol_psi.compare(mol_json) - assert mol_psi.compare(mol_np) - assert mol_psi.get_molecular_formula() == "He2" + assert water_psi.compare(ptl.Molecule.from_data(water_psi.to_string("psi4"))) def test_water_minima_data(): - mol = portal.data.get_molecule("water_dimer_minima.psimol") + mol = ptl.data.get_molecule("water_dimer_minima.psimol") assert sum(x == y for x, y in zip(mol.symbols, ['O', 'H', 'H', 'O', 'H', 'H'])) == mol.geometry.shape[0] assert mol.molecular_charge == 0 @@ -75,7 +68,7 @@ def test_water_minima_data(): def test_water_minima_fragment(): - mol = portal.data.get_molecule("water_dimer_minima.psimol") + mol = ptl.data.get_molecule("water_dimer_minima.psimol") frag_0 = mol.get_fragment(0, orient=True) frag_1 = mol.get_fragment(1, orient=True) @@ -85,32 +78,32 @@ def test_water_minima_fragment(): frag_0_1 = mol.get_fragment(0, 1) frag_1_0 = mol.get_fragment(1, 0) - assert mol.symbols[:3] == frag_0.symbols + assert np.array_equal(mol.symbols[:3], frag_0.symbols) assert np.allclose(mol.masses[:3], frag_0.masses) - assert mol.symbols == frag_0_1.symbols + assert np.array_equal(mol.symbols, frag_0_1.symbols) assert np.allclose(mol.geometry, frag_0_1.geometry) - assert mol.symbols[3:] + mol.symbols[:3] == frag_1_0.symbols - assert np.allclose(mol.masses[3:] + mol.masses[:3], frag_1_0.masses) + assert np.array_equal(np.hstack((mol.symbols[3:], mol.symbols[:3])), frag_1_0.symbols) + assert np.allclose(np.hstack((mol.masses[3:], mol.masses[:3])), frag_1_0.masses) def test_pretty_print(): - mol = portal.data.get_molecule("water_dimer_minima.psimol") + mol = ptl.data.get_molecule("water_dimer_minima.psimol") assert isinstance(mol.pretty_print(), str) def test_to_string(): - mol = portal.data.get_molecule("water_dimer_minima.psimol") + mol = ptl.data.get_molecule("water_dimer_minima.psimol") assert isinstance(mol.to_string("psi4"), str) def test_water_orient(): # These are identical molecules, should find the correct results - mol = portal.data.get_molecule("water_dimer_stretch.psimol") + mol = ptl.data.get_molecule("water_dimer_stretch.psimol") frag_0 = mol.get_fragment(0, orient=True) frag_1 = mol.get_fragment(1, orient=True) @@ -123,7 +116,7 @@ def test_water_orient(): assert frag_0_1.get_hash() == frag_1_0.get_hash() - mol = portal.data.get_molecule("water_dimer_stretch2.psimol") + mol = ptl.data.get_molecule("water_dimer_stretch2.psimol") frag_0 = mol.get_fragment(0, orient=True) frag_1 = mol.get_fragment(1, orient=True) @@ -141,17 +134,17 @@ def test_water_orient(): def test_molecule_errors(): - mol = portal.data.get_molecule("water_dimer_stretch.psimol") + mol = ptl.data.get_molecule("water_dimer_stretch.psimol") - data = mol.json_dict() + data = mol.dict() data["whatever"] = 5 with pytest.raises(ValueError): - portal.Molecule(**data) + ptl.Molecule(**data) def test_molecule_repeated_hashing(): - mol = portal.Molecule( + mol = ptl.Molecule( **{ 'symbols': ['H', 'O', 'O', 'H'], 'geometry': [ @@ -165,8 +158,8 @@ def test_molecule_repeated_hashing(): h1 = mol.get_hash() assert mol.get_molecular_formula() == "H2O2" - mol2 = portal.Molecule(**mol.json_dict(), orient=False) + mol2 = ptl.Molecule(**json.loads(mol.json()), orient=False) assert h1 == mol2.get_hash() - mol3 = portal.Molecule(**mol2.json_dict(), orient=False) + mol3 = ptl.Molecule(**json.loads(mol2.json()), orient=False) assert h1 == mol3.get_hash() diff --git a/qcfractal/interface/tests/test_visualization.py b/qcfractal/interface/tests/test_visualization.py index 42c0371e4..f564ae034 100644 --- a/qcfractal/interface/tests/test_visualization.py +++ b/qcfractal/interface/tests/test_visualization.py @@ -8,7 +8,6 @@ from . import portal try: - import plotly _has_ploty = True except ModuleNotFoundError: _has_ploty = False diff --git a/qcfractal/interface/util.py b/qcfractal/interface/util.py index 4e854a9d1..f93d153d3 100644 --- a/qcfractal/interface/util.py +++ b/qcfractal/interface/util.py @@ -17,7 +17,6 @@ class AutoDocError(ValueError): Traps this very specific error and not other ValueErrors """ - pass def type_to_string(input_type): diff --git a/qcfractal/postgres_harness.py b/qcfractal/postgres_harness.py index 2b962c4f4..c2b448c13 100644 --- a/qcfractal/postgres_harness.py +++ b/qcfractal/postgres_harness.py @@ -206,15 +206,10 @@ def upgrade(self): The database data won't be deleted. """ - cmd = [shutil.which('alembic'), - '-c', self._alembic_ini, - '-x', 'uri='+self.config.database_uri(), - 'upgrade', 'head'] - - ret = self._run(cmd) + ret = self._run(self.alembic_commands() + ['upgrade', 'head']) if ret['retcode'] != 0: - self.logger(ret) + self.logger(ret["stderr"]) raise ValueError(f"\nFailed to Upgrade the database, make sure to init the database first before being able to upgrade it.\n") return True @@ -284,7 +279,7 @@ def shutdown(self) -> Any: ret = self.pg_ctl(["stop"]) return ret - def initialize_postgres(self): + def initialize_postgres(self) -> None: """Initializes and starts the current postgres instance. """ @@ -321,7 +316,12 @@ def initialize_postgres(self): self.logger("\nDatabase server successfully started!") - def init_database(self): + def alembic_commands(self) -> List[str]: + return [shutil.which('alembic'), + '-c', self._alembic_ini, + '-x', 'uri='+self.config.database_uri()] + + def init_database(self) -> None: # TODO: drop tables @@ -331,10 +331,7 @@ def init_database(self): # update alembic_version table with the current version self.logger(f'\nStamping Database with current version..') - ret = self._run([shutil.which('alembic'), - '-c', self._alembic_ini, - '-x', 'uri='+self.config.database_uri(), - 'stamp', 'head']) + ret = self._run(self.alembic_commands() + ['stamp', 'head']) if ret['retcode'] != 0: self.logger(ret) diff --git a/qcfractal/procedures/procedures.py b/qcfractal/procedures/procedures.py index 9bc01d432..fe0201cd4 100644 --- a/qcfractal/procedures/procedures.py +++ b/qcfractal/procedures/procedures.py @@ -131,7 +131,7 @@ def parse_input(self, data): task = TaskRecord(**{ "spec": { "function": "qcengine.compute", # todo: add defaults in models - "args": [inp.json_dict(), data.meta.program], # todo: json_dict should come from results + "args": [inp.dict(), data.meta.program], "kwargs": {} # todo: add defaults in models }, "parser": "single", @@ -164,7 +164,7 @@ def parse_output(self, result_outputs): rdata["stderr"] = stderr rdata["error"] = error - result.consume_output(rdata) + result._consume_output(rdata) updates.append(result) completed_tasks.append(data["task_id"]) @@ -292,7 +292,7 @@ def parse_input(self, data, duplicate_id="hash_index"): task = TaskRecord(**{ "spec": { "function": "qcengine.compute_procedure", - "args": [inp.json_dict(), data.meta.program], + "args": [inp.dict(), data.meta.program], "kwargs": {} }, "parser": "optimization", diff --git a/qcfractal/queue/handlers.py b/qcfractal/queue/handlers.py index 861761251..c4e6a7117 100644 --- a/qcfractal/queue/handlers.py +++ b/qcfractal/queue/handlers.py @@ -42,7 +42,7 @@ def post(self): response = response_model(**payload) self.logger.info("POST: TaskQueue - Added {} tasks.".format(response.meta.n_inserted)) - self.write(response.json()) + self.write(response) def get(self): """Posts new services to the service queue. @@ -55,7 +55,7 @@ def get(self): response = response_model(**tasks) self.logger.info("GET: TaskQueue - {} pulls.".format(len(response.data))) - self.write(response.json()) + self.write(response) def put(self): """Posts new services to the service queue. @@ -76,7 +76,7 @@ def put(self): response = response_model(data=data, meta={"errors": [], "success": True, "error_description": False}) self.logger.info(f"PUT: TaskQueue - Operation: {body.meta.operation} - {tasks_updated}.") - self.write(response.json()) + self.write(response) class ServiceQueueHandler(APIHandler): @@ -115,7 +115,7 @@ def post(self): response = response_model(**ret) self.logger.info("POST: ServiceQueue - Added {} services.\n".format(response.meta.n_inserted)) - self.write(response.json()) + self.write(response) def get(self): """Gets services from the service queue. @@ -128,7 +128,7 @@ def get(self): response = response_model(**ret) self.logger.info("GET: ServiceQueue - {} pulls.\n".format(len(response.data))) - self.write(response.json()) + self.write(response) class QueueManagerHandler(APIHandler): @@ -236,7 +236,7 @@ def get(self): }, "data": new_tasks }) - self.write(response.json()) + self.write(response) self.logger.info("QueueManager: Served {} tasks.".format(response.meta.n_found)) @@ -267,7 +267,7 @@ def post(self): }, "data": True }) - self.write(response.json()) + self.write(response) self.logger.info("QueueManager: Inserted {} complete tasks.".format(len(body.data))) # Update manager logs @@ -308,6 +308,6 @@ def put(self): raise tornado.web.HTTPError(status_code=400, reason=msg) response = response_model(**{"meta": {}, "data": ret}) - self.write(response.json()) + self.write(response) # Update manager logs diff --git a/qcfractal/queue/managers.py b/qcfractal/queue/managers.py index 7ec29a7f0..04f002163 100644 --- a/qcfractal/queue/managers.py +++ b/qcfractal/queue/managers.py @@ -653,7 +653,7 @@ def test(self, n=1) -> bool: "function": "qcengine.compute", "args": [{ - "molecule": get_molecule("hooh.json").json_dict(), + "molecule": get_molecule("hooh.json").dict(encoding="json"), "driver": "energy", "model": {}, "keywords": {}, diff --git a/qcfractal/services/gridoptimization_service.py b/qcfractal/services/gridoptimization_service.py index 4234942c5..d208ad282 100644 --- a/qcfractal/services/gridoptimization_service.py +++ b/qcfractal/services/gridoptimization_service.py @@ -9,7 +9,7 @@ from .service_util import BaseService, expand_ndimensional_grid from ..extras import get_information -from ..interface.models import GridOptimizationRecord, Molecule, json_encoders +from ..interface.models import GridOptimizationRecord, Molecule __all__ = ["GridOptimizationService"] @@ -21,6 +21,9 @@ class GridOptimizationService(BaseService): program: str = "qcfractal" procedure: str = "gridoptimization" + # Program info + optimization_program: str + # Output output: GridOptimizationRecord @@ -42,9 +45,6 @@ class GridOptimizationService(BaseService): # keyword_template: KeywordSet starting_molecule: Molecule - class Config: - json_encoders = json_encoders - @classmethod def initialize_from_api(cls, storage_socket, logger, service_input, tag=None, priority=None): diff --git a/qcfractal/services/service_util.py b/qcfractal/services/service_util.py index 1b9bc1112..1292d4303 100644 --- a/qcfractal/services/service_util.py +++ b/qcfractal/services/service_util.py @@ -4,12 +4,11 @@ import abc import datetime -import json from typing import Any, Dict, List, Set, Tuple, Optional -from pydantic import BaseModel, validator +from pydantic import validator -from ..interface.models import ObjectId +from ..interface.models import ObjectId, ProtoModel from ..interface.models.rest_models import TaskQueuePOSTBody from ..interface.models.task_models import PriorityEnum from ..procedures import get_procedure_parser @@ -17,7 +16,7 @@ from qcelemental.models import ComputeError -class TaskManager(BaseModel): +class TaskManager(ProtoModel): storage_socket: Any = None logger: Any = None @@ -26,9 +25,9 @@ class TaskManager(BaseModel): tag: Optional[str] = None priority: PriorityEnum = PriorityEnum.HIGH - def dict(self, *args, **kwargs) -> Dict[str, Any]: - kwargs["exclude"] = (kwargs.pop("exclude", None) or set()) | {"storage_socket", "logger"} - return BaseModel.dict(self, *args, **kwargs) + class Config(ProtoModel.Config): + allow_mutation = True + serialize_default_excludes = {"storage_socket", "logger"} def done(self) -> bool: """ @@ -100,7 +99,7 @@ def submit_tasks(self, procedure_type: str, tasks: Dict[str, Any]) -> bool: return True -class BaseService(BaseModel, abc.ABC): +class BaseService(ProtoModel, abc.ABC): # Excluded fields storage_socket: Any @@ -125,12 +124,17 @@ class BaseService(BaseModel, abc.ABC): status: str = "WAITING" error: Optional[ComputeError] = None + tag: Optional[str] = None # Sorting and priority priority: PriorityEnum = PriorityEnum.NORMAL modified_on: datetime.datetime = None created_on: datetime.datetime = None + class Config(ProtoModel.Config): + allow_mutation = True + serialize_default_excludes = {"storage_socket", "logger"} + def __init__(self, **data): dt = datetime.datetime.utcnow() @@ -159,13 +163,6 @@ def initialize_from_api(cls, storage_socket, meta, molecule, tag=None, priority= Initalizes a Service from the API. """ - def dict(self, *args, **kwargs) -> Dict[str, Any]: - kwargs["exclude"] = (kwargs.pop("exclude", None) or set()) | {"storage_socket", "logger"} - return BaseModel.dict(self, *args, **kwargs) - - def json_dict(self, *args, **kwargs) -> str: - return json.loads(self.json(*args, **kwargs)) - @abc.abstractmethod def iterate(self): """ diff --git a/qcfractal/services/torsiondrive_service.py b/qcfractal/services/torsiondrive_service.py index 9cb46b27d..5c4140d3a 100644 --- a/qcfractal/services/torsiondrive_service.py +++ b/qcfractal/services/torsiondrive_service.py @@ -9,7 +9,7 @@ import numpy as np from .service_util import BaseService, TaskManager -from ..interface.models import TorsionDriveRecord, json_encoders +from ..interface.models import TorsionDriveRecord from ..extras import find_module @@ -29,6 +29,9 @@ class TorsionDriveService(BaseService): program: str = "torsiondrive" procedure: str = "torsiondrive" + # Program info + optimization_program: str + # Output output: TorsionDriveRecord = None # added default @@ -45,9 +48,6 @@ class TorsionDriveService(BaseService): optimization_template: str molecule_template: str - class Config: - json_encoders = json_encoders - @classmethod def initialize_from_api(cls, storage_socket, logger, service_input, tag=None, priority=None): _check_td() @@ -70,7 +70,7 @@ def initialize_from_api(cls, storage_socket, logger, service_input, tag=None, pr meta = {"output": output} # Remove identity info from molecule template - molecule_template = copy.deepcopy(service_input.initial_molecule[0].json_dict()) + molecule_template = copy.deepcopy(service_input.initial_molecule[0].dict(encoding="json")) molecule_template.pop("id", None) molecule_template.pop("identifiers", None) meta["molecule_template"] = json.dumps(molecule_template) diff --git a/qcfractal/storage_sockets/sql_models.py b/qcfractal/storage_sockets/sql_models.py index b4d931090..d372006c0 100644 --- a/qcfractal/storage_sockets/sql_models.py +++ b/qcfractal/storage_sockets/sql_models.py @@ -2,6 +2,8 @@ # from sqlalchemy.ext.declarative import declarative_base from sqlalchemy import (Column, Integer, String, DateTime, Boolean, ForeignKey, JSON, Enum, Float, Binary, Table, inspect, Index, UniqueConstraint) +from sqlalchemy.dialects.postgresql import BYTEA +from sqlalchemy.types import TypeDecorator from sqlalchemy.orm import relationship, object_session, column_property from qcfractal.interface.models.records import RecordStatusEnum, DriverEnum from qcfractal.interface.models.task_models import TaskStatusEnum, ManagerStatusEnum, PriorityEnum @@ -12,8 +14,20 @@ # from sqlalchemy.ext.associationproxy import association_proxy from sqlalchemy.dialects.postgresql import aggregate_order_by +from qcelemental.util import msgpackext_dumps, msgpackext_loads + # Base = declarative_base() +class MsgpackExt(TypeDecorator): + '''Converts JSON-like data to msgpack with full NumPy Array support.''' + + impl = BYTEA + + def process_bind_param(self, value, dialect): + return msgpackext_dumps(value) + + def process_result_value(self, value, dialect): + return msgpackext_loads(value) @as_declarative() class Base: @@ -213,8 +227,8 @@ class MoleculeORM(Base): # Required data schema_name = Column(String) schema_version = Column(Integer, default=2) - symbols = Column(JSON) # Column(ARRAY(String)) - geometry = Column(JSON) # Column(ARRAY(Float)) + symbols = Column(MsgpackExt) + geometry = Column(MsgpackExt) # Molecule data name = Column(String, default="") @@ -224,15 +238,15 @@ class MoleculeORM(Base): molecular_multiplicity = Column(Integer, default=1) # Atom data - masses = Column(JSON) # Column(ARRAY(Float)) - real = Column(JSON) # Column(ARRAY(Boolean)) - atom_labels = Column(JSON) # Column(ARRAY(String)) - atomic_numbers = Column(JSON) # Column(ARRAY(Integer)) - mass_numbers = Column(JSON) # Column(ARRAY(Integer)) + masses = Column(MsgpackExt) + real = Column(MsgpackExt) + atom_labels = Column(MsgpackExt) + atomic_numbers = Column(MsgpackExt) + mass_numbers = Column(MsgpackExt) # Fragment and connection data connectivity = Column(JSON) - fragments = Column(JSON) + fragments = Column(MsgpackExt) fragment_charges = Column(JSON) # Column(ARRAY(Float)) fragment_multiplicities = Column(JSON) # Column(ARRAY(Integer)) @@ -308,7 +322,7 @@ class BaseResultORM(Base): id = Column(Integer, primary_key=True) # ondelete="SET NULL": when manger is deleted, set this field to None manager_name = Column(String, ForeignKey('queue_manager.name', ondelete="SET NULL"), - nullable=True,) + nullable=True) hash_index = Column(String) # TODO procedure = Column(String(100)) # TODO: may remove @@ -316,7 +330,7 @@ class BaseResultORM(Base): version = Column(Integer) # Extra fields - extras = Column(JSON) + extras = Column(MsgpackExt) stdout = Column(Integer, ForeignKey('kv_store.id')) stdout_obj = relationship(KVStoreORM, lazy='noload', @@ -381,7 +395,7 @@ class ResultORM(BaseResultORM): keywords_obj = relationship(KeywordsORM, lazy='select') # output related - return_result = Column(JSON) # one of 3 types + return_result = Column(MsgpackExt) properties = Column(JSON) # TODO: may use JSONB in the future # TODO: Do they still exist? @@ -735,7 +749,7 @@ class TaskQueueORM(Base): id = Column(Integer, primary_key=True) - spec = Column(JSON) + spec = Column(MsgpackExt) # others tag = Column(String, default=None) @@ -802,7 +816,7 @@ class ServiceQueueORM(Base): created_on = Column(DateTime, default=datetime.datetime.utcnow) modified_on = Column(DateTime, default=datetime.datetime.utcnow) - extra = Column(JSON) + extra = Column(MsgpackExt) __table_args__ = ( Index('ix_service_queue_status', "status"), diff --git a/qcfractal/storage_sockets/sqlalchemy_socket.py b/qcfractal/storage_sockets/sqlalchemy_socket.py index 2d08fe47e..66f280a1e 100644 --- a/qcfractal/storage_sockets/sqlalchemy_socket.py +++ b/qcfractal/storage_sockets/sqlalchemy_socket.py @@ -441,7 +441,10 @@ def add_molecules(self, molecules: List[Molecule]): with self.session_scope() as session: for dmol in molecules: - mol_dict = dmol.json_dict(exclude={"id"}) + if dmol.validated is False: + dmol = Molecule(**dmol.dicT(), validate=True) + + mol_dict = dmol.dict(exclude={"id", "validated"}) # TODO: can set them as defaults in the sql_models, not here mol_dict["fix_com"] = True @@ -503,7 +506,7 @@ def get_molecules(self, id=None, molecule_hash=None, molecular_formula=None, lim # ret["meta"]["errors"].extend(errors) - data = [Molecule(**d, validate=False) for d in rdata] + data = [Molecule(**d, validate=False, validated=True) for d in rdata] return {'meta': meta, 'data': data} @@ -559,7 +562,7 @@ def add_keywords(self, keyword_sets: List[KeywordSet]): with self.session_scope() as session: for kw in keyword_sets: - kw_dict = kw.json_dict(exclude={"id"}) + kw_dict = kw.dict(exclude={"id"}) # search by index keywords not by all keys, much faster found = session.query(KeywordsORM).filter_by(hash_index=kw_dict['hash_index']).first() @@ -878,7 +881,7 @@ def add_results(self, record_list: List[ResultRecord]): molecule=result.molecule) if get_count_fast(doc) == 0: - doc = ResultORM(**result.json_dict(exclude={"id"})) + doc = ResultORM(**result.dict(exclude={"id"})) session.add(doc) session.commit() # TODO: faster if done in bulk result_ids.append(str(doc.id)) @@ -922,7 +925,7 @@ def update_results(self, record_list: List[ResultRecord]): result_db = session.query(ResultORM).filter_by(id=result.id).first() - data = result.json_dict(exclude={'id'}) + data = result.dict(exclude={'id'}) for attr, val in data.items(): setattr(result_db, attr, val) @@ -1120,7 +1123,7 @@ def add_procedures(self, record_list: List['BaseRecord']): doc = session.query(procedure_class).filter_by(hash_index=procedure.hash_index) if get_count_fast(doc) == 0: - data = procedure.json_dict(exclude={"id"}) + data = procedure.dict(exclude={"id"}) proc_db = procedure_class(**data) session.add(proc_db) session.commit() @@ -1240,7 +1243,7 @@ def update_procedures(self, records_list: List['BaseRecord']): proc_db = session.query(className).filter_by(id=procedure.id).first() - data = procedure.json_dict(exclude={'id'}) + data = procedure.dict(exclude={'id'}) proc_db.update_relations(**data) for attr, val in data.items(): @@ -1330,8 +1333,9 @@ def add_services(self, service_list: List['BaseService']): service.procedure_id = proc_id if doc.count() == 0: - doc = ServiceQueueORM(**service.json_dict(include=set(ServiceQueueORM.__dict__.keys()))) - doc.extra = service.json_dict(exclude=set(ServiceQueueORM.__dict__.keys())) + doc = ServiceQueueORM(**service.dict(include=set(ServiceQueueORM.__dict__.keys()))) + doc.extra = service.dict(exclude=set(ServiceQueueORM.__dict__.keys())) + doc.priority = doc.priority.value # Must be an integer for sorting session.add(doc) session.commit() # TODO procedure_ids.append(proc_id) @@ -1421,8 +1425,8 @@ def update_services(self, records_list: List["BaseService"]) -> int: doc_db = session.query(ServiceQueueORM).filter_by(id=service.id).first() - data = service.json_dict(include=set(ServiceQueueORM.__dict__.keys())) - data['extra'] = service.json_dict(exclude=set(ServiceQueueORM.__dict__.keys())) + data = service.dict(include=set(ServiceQueueORM.__dict__.keys())) + data['extra'] = service.dict(exclude=set(ServiceQueueORM.__dict__.keys())) data['id'] = int(data['id']) for attr, val in data.items(): @@ -1448,7 +1452,7 @@ def services_completed(self, records_list: List["BaseService"]) -> int: with self.session_scope() as session: procedure = service.output - procedure.id = service.procedure_id + procedure.__dict__["id"] = service.procedure_id self.update_procedures([procedure]) session.query(ServiceQueueORM)\ @@ -1497,11 +1501,12 @@ def queue_submit(self, data: List[TaskRecord]): with self.session_scope() as session: for task_num, record in enumerate(data): try: - task_dict = record.json_dict(exclude={"id"}) + task_dict = record.dict(exclude={"id"}) # # for compatibility with mongoengine # if isinstance(task_dict['base_result'], dict): # task_dict['base_result'] = task_dict['base_result']['id'] task = TaskQueueORM(**task_dict) + task.priority = task.priority.value # Must be an integer for sorting session.add(task) session.commit() results.append(str(task.id)) @@ -1842,7 +1847,8 @@ def _copy_task_to_queue(self, record_list: List[TaskRecord]): doc = session.query(TaskQueueORM).filter_by(base_result_id=task.base_result.id) if get_count_fast(doc) == 0: - doc = TaskQueueORM(**task.json_dict(exclude={"id"})) + doc = TaskQueueORM(**task.dict(exclude={"id"})) + doc.priority = doc.priority.value if isinstance(doc.error, dict): doc.error = json.dumps(doc.error) diff --git a/qcfractal/testing.py b/qcfractal/testing.py index f06fab3d7..90839df7a 100644 --- a/qcfractal/testing.py +++ b/qcfractal/testing.py @@ -334,7 +334,9 @@ def postgres_server(): def reset_server_database(server): """Resets the server database for testing. """ - # server.storage._clear_db(server.storage._project_name) + if "QCFRACTAL_RESET_TESTING_DB" in os.environ: + server.storage._clear_db(server.storage._project_name) + server.storage._delete_DB_data(server.storage._project_name) diff --git a/qcfractal/tests/test_adapaters.py b/qcfractal/tests/test_adapaters.py index 954483f62..84edbeaff 100644 --- a/qcfractal/tests/test_adapaters.py +++ b/qcfractal/tests/test_adapaters.py @@ -13,11 +13,13 @@ @testing.using_rdkit def test_adapter_single(managed_compute_server): client, server, manager = managed_compute_server + reset_server_database(server) + manager.heartbeat() # Re-register with server after clear # Add compute hooh = ptl.data.get_molecule("hooh.json") - ret = client.add_compute("rdkit", "UFF", "", "energy", None, [hooh.json_dict()], tag="other") + ret = client.add_compute("rdkit", "UFF", "", "energy", None, [hooh], tag="other") # Force manager compute and get results manager.await_results() @@ -47,7 +49,7 @@ def test_keyword_args_passing(adapter_client_fixture, cores_per_task, memory_per "function": "qcengine.compute", "args": [{ - "molecule": ptl.data.get_molecule("hooh.json").json_dict(), + "molecule": ptl.data.get_molecule("hooh.json"), "driver": "energy", "model": { "method": "HF", @@ -93,10 +95,12 @@ def test_keyword_args_passing(adapter_client_fixture, cores_per_task, memory_per @testing.using_rdkit def test_adapter_error_message(managed_compute_server): client, server, manager = managed_compute_server + reset_server_database(server) + manager.heartbeat() # Re-register with server after clear # HOOH without connectivity, RDKit should fail - hooh = ptl.data.get_molecule("hooh.json").json_dict() + hooh = ptl.data.get_molecule("hooh.json").dict() del hooh["connectivity"] mol_ret = client.add_molecules([hooh]) @@ -124,10 +128,12 @@ def test_adapter_error_message(managed_compute_server): @testing.using_rdkit def test_adapter_raised_error(managed_compute_server): client, server, manager = managed_compute_server + reset_server_database(server) + manager.heartbeat() # Re-register with server after clear # HOOH without connectivity, RDKit should fail - hooh = ptl.data.get_molecule("hooh.json").json_dict() + hooh = ptl.data.get_molecule("hooh.json") ret = client.add_compute("rdkit", "UFF", "", "hessian", None, hooh) queue_id = ret.submitted[0] diff --git a/qcfractal/tests/test_client.py b/qcfractal/tests/test_client.py index b6dfc1e79..1312871a5 100644 --- a/qcfractal/tests/test_client.py +++ b/qcfractal/tests/test_client.py @@ -3,6 +3,7 @@ """ import pytest +import numpy as np import qcfractal.interface as ptl from qcfractal.testing import test_server @@ -10,12 +11,16 @@ # All tests should import test_server, but not use it # Make PyTest aware that this module needs the server +valid_encodings = ["json", "json-ext", "msgpack-ext"] -def test_client_molecule(test_server): +@pytest.mark.parametrize("encoding", valid_encodings) +def test_client_molecule(test_server, encoding): client = ptl.FractalClient(test_server) + client._set_encoding(encoding) water = ptl.data.get_molecule("water_dimer_minima.psimol") + water.geometry[:] += np.random.random(water.geometry.shape) # Test add ret = client.add_molecules([water]) @@ -26,14 +31,16 @@ def test_client_molecule(test_server): # Test molecular_formula get get_mol = client.query_molecules(molecular_formula="H4O2") - assert water.compare(get_mol[0]) + assert len(get_mol) -def test_client_keywords(test_server): +@pytest.mark.parametrize("encoding", valid_encodings) +def test_client_keywords(test_server, encoding): client = ptl.FractalClient(test_server) + client._set_encoding(encoding) - opt = ptl.models.KeywordSet(values={"one": "fish", "two": "fish"}) + opt = ptl.models.KeywordSet(values={"one": "fish", "two": encoding}) # Test add ret = client.add_keywords([opt]) @@ -46,13 +53,16 @@ def test_client_keywords(test_server): assert opt == get_kw[0] -def test_client_duplicate_keywords(test_server): +@pytest.mark.parametrize("encoding", valid_encodings) +def test_client_duplicate_keywords(test_server, encoding): client = ptl.FractalClient(test_server) + client._set_encoding(encoding) - opt1 = ptl.models.KeywordSet(values={"key": 1}) - opt2 = ptl.models.KeywordSet(values={"key": 2}) - opt3 = ptl.models.KeywordSet(values={"key": 3}) + key_name = f"key-{encoding}" + opt1 = ptl.models.KeywordSet(values={key_name: 1}) + opt2 = ptl.models.KeywordSet(values={key_name: 2}) + opt3 = ptl.models.KeywordSet(values={key_name: 3}) # Test add ret = client.add_keywords([opt1, opt1]) @@ -67,9 +77,11 @@ def test_client_duplicate_keywords(test_server): assert len(ret3) == 3 assert ret3[1] == ret[0] -def test_empty_query(test_server): +@pytest.mark.parametrize("encoding", valid_encodings) +def test_empty_query(test_server, encoding): client = ptl.FractalClient(test_server) + client._set_encoding(encoding) with pytest.raises(IOError) as error: client.query_procedures(limit=1) @@ -77,11 +89,14 @@ def test_empty_query(test_server): assert "ID is required" in str(error.value) -def test_collection_portal(test_server): +@pytest.mark.parametrize("encoding", valid_encodings) +def test_collection_portal(test_server, encoding): - db = {"collection": "torsiondrive", "name": "Torsion123", "something": "else", "array": ["54321"]} + db_name = f"Torsion123-{encoding}" + db = {"collection": "torsiondrive", "name": db_name, "something": "else", "array": ["12345"]} client = ptl.FractalClient(test_server) + client._set_encoding(encoding) # Test add _ = client.add_collection(db) @@ -106,12 +121,12 @@ def test_collection_portal(test_server): # Test that we cannot use a local key db['id'] = 'local' - db['array'] = ["12345"] + db['array'] = ["6789"] with pytest.raises(KeyError): _ = client.add_collection(db, overwrite=True) # Finally test that we can overwrite db['id'] = db_id - _ = client.add_collection(db, overwrite=True) + r = client.add_collection(db, overwrite=True) get_db = client.get_collection(db["collection"], db["name"], full_return=True) - assert get_db.data[0]['array'] == ["12345"] + assert get_db.data[0]['array'] == ["6789"] diff --git a/qcfractal/tests/test_collections.py b/qcfractal/tests/test_collections.py index 41b07c862..bdf2c86c8 100644 --- a/qcfractal/tests/test_collections.py +++ b/qcfractal/tests/test_collections.py @@ -186,7 +186,7 @@ def test_compute_reactiondataset_regression(fractal_compute_server): "units": "hartree" } ds.add_contributed_values(contrib) - ds.data.default_benchmark = "Benchmark" + ds.set_default_benchmark("Benchmark") # Save the DB and overwrite the result, reacquire via client r = ds.save() @@ -262,125 +262,6 @@ def test_compute_reactiondataset_keywords(fractal_compute_server): assert kw.values["scf_type"] == "df" -@mark_slow -@testing.using_torsiondrive -@testing.using_geometric -@testing.using_rdkit -def test_compute_openffworkflow(fractal_compute_server): - """ - Tests the openffworkflow collection - """ - - # Obtain a client and build a BioFragment - client = ptl.FractalClient(fractal_compute_server) - - openff_workflow_options = { - # Blank Fragmenter options - "enumerate_states": {}, - "enumerate_fragments": {}, - "torsiondrive_input": {}, - - # TorsionDriveRecord, Geometric, and QC options - "torsiondrive_static_options": { - "keywords": {}, - "optimization_spec": { - "program": "geometric", - "keywords": { - "coordsys": "tric", - } - }, - "qc_spec": { - "driver": "gradient", - "method": "UFF", - "basis": "", - "keywords": None, - "program": "rdkit", - } - }, - "optimization_static_options": { - "program": "geometric", - "keywords": { - "coordsys": "tric", - }, - "qc_spec": { - "driver": "gradient", - "method": "UFF", - "basis": "", - "keywords": None, - "program": "rdkit", - } - } - } - wf = ptl.collections.OpenFFWorkflow("Workflow1", client=client, **openff_workflow_options) - - # # Add a fragment and wait for the compute - hooh = ptl.data.get_molecule("hooh.json") - fragment_input = { - "label1": { - "type": "torsiondrive_input", - "initial_molecule": hooh.json_dict(), - "grid_spacing": [90], - "dihedrals": [[0, 1, 2, 3]], - }, - } - wf.add_fragment("HOOH", fragment_input) - assert set(wf.list_fragments()) == {"HOOH"} - fractal_compute_server.await_services(max_iter=5) - - final_energies = wf.list_final_energies() - assert final_energies.keys() == {"HOOH"} - assert final_energies["HOOH"].keys() == {"label1"} - - final_molecules = wf.list_final_molecules() - assert final_molecules.keys() == {"HOOH"} - assert final_molecules["HOOH"].keys() == {"label1"} - - optimization_input = { - "label2": { - "type": "optimization_input", - "initial_molecule": hooh.json_dict(), - "constraints": { - 'set': [{ - "type": 'dihedral', - "indices": [0, 1, 2, 3], - "value": 0 - }] - } - } - } - - wf.add_fragment("HOOH", optimization_input) - fractal_compute_server.await_services(max_iter=5) - - final_energies = wf.list_final_energies() - assert final_energies["HOOH"].keys() == {"label1", "label2"} - assert pytest.approx(0.00259754, 1.e-4) == final_energies["HOOH"]["label2"] - - final_molecules = wf.list_final_molecules() - assert final_molecules.keys() == {"HOOH"} - assert final_molecules["HOOH"].keys() == {"label1", "label2"} - - # Add a second fragment - butane = ptl.data.get_molecule("butane.json") - butane_id = butane.identifiers.canonical_isomeric_explicit_hydrogen_mapped_smiles - - fragment_input = { - "label1": { - "type": "torsiondrive_input", - "initial_molecule": butane.json_dict(), - "grid_spacing": [90], - "dihedrals": [[0, 2, 3, 1]], - }, - } - wf.add_fragment(butane_id, fragment_input) - assert set(wf.list_fragments()) == {butane_id, "HOOH"} - - final_energies = wf.list_final_energies() - assert final_energies.keys() == {butane_id, "HOOH"} - assert final_energies[butane_id].keys() == {"label1"} - assert final_energies[butane_id]["label1"] == {} - - def test_generic_collection(fractal_compute_server): client = ptl.FractalClient(fractal_compute_server) diff --git a/qcfractal/tests/test_compute.py b/qcfractal/tests/test_compute.py index b5dab47af..6f86e5ad9 100644 --- a/qcfractal/tests/test_compute.py +++ b/qcfractal/tests/test_compute.py @@ -114,9 +114,7 @@ def test_queue_error(fractal_compute_server): client = ptl.FractalClient(fractal_compute_server) - hooh = ptl.data.get_molecule("hooh.json").json_dict() - del hooh["connectivity"] - + hooh = ptl.data.get_molecule("hooh.json").copy(update={"connectivity": None}) compute_ret = client.add_compute("rdkit", "UFF", "", "energy", None, hooh) # Pull out a special iteration on the queue manager @@ -148,7 +146,7 @@ def test_queue_duplicate_compute(fractal_compute_server): client = ptl.FractalClient(fractal_compute_server) - hooh = ptl.data.get_molecule("hooh.json").json_dict() + hooh = ptl.data.get_molecule("hooh.json") mol_ret = client.add_molecules([hooh]) ret = client.add_compute("rdkit", "UFF", "", "energy", None, mol_ret) @@ -216,7 +214,7 @@ def test_queue_duplicate_procedure(fractal_compute_server): client = ptl.FractalClient(fractal_compute_server) - hooh = ptl.data.get_molecule("hooh.json").json_dict() + hooh = ptl.data.get_molecule("hooh.json") mol_ret = client.add_molecules([hooh]) geometric_options = { diff --git a/qcfractal/tests/test_managers.py b/qcfractal/tests/test_managers.py index c02ea076a..1e603caba 100644 --- a/qcfractal/tests/test_managers.py +++ b/qcfractal/tests/test_managers.py @@ -52,7 +52,7 @@ def test_queue_manager_single_tags(compute_adapter_fixture): # Add compute hooh = ptl.data.get_molecule("hooh.json") - ret = client.add_compute("rdkit", "UFF", "", "energy", None, [hooh.json_dict()], tag="other") + ret = client.add_compute("rdkit", "UFF", "", "energy", None, [hooh], tag="other") # Computer with the incorrect tag manager_stuff.await_results() @@ -87,7 +87,7 @@ def test_queue_manager_statistics(compute_adapter_fixture, caplog): manager = queue.QueueManager(client, adapter, verbose=True) hooh = ptl.data.get_molecule("hooh.json") - client.add_compute("rdkit", "UFF", "", "energy", None, [hooh.json_dict()], tag="other") + client.add_compute("rdkit", "UFF", "", "energy", None, [hooh], tag="other") # Set capture level with caplog_handler_at_level(caplog, logging.INFO): @@ -118,7 +118,7 @@ def test_queue_manager_shutdown(compute_adapter_fixture): manager = queue.QueueManager(client, adapter) hooh = ptl.data.get_molecule("hooh.json") - client.add_compute("rdkit", "UFF", "", "energy", None, [hooh.json_dict()], tag="other") + client.add_compute("rdkit", "UFF", "", "energy", None, [hooh], tag="other") # Pull job to manager and shutdown manager.update() @@ -145,7 +145,7 @@ def test_queue_manager_server_delay(compute_adapter_fixture): manager = queue.QueueManager(client, adapter, server_error_retries=1) hooh = ptl.data.get_molecule("hooh.json") - client.add_compute("rdkit", "UFF", "", "energy", None, [hooh.json_dict()], tag="other") + client.add_compute("rdkit", "UFF", "", "energy", None, [hooh], tag="other") # Pull job to manager and shutdown manager.update() diff --git a/qcfractal/tests/test_procedures.py b/qcfractal/tests/test_procedures.py index e04e47a94..02e069a5e 100644 --- a/qcfractal/tests/test_procedures.py +++ b/qcfractal/tests/test_procedures.py @@ -4,6 +4,7 @@ import pytest import requests +import numpy as np import qcfractal.interface as ptl from qcfractal import testing @@ -36,7 +37,7 @@ def test_compute_queue_stack(fractal_compute_server): "keywords": kw_id, "program": "psi4", }, - "data": [hydrogen_mol_id, helium.json_dict()], + "data": [hydrogen_mol_id, helium], } # Ask the server to compute a new computation @@ -119,7 +120,7 @@ def test_procedure_optimization(fractal_compute_server): traj = opt_result.get_trajectory() assert len(traj) == len(opt_result.energies) - assert opt_result.get_final_molecule().symbols == ["H", "H"] + assert np.array_equal(opt_result.get_final_molecule().symbols, ["H", "H"]) # Check individual elements for ind in range(len(opt_result.trajectory)): diff --git a/qcfractal/tests/test_server.py b/qcfractal/tests/test_server.py index 060e739bb..8bcbcacaf 100644 --- a/qcfractal/tests/test_server.py +++ b/qcfractal/tests/test_server.py @@ -3,6 +3,7 @@ """ import os +import json import threading import pytest @@ -29,8 +30,9 @@ def test_molecule_socket(test_server): mol_api_addr = test_server.get_address("molecule") water = ptl.data.get_molecule("water_dimer_minima.psimol") + water_json = json.loads(water.json()) # Add a molecule - r = requests.post(mol_api_addr, json={"meta": {}, "data": [water.json_dict()]}) + r = requests.post(mol_api_addr, json={"meta": {}, "data": [water_json]}) assert r.status_code == 200 pdata = r.json() diff --git a/qcfractal/tests/test_services.py b/qcfractal/tests/test_services.py index 4b3a257f9..30fac3419 100644 --- a/qcfractal/tests/test_services.py +++ b/qcfractal/tests/test_services.py @@ -111,7 +111,7 @@ def test_service_torsiondrive_duplicates(torsiondrive_fixture): # Augment the input for torsion drive to yield a new hash procedure hash, # but not a new task set - id2 = spin_up_test(keywords={"meaningless_entry_to_change_hash": "Waffles!"}).ids[0] + id2 = spin_up_test(keywords={"energy_upper_limit": 1000}).ids[0] assert id1 != id2 procedures = client.query_procedures(id=[id1, id2]) diff --git a/qcfractal/tests/test_sqlalchemy.py b/qcfractal/tests/test_sqlalchemy.py index 7265a7e18..dda8d384d 100644 --- a/qcfractal/tests/test_sqlalchemy.py +++ b/qcfractal/tests/test_sqlalchemy.py @@ -107,7 +107,7 @@ def test_molecule_sql(storage_socket, session): assert water_mol.molecular_formula == "H4O2" assert water_mol.molecular_charge == 0 - # print(water_mol.json_dict()) + # print(water_mol.dict()) # # Query with fields in the model result_list = session.query(MoleculeORM).filter_by(molecular_formula="H4O2").all() @@ -148,6 +148,8 @@ def test_services(storage_socket, session): "hash_index" : "123", "status": "COMPLETE", + "optimization_program": "gaussian", + # extra fields "torsiondrive_state": {}, @@ -167,9 +169,10 @@ def test_services(storage_socket, session): service_pydantic = TorsionDriveService(**service_data) - doc = ServiceQueueORM(**service_pydantic.json_dict(include=set(ServiceQueueORM.__dict__.keys()))) - doc.extra = service_pydantic.json_dict(exclude=set(ServiceQueueORM.__dict__.keys())) + doc = ServiceQueueORM(**service_pydantic.dict(include=set(ServiceQueueORM.__dict__.keys()))) + doc.extra = service_pydantic.dict(exclude=set(ServiceQueueORM.__dict__.keys())) doc.procedure_id = procedure.id + doc.priority = doc.priority.value # Special case where we need the value not the enum session.add(doc) session.commit() diff --git a/qcfractal/tests/test_storage.py b/qcfractal/tests/test_storage.py index 0bec04ce8..17825baf7 100644 --- a/qcfractal/tests/test_storage.py +++ b/qcfractal/tests/test_storage.py @@ -895,6 +895,8 @@ def test_services_sql(storage_results): "hash_index" : "123", "status": TaskStatusEnum.waiting, + "optimization_program": "gaussian", + # extra fields "torsiondrive_state": {}, diff --git a/qcfractal/web_handlers.py b/qcfractal/web_handlers.py index 30495a5ce..b8b277022 100644 --- a/qcfractal/web_handlers.py +++ b/qcfractal/web_handlers.py @@ -6,9 +6,15 @@ import tornado.web from pydantic import ValidationError +from qcelemental.util import serialize, deserialize from .interface.models.rest_models import rest_model +_valid_encodings = { + "application/json": "json", + "application/json-ext": "json-ext", + "application/msgpack-ext": "msgpack-ext", +} class APIHandler(tornado.web.RequestHandler): """ @@ -24,7 +30,17 @@ def initialize(self, **objects): Initializes the request to JSON, adds objects, and logging. """ - self.set_header("Content-Type", "application/json") + + self.content_type = "Not Provided" + try: + self.content_type = self.request.headers["Content-Type"] + self.encoding = _valid_encodings[self.content_type] + except KeyError: + raise tornado.web.HTTPError(status_code=401, reason=f"Did not understand 'Content-Type': {self.content_type}") + + # Always reply in the format sent + self.set_header("Content-Type", self.content_type) + self.objects = objects self.storage = self.objects["storage_socket"] self.logger = objects["logger"] @@ -35,7 +51,10 @@ def prepare(self): if self._required_auth: self.authenticate(self._required_auth) - self.json = json.loads(self.request.body.decode("UTF-8")) + try: + self.data = deserialize(self.request.body, self.encoding) + except: + raise tornado.web.HTTPError(status_code=401, reason="Could not deserialize body.") def on_finish(self): @@ -43,7 +62,7 @@ def on_finish(self): if self.api_logger and self.request.method == 'GET' \ and self.request.uri not in exclude_uris: - extra_params = self.json.copy() + extra_params = self.data.copy() if self._logging_param_counts: for key in self._logging_param_counts: if extra_params["data"].get(key, None): @@ -86,10 +105,16 @@ def authenticate(self, permission): def parse_bodymodel(self, model): try: - return model.parse_raw(self.request.body) + return model(**self.data) except ValidationError as exc: raise tornado.web.HTTPError(status_code=401, reason="Invalid REST") + def write(self, data): + if not isinstance(data, (str, bytes)): + data = serialize(data, self.encoding) + + return super().write(data) + class InformationHandler(APIHandler): """ @@ -142,7 +167,7 @@ def get(self): ret = response_model(**ret) self.logger.info("GET: KVStore - {} pulls.".format(len(ret.data))) - self.write(ret.json()) + self.write(ret) class MoleculeHandler(APIHandler): @@ -181,7 +206,7 @@ def get(self): ret = response_model(**molecules) self.logger.info("GET: Molecule - {} pulls.".format(len(ret.data))) - self.write(ret.json()) + self.write(ret) def post(self): """ @@ -211,7 +236,7 @@ def post(self): response = response_model(**ret) self.logger.info("POST: Molecule - {} inserted.".format(response.meta.n_inserted)) - self.write(response.json()) + self.write(response) class KeywordHandler(APIHandler): @@ -231,7 +256,7 @@ def get(self): response = response_model(**ret) self.logger.info("GET: Keywords - {} pulls.".format(len(response.data))) - self.write(response.json()) + self.write(response) def post(self): self.authenticate("write") @@ -243,7 +268,7 @@ def post(self): response = response_model(**ret) self.logger.info("POST: Keywords - {} inserted.".format(response.meta.n_inserted)) - self.write(response.json()) + self.write(response) class CollectionHandler(APIHandler): @@ -262,7 +287,7 @@ def get(self): response = response_model(**cols) self.logger.info("GET: Collections - {} pulls.".format(len(response.data))) - self.write(response.json()) + self.write(response) def post(self): self.authenticate("write") @@ -274,7 +299,7 @@ def post(self): response = response_model(**ret) self.logger.info("POST: Collections - {} inserted.".format(response.meta.n_inserted)) - self.write(response.json()) + self.write(response) class ResultHandler(APIHandler): @@ -294,7 +319,7 @@ def get(self): result = response_model(**ret) self.logger.info("GET: Results - {} pulls.".format(len(result.data))) - self.write(result.json()) + self.write(result) class ProcedureHandler(APIHandler): @@ -318,4 +343,4 @@ def get(self): response = response_model(**ret) self.logger.info("GET: Procedures - {} pulls.".format(len(response.data))) - self.write(response.json()) + self.write(response) diff --git a/setup.cfg b/setup.cfg index 470be4fc4..541111a0c 100644 --- a/setup.cfg +++ b/setup.cfg @@ -7,6 +7,7 @@ omit = */tests/* qcfractal/_version.py qcfractal/dashboard/* + qcfractal/alembic/* [isort] line_length=120 diff --git a/setup.py b/setup.py index be3aa20e1..dfab3e083 100644 --- a/setup.py +++ b/setup.py @@ -26,11 +26,14 @@ 'bcrypt', 'cryptography', 'numpy>=1.7', + 'msgpack>=0.6.1', 'pandas', 'pydantic>=0.30.1', + 'msgpack>=0.6.1', + 'pyyaml>=5.1', 'requests', 'tornado', - 'pyyaml>=5.1', + 'tqdm', # Database 'sqlalchemy>=1.3', @@ -38,8 +41,8 @@ 'alembic', # QCArchive depends - 'qcengine>=0.8.2', - 'qcelemental>=0.5.0', + 'qcengine>=0.9.0', + 'qcelemental>=0.6.0', # Testing 'pytest',