diff --git a/mipdb/database.py b/mipdb/database.py index f1a2ac5..341390c 100644 --- a/mipdb/database.py +++ b/mipdb/database.py @@ -433,6 +433,13 @@ def get_row_count(self, table): res = self.execute(f"select COUNT(*) from {table}").fetchone() return res[0] + def get_column_distinct(self, column, table): + datasets = list(self.execute( + f"SELECT DISTINCT({column}) FROM {table};" + )) + datasets = [dataset[0] for dataset in datasets] + return datasets + def get_dataset(self, dataset_id, columns): columns_query = ", ".join(columns) if columns else "*" diff --git a/mipdb/tables.py b/mipdb/tables.py index a10bc52..ba09bb3 100644 --- a/mipdb/tables.py +++ b/mipdb/tables.py @@ -1,14 +1,15 @@ -import os from abc import ABC, abstractmethod import json from enum import Enum from typing import Union, List import sqlalchemy as sql -from sqlalchemy import ForeignKey, MetaData +from sqlalchemy import ForeignKey, Integer, MetaData from sqlalchemy.ext.compiler import compiles from mipdb.database import DataBase, Connection, credentials_from_config +from mipdb.data_frame import DATASET_COLUMN_NAME +from mipdb.database import DataBase, Connection from mipdb.database import METADATA_SCHEMA from mipdb.database import METADATA_TABLE from mipdb.dataelements import CommonDataElement @@ -70,6 +71,9 @@ def delete(self, db: Union[DataBase, Connection]): def get_row_count(self, db): return db.get_row_count(self.table.fullname) + def get_column_distinct(self, column, db): + return db.get_column_distinct(column, self.table.fullname) + def drop(self, db: Union[DataBase, Connection]): db.drop_table(self._table) @@ -375,7 +379,7 @@ def validate_csv(self, csv_path, cdes_with_min_max, cdes_with_enumerations, db): break validated_datasets = set(validated_datasets) | set( - self.get_unique_datasets(db) + self.get_column_distinct(DATASET_COLUMN_NAME, db) ) self._validate_enumerations_restriction(cdes_with_enumerations, db) self._validate_min_max_restriction(cdes_with_min_max, db) @@ -432,10 +436,5 @@ def _validate_enumerations_restriction(self, cdes_with_enumerations, db): f"In the column: '{cde}' the following values are invalid: '{cde_invalid_values}'" ) - def get_unique_datasets(self, db): - return db.execute( - f"SELECT DISTINCT(dataset) FROM {self.table.fullname};" - ).fetchone() - def set_table(self, table): self._table = table diff --git a/mipdb/usecases.py b/mipdb/usecases.py index 14461df..cef4d21 100644 --- a/mipdb/usecases.py +++ b/mipdb/usecases.py @@ -31,9 +31,8 @@ TemporaryTable, RECORDS_PER_COPY, ) -from mipdb.data_frame import DataFrame +from mipdb.data_frame import DataFrame, DATASET_COLUMN_NAME -DATASET_COLUMN_NAME = "dataset" LONGITUDINAL = "longitudinal" @@ -350,7 +349,7 @@ def insert_csv_to_db(self, csv_path, temporary_table, data_model, db): break imported_datasets = set(imported_datasets) | set( - temporary_table.get_unique_datasets(db) + temporary_table.get_column_distinct(DATASET_COLUMN_NAME, db) ) db.copy_data_table_to_another_table(primary_data_table, temporary_table) temporary_table.delete(db) diff --git a/pyproject.toml b/pyproject.toml index a3f09e4..26194dc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "mipdb" -version = "2.4.4" +version = "2.4.5" description = "" authors = ["Your Name "] diff --git a/tests/conftest.py b/tests/conftest.py index 23b06a5..4b40bfc 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -16,6 +16,7 @@ FAIL_DATA_FOLDER = DATA_FOLDER + "fail" ABSOLUTE_PATH_DATA_FOLDER = f"{os.path.dirname(os.path.realpath(__file__))}/data/" ABSOLUTE_PATH_DATASET_FILE = f"{os.path.dirname(os.path.realpath(__file__))}/data/success/data_model_v_1_0/dataset.csv" +ABSOLUTE_PATH_DATASET_FILE_MULTIPLE_DATASET = f"{os.path.dirname(os.path.realpath(__file__))}/data/success/data_model_v_1_0/dataset123.csv" ABSOLUTE_PATH_SUCCESS_DATA_FOLDER = ABSOLUTE_PATH_DATA_FOLDER + "success" ABSOLUTE_PATH_FAIL_DATA_FOLDER = ABSOLUTE_PATH_DATA_FOLDER + "fail" IP = "127.0.0.1" diff --git a/tests/data/success/data_model_v_1_0/dataset10.csv b/tests/data/success/data_model_v_1_0/dataset10.csv index ee4c72e..29f411f 100644 --- a/tests/data/success/data_model_v_1_0/dataset10.csv +++ b/tests/data/success/data_model_v_1_0/dataset10.csv @@ -3,4 +3,4 @@ subjectcode,var1,var3,dataset 2,2,22,dataset10 2,1,23,dataset10 5,1,24,dataset10 -5,2,25,dataset2 +5,2,25,dataset10 diff --git a/tests/test_commands.py b/tests/test_commands.py index 3ade077..cdb2b3b 100644 --- a/tests/test_commands.py +++ b/tests/test_commands.py @@ -25,7 +25,7 @@ ABSOLUTE_PATH_SUCCESS_DATA_FOLDER, SUCCESS_DATA_FOLDER, ABSOLUTE_PATH_FAIL_DATA_FOLDER, - DEFAULT_OPTIONS, + DEFAULT_OPTIONS, ABSOLUTE_PATH_DATASET_FILE_MULTIPLE_DATASET, ) from tests.conftest import DATA_MODEL_FILE @@ -1072,13 +1072,13 @@ def test_list_datasets(db): runner.invoke( add_dataset, [ - DATASET_FILE, + ABSOLUTE_PATH_DATASET_FILE_MULTIPLE_DATASET, "--data-model", "data_model", "-v", "1.0", "--copy_from_file", - False, + True, ] + DEFAULT_OPTIONS, ) @@ -1089,12 +1089,20 @@ def test_list_datasets(db): assert result.stdout == "There are no datasets.\n" assert result_with_dataset.exit_code == ExitCode.OK assert ( - "dataset_id data_model_id code label status count" - in result_with_dataset.stdout + "dataset_id data_model_id code label status count".strip(" ") + in result_with_dataset.stdout.strip(" ") + ) + assert ( + "dataset2 Dataset 2 ENABLED 2".strip(" ") + in result_with_dataset.stdout.strip(" ") + ) + assert ( + "dataset1 Dataset 1 ENABLED 2".strip(" ") + in result_with_dataset.stdout.strip(" ") ) assert ( - "0 1 1 dataset Dataset ENABLED 5" - in result_with_dataset.stdout + "dataset Dataset ENABLED 1".strip(" ") + in result_with_dataset.stdout.strip(" ") ) diff --git a/tests/test_usecases.py b/tests/test_usecases.py index 2d90828..ff2172c 100644 --- a/tests/test_usecases.py +++ b/tests/test_usecases.py @@ -360,14 +360,14 @@ def test_add_dataset_with_db_with_multiple_datasets(db, data_model_metadata): # Test ImportCSV(db).execute( - csv_path="tests/data/success/data_model_v_1_0/dataset10.csv", + csv_path="tests/data/success/data_model_v_1_0/dataset123.csv", copy_from_file=False, data_model_code="data_model", data_model_version="1.0", ) datasets = db.get_values(columns=["data_model_id", "code"]) - assert len(datasets) == 2 - assert all(code in ["dataset2", "dataset10"] for dmi, code in datasets) + assert len(datasets) == 3 + assert all(code in ["dataset", "dataset1", "dataset2"] for dmi, code in datasets) @pytest.mark.database