Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for database schemas into Data connection, Importer (backport to 0.7) #174

Merged
merged 1 commit into from
Nov 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
271 changes: 201 additions & 70 deletions spine_items/data_connection/data_connection.py

Large diffs are not rendered by default.

13 changes: 7 additions & 6 deletions spine_items/data_connection/executable_item.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,13 @@
from spine_engine.utils.serialization import deserialize_path
from .item_info import ItemInfo
from .output_resources import scan_for_resources
from .utils import restore_database_references


class ExecutableItem(ExecutableItemBase):
"""The executable parts of Data Connection."""

def __init__(self, name, file_references, db_references, db_credentials, project_dir, logger):
def __init__(self, name, file_references, db_references, project_dir, logger):
"""
Args:
name (str): item's name
Expand All @@ -40,7 +41,6 @@ def __init__(self, name, file_references, db_references, db_credentials, project
data_files.append(entry.path)
self._file_paths = file_references + data_files
self._urls = db_references
self._url_credentials = db_credentials

@staticmethod
def item_type():
Expand All @@ -49,13 +49,14 @@ def item_type():

def _output_resources_forward(self):
"""See base class."""
return scan_for_resources(self, self._file_paths, self._urls, self._url_credentials, self._project_dir)
return scan_for_resources(self, self._file_paths, self._urls, self._project_dir)

@classmethod
def from_dict(cls, item_dict, name, project_dir, app_settings, specifications, logger):
"""See base class."""
file_references = item_dict["file_references"]
file_references = [deserialize_path(r, project_dir) for r in file_references]
db_references = item_dict.get("db_references", [])
db_credentials = item_dict.get("db_credentials", {})
return cls(name, file_references, db_references, db_credentials, project_dir, logger)
db_references = restore_database_references(
item_dict.get("db_references", []), item_dict.get("db_credentials", {}), project_dir
)
return cls(name, file_references, db_references, project_dir, logger)
16 changes: 9 additions & 7 deletions spine_items/data_connection/output_resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,18 @@
from pathlib import Path
from spine_engine.project_item.project_item_resource import file_resource, transient_file_resource, url_resource
from spine_engine.utils.serialization import path_in_dir
from ..utils import unsplit_url_credentials
from spinedb_api.helpers import remove_credentials_from_url
from ..utils import convert_to_sqlalchemy_url, unsplit_url_credentials


def scan_for_resources(provider, file_paths, urls, url_credentials, project_dir):
def scan_for_resources(provider, file_paths, urls, project_dir):
"""
Creates file and URL resources based on DC's references and data.

Args:
provider (ProjectItem or ExecutableItem): resource provider item
file_paths (list of str): file paths
urls (list of str): urls
url_credentials (dict): mapping url from urls to tuple (username, password)
urls (list of dict): urls
project_dir (str): absolute path to project directory

Returns:
Expand Down Expand Up @@ -55,8 +55,10 @@ def scan_for_resources(provider, file_paths, urls, url_credentials, project_dir)
continue
resources.append(resource)
for url in urls:
credentials = url_credentials.get(url)
full_url = unsplit_url_credentials(url, credentials) if credentials is not None else url
resource = url_resource(provider.name, full_url, f"<{provider.name}>" + url)
str_url = str(convert_to_sqlalchemy_url(url))
schema = url.get("schema")
resource = url_resource(
provider.name, str_url, f"<{provider.name}>" + remove_credentials_from_url(str_url), schema=schema
)
resources.append(resource)
return resources
63 changes: 63 additions & 0 deletions spine_items/data_connection/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
######################################################################################################################
# Copyright (C) 2017-2022 Spine project consortium
# This file is part of Spine Items.
# Spine Items is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser General
# Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option)
# any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
# without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General
# Public License for more details. You should have received a copy of the GNU Lesser General Public License along with
# this program. If not, see <http://www.gnu.org/licenses/>.
######################################################################################################################
"""This module contains utilities for Data Connection."""
import sys
import urllib.parse
from spine_engine.utils.serialization import deserialize_path
from spine_items.utils import convert_url_to_safe_string


