Skip to content

Commit

Permalink
Revert "merge "
Browse files Browse the repository at this point in the history
  • Loading branch information
SpaceCondor authored Sep 27, 2024
1 parent 5765f12 commit c0b4a78
Show file tree
Hide file tree
Showing 7 changed files with 179 additions and 160 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/release_workflow.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,6 @@ jobs:
tag: ${{ github.ref }}
overwrite: true
file_glob: true
- uses: pypa/[email protected].2
- uses: pypa/[email protected].1
with:
attestations: true
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ repos:
- id: trailing-whitespace

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.6.7
rev: v0.6.5
hooks:
- id: ruff
args: [--fix]
Expand Down
158 changes: 79 additions & 79 deletions poetry.lock

Large diffs are not rendered by default.

27 changes: 11 additions & 16 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ packages = [

[tool.poetry.dependencies]
python = ">=3.8"
faker = {version = "~=29.0", optional = true}
faker = {version = "~=28.1", optional = true}
psycopg2-binary = "2.9.9"
sqlalchemy = "~=2.0"
sshtunnel = "0.4.0"
Expand Down Expand Up @@ -85,21 +85,16 @@ target-version = "py38"

[tool.ruff.lint]
select = [
"F", # Pyflakes
"W", # pycodestyle warnings
"E", # pycodestyle errors
"FA", # flake8-future-annotations
"I", # isort
"N", # pep8-naming
"D", # pydocsyle
"UP", # pyupgrade
"ICN", # flake8-import-conventions
"RET", # flake8-return
"SIM", # flake8-simplify
"TCH", # flake8-type-checking
"PL", # Pylint
"PERF", # Perflint
"RUF", # ruff
"F", # Pyflakes
"W", # pycodestyle warnings
"E", # pycodestyle errors
"FA", # flake8-future-annotations
"I", # isort
"N", # pep8-naming
"D", # pydocsyle
"ICN", # flake8-import-conventions
"UP", # pyupgrade
"RUF", # ruff
]

[tool.ruff.lint.flake8-import-conventions]
Expand Down
95 changes: 56 additions & 39 deletions target_postgres/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@

import atexit
import io
import itertools
import signal
import sys
import typing as t
from contextlib import contextmanager
from functools import cached_property
Expand Down Expand Up @@ -97,7 +95,7 @@ def interpret_content_encoding(self) -> bool:
"""
return self.config.get("interpret_content_encoding", False)

def prepare_table( # type: ignore[override] # noqa: PLR0913
def prepare_table( # type: ignore[override]
self,
full_table_name: str | FullyQualifiedName,
schema: dict,
Expand All @@ -123,7 +121,7 @@ def prepare_table( # type: ignore[override] # noqa: PLR0913
meta = sa.MetaData(schema=schema_name)
table: sa.Table
if not self.table_exists(full_table_name=full_table_name):
return self.create_empty_table(
table = self.create_empty_table(
table_name=table_name,
meta=meta,
schema=schema,
Expand All @@ -132,6 +130,7 @@ def prepare_table( # type: ignore[override] # noqa: PLR0913
as_temp_table=as_temp_table,
connection=connection,
)
return table
meta.reflect(connection, only=[table_name])
table = meta.tables[
full_table_name
Expand Down Expand Up @@ -179,17 +178,19 @@ def copy_table_structure(
_, schema_name, table_name = self.parse_full_table_name(full_table_name)
meta = sa.MetaData(schema=schema_name)
new_table: sa.Table
columns = []
if self.table_exists(full_table_name=full_table_name):
raise RuntimeError("Table already exists")

columns = [column._copy() for column in from_table.columns]
for column in from_table.columns:
columns.append(column._copy())
if as_temp_table:
new_table = sa.Table(table_name, meta, *columns, prefixes=["TEMPORARY"])
new_table.create(bind=connection)
return new_table
new_table = sa.Table(table_name, meta, *columns)
new_table.create(bind=connection)
return new_table
else:
new_table = sa.Table(table_name, meta, *columns)
new_table.create(bind=connection)
return new_table

@contextmanager
def _connect(self) -> t.Iterator[sa.engine.Connection]:
Expand All @@ -204,7 +205,14 @@ def clone_table(
self, new_table_name, table, metadata, connection, temp_table
) -> sa.Table:
"""Clone a table."""
new_columns = [sa.Column(column.name, column.type) for column in table.columns]
new_columns = []
for column in table.columns:
new_columns.append(
sa.Column(
column.name,
column.type,
)
)
if temp_table is True:
new_table = sa.Table(
new_table_name, metadata, *new_columns, prefixes=["TEMPORARY"]
Expand Down Expand Up @@ -270,7 +278,7 @@ def to_sql_type(self, jsonschema_type: dict) -> sa.types.TypeEngine: # type: ig

return PostgresConnector.pick_best_sql_type(sql_type_array=sql_type_array)

def pick_individual_type(self, jsonschema_type: dict): # noqa: PLR0911
def pick_individual_type(self, jsonschema_type: dict):
"""Select the correct sql type assuming jsonschema_type has only a single type.
Args:
Expand Down Expand Up @@ -308,7 +316,11 @@ def pick_individual_type(self, jsonschema_type: dict): # noqa: PLR0911
return ARRAY(self.to_sql_type({"type": items_type}))

# Case 3: tuples
return ARRAY(JSONB()) if isinstance(items, list) else JSONB()
if isinstance(items, list):
return ARRAY(JSONB())

# All other cases, return JSONB
return JSONB()

# string formats
if jsonschema_type.get("format") == "date-time":
Expand All @@ -321,7 +333,9 @@ def pick_individual_type(self, jsonschema_type: dict): # noqa: PLR0911
):
return HexByteString()
individual_type = th.to_sql_type(jsonschema_type)
return TEXT() if isinstance(individual_type, VARCHAR) else individual_type
if isinstance(individual_type, VARCHAR):
return TEXT()
return individual_type