def restore_database_references(references_list, credentials_dict, project_dir):
"""Restores data from serialized database references.

Args:
references_list (list of dict): serialized database references
credentials_dict (dict): mapping from safe URL to (username, password) tuple
project_dir (str): path to project directory

Returns:
list of dict: deserialized database references
"""
db_references = []
for reference_dict in references_list:
if isinstance(reference_dict, str):
# legacy db reference
url = urllib.parse.urlparse(reference_dict)
dialect = _dialect_from_scheme(url.scheme)
database = url.path[1:]
db_reference = {
"dialect": dialect,
"host": url.hostname,
"port": url.port,
"database": database,
}
else:
db_reference = dict(reference_dict)
if db_reference["dialect"] == "sqlite":
db_reference["database"] = deserialize_path(db_reference["database"], project_dir)

db_reference["username"], db_reference["password"] = credentials_dict.get(
convert_url_to_safe_string(db_reference), (None, None)
)
db_references.append(db_reference)
return db_references


def _dialect_from_scheme(scheme):
"""Parses dialect from URL scheme.

Args:
scheme (str): URL scheme

Returns:
str: dialect name
"""
return scheme.split("+")[0]
4 changes: 2 additions & 2 deletions spine_items/data_store/data_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def set_up(self):

def parse_url(self, url):
"""Return a complete url dictionary from the given dict or string"""
base_url = dict(dialect="", username="", password="", host="", port="", database="")
base_url = dict(dialect="", username="", password="", host="", port="", database="", schema="")
if isinstance(url, dict):
if url.get("dialect") == "sqlite" and "database" in url and url["database"] is not None:
# Convert relative database path back to absolute
Expand Down Expand Up @@ -202,7 +202,7 @@ def do_update_url(self, **kwargs):
old_url = convert_to_sqlalchemy_url(self._url, self.name)
new_dialect = kwargs.get("dialect")
if new_dialect == "sqlite":
kwargs.update({"username": "", "password": "", "host": "", "port": ""})
kwargs.update({"username": "", "password": "", "host": "", "port": "", "schema": ""})
self._url.update(kwargs)
new_url = convert_to_sqlalchemy_url(self._url, self.name)
self.load_url_into_selections(self._url)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@ def __init__(self, toolbox):
self._active_item = None
self.ui = Ui_Form()
self.ui.setupUi(self)
self.ui.url_selector_widget.setup(list(SUPPORTED_DIALECTS.keys()), self._select_sqlite_file, self._toolbox)
self.ui.url_selector_widget.setup(
list(SUPPORTED_DIALECTS.keys()), self._select_sqlite_file, True, self._toolbox
)

def set_item(self, data_store):
"""Sets the active project item for the properties widget.
Expand Down
2 changes: 1 addition & 1 deletion spine_items/exporter/widgets/export_list_item.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def _emit_out_label_changed(self):
@Slot(bool)
def _show_url_dialog(self, _=False):
"""Opens the URL selector dialog."""
dialog = UrlSelectorDialog(self._app_settings, self._logger, self)
dialog = UrlSelectorDialog(self._app_settings, True, self._logger, self)
if self._out_url is not None:
dialog.set_url_dict(self._out_url)
dialog.exec_()
Expand Down
76 changes: 28 additions & 48 deletions spine_items/importer/connection_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,7 @@
# Public License for more details. You should have received a copy of the GNU Lesser General Public License along with
# this program. If not, see <http://www.gnu.org/licenses/>.
######################################################################################################################

"""
Contains ConnectionManager class.

"""
""" Contains ConnectionManager class. """

from PySide6.QtCore import QObject, Qt, QThread, Signal, Slot
from PySide6.QtWidgets import QFileDialog
Expand Down Expand Up @@ -60,11 +56,10 @@ def __init__(self, connection, connection_settings, parent):
super().__init__(parent)
self._thread = None
self._worker = None
self._source = None
self._current_table = None
self._table_options = {}
self._table_types = {}
self._defaul_table_column_type = {}
self._default_table_column_type = {}
self._table_row_types = {}
self._connection = connection
self._connection_settings = connection_settings
Expand Down Expand Up @@ -92,20 +87,12 @@ def table_types(self):

@property
def table_default_column_type(self):
return self._defaul_table_column_type
return self._default_table_column_type

@property
def table_row_types(self):
return self._table_row_types

@property
def source(self):
return self._source

@source.setter
def source(self, source):
self._source = source

@property
def source_type(self):
return self._connection.__name__
Expand Down Expand Up @@ -143,28 +130,19 @@ def request_default_mapping(self):
if self.is_connected:
self.default_mapping_requested.emit()

def connection_ui(self):
"""
Launches a modal ui that prompts the user to select source.

ex: fileselect if source is a file.
"""
ext = self._connection.FILE_EXTENSIONS
source, action = QFileDialog.getOpenFileName(None, "", ext)
if not source or not action:
return False
self._source = source
return True

def init_connection(self):
def init_connection(self, source, **source_extras):
"""Creates a Worker and a new thread to read source data.
If there is an existing thread close that one.

Args:
source (str): source file name or URL
**source_extras: source specific additional connection settings
"""
# close existing thread
self.close_connection()
# create new thread and worker
self._thread = QThread()
self._worker = ConnectionWorker(self._source, self._connection, self._connection_settings)
self._worker = ConnectionWorker(self._connection, self._connection_settings)
self._worker.moveToThread(self._thread)
# connect worker signals
self._worker.connectionReady.connect(self._handle_connection_ready)
Expand All @@ -181,7 +159,7 @@ def init_connection(self):
self.connection_closed.connect(self._worker.disconnect, type=Qt.ConnectionType.BlockingQueuedConnection)

# when thread is started, connect worker to source
self._thread.started.connect(self._worker.init_connection)
self._thread.started.connect(lambda: self._worker.init_connection(source, dict(source_extras)))
self._thread.start()

@Slot()
Expand Down Expand Up @@ -260,15 +238,15 @@ def update_table_default_column_type(self, column_type):
Args:
column_type (dict): mapping from table name to column type name
"""
self._defaul_table_column_type.update(column_type)
self._default_table_column_type.update(column_type)

def clear_table_default_column_type(self, table_name):
"""Clears default column type.

Args:
table_name (str): table name
"""
self._defaul_table_column_type.pop(table_name, None)
self._default_table_column_type.pop(table_name, None)

def set_table_row_types(self, types):
"""Sets connection manager types for current connector
Expand All @@ -293,12 +271,7 @@ def close_connection(self):


class ConnectionWorker(QObject):
"""A class for delegating SourceConnection operations to another QThread.

Args:
source (str): path of the source file
connection (class): A class derived from `SourceConnection` for connecting to the source file
"""
"""A class for delegating SourceConnection operations to another QThread."""

connectionFailed = Signal(str)
"""Signal with error message if connection fails"""
Expand All @@ -313,19 +286,26 @@ class ConnectionWorker(QObject):
defaultMappingReady = Signal(dict)
"""Signal when default mapping is ready"""

def __init__(self, source, connection, connection_settings, parent=None):
def __init__(self, connection, connection_settings, parent=None):
"""
Args:
connection (class): A class derived from `SourceConnection` for connecting to the source file
connection_settings (dict): settings passed to the connection constructor
parent (QObject): parent object
"""
super().__init__(parent)
self._source = source
self._connection = connection(connection_settings)

@Slot()
def init_connection(self):
"""
Connect to data source
def init_connection(self, source, source_extras):
"""Connect to data source.

Args:
source (str): source file path or URL
source_extras (dict): source specific additional connection settings
"""
if self._source:
if source:
try:
self._connection.connect_to_source(self._source)
self._connection.connect_to_source(source, **source_extras)
self.connectionReady.emit()
except Exception as error:
self.connectionFailed.emit(f"Could not connect to source: {error}")
Expand Down
Loading
Loading