@staticmethod
def pick_best_sql_type(sql_type_array: list):
Expand Down Expand Up @@ -350,12 +364,13 @@ def pick_best_sql_type(sql_type_array: list):
NOTYPE,
]

for sql_type, obj in itertools.product(precedence_order, sql_type_array):
if isinstance(obj, sql_type):
return obj
for sql_type in precedence_order:
for obj in sql_type_array:
if isinstance(obj, sql_type):
return obj
return TEXT()

def create_empty_table( # type: ignore[override] # noqa: PLR0913
def create_empty_table( # type: ignore[override]
self,
table_name: str,
meta: sa.MetaData,
Expand All @@ -369,7 +384,7 @@ def create_empty_table( # type: ignore[override] # noqa: PLR0913
Args:
table_name: the target table name.
meta: the SQLAlchemy metadata object.
meta: the SQLAchemy metadata object.
schema: the JSON schema for the new table.
connection: the database connection.
primary_keys: list of key properties.
Expand All @@ -391,7 +406,7 @@ def create_empty_table( # type: ignore[override] # noqa: PLR0913
raise RuntimeError(
f"Schema for table_name: '{table_name}'"
f"does not define properties: {schema}"
) from None
)

for property_name, property_jsonschema in properties.items():
is_primary_key = property_name in primary_keys
Expand Down Expand Up @@ -525,7 +540,7 @@ def get_column_add_ddl( # type: ignore[override]
},
)

def _adapt_column_type( # type: ignore[override] # noqa: PLR0913
def _adapt_column_type( # type: ignore[override]
self,
schema_name: str,
table_name: str,
Expand Down Expand Up @@ -568,7 +583,7 @@ def _adapt_column_type( # type: ignore[override] # noqa: PLR0913
return

# Not the same type, generic type or compatible types
# calling merge_sql_types for assistance
# calling merge_sql_types for assistnace
compatible_sql_type = self.merge_sql_types([current_type, sql_type])

if str(compatible_sql_type) == str(current_type):
Expand Down Expand Up @@ -638,16 +653,17 @@ def get_sqlalchemy_url(self, config: dict) -> str:
if config.get("sqlalchemy_url"):
return cast(str, config["sqlalchemy_url"])

sqlalchemy_url = URL.create(
drivername=config["dialect+driver"],
username=config["user"],
password=config["password"],
host=config["host"],
port=config["port"],
database=config["database"],
query=self.get_sqlalchemy_query(config),
)
return cast(str, sqlalchemy_url)
else:
sqlalchemy_url = URL.create(
drivername=config["dialect+driver"],
username=config["user"],
password=config["password"],
host=config["host"],
port=config["port"],
database=config["database"],
query=self.get_sqlalchemy_query(config),
)
return cast(str, sqlalchemy_url)

def get_sqlalchemy_query(self, config: dict) -> dict:
"""Get query values to be used for sqlalchemy URL creation.
Expand All @@ -663,7 +679,7 @@ def get_sqlalchemy_query(self, config: dict) -> dict:
# ssl_enable is for verifying the server's identity to the client.
if config["ssl_enable"]:
ssl_mode = config["ssl_mode"]
query["sslmode"] = ssl_mode
query.update({"sslmode": ssl_mode})
query["sslrootcert"] = self.filepath_or_certificate(
value=config["ssl_certificate_authority"],
alternative_name=config["ssl_storage_directory"] + "/root.crt",
Expand Down Expand Up @@ -709,11 +725,12 @@ def filepath_or_certificate(
"""
if path.isfile(value):
return value
with open(alternative_name, "wb") as alternative_file:
alternative_file.write(value.encode("utf-8"))
if restrict_permissions:
chmod(alternative_name, 0o600)
return alternative_name
else:
with open(alternative_name, "wb") as alternative_file:
alternative_file.write(value.encode("utf-8"))
if restrict_permissions:
chmod(alternative_name, 0o600)
return alternative_name

def guess_key_type(self, key_data: str) -> paramiko.PKey:
"""Guess the type of the private key.
Expand All @@ -738,7 +755,7 @@ def guess_key_type(self, key_data: str) -> paramiko.PKey:
):
try:
key = key_class.from_private_key(io.StringIO(key_data)) # type: ignore[attr-defined]
except paramiko.SSHException: # noqa: PERF203
except paramiko.SSHException:
continue
else:
return key
Expand All @@ -758,7 +775,7 @@ def catch_signal(self, signum, frame) -> None:
signum: The signal number
frame: The current stack frame
"""
sys.exit(1) # Calling this to be sure atexit is called, so clean_up gets called
exit(1) # Calling this to be sure atexit is called, so clean_up gets called

def _get_column_type( # type: ignore[override]
self,
Expand Down
Loading

0 comments on commit c0b4a78

Please sign in to comment